Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion wmin/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import numpy as np
import tensorflow as tf
from colibri.constants import EXPORT_LABELS, LHAPDF_XGRID
from colibri.constants import EXPORT_LABELS, LHAPDF_XGRID, FLAVOUR_TO_ID_MAPPING
from colibri.export_results import write_exportgrid
from n3fit.model_gen import _pdfNN_layer_generator, ReplicaSettings

Expand Down Expand Up @@ -67,6 +67,8 @@ def n3fit_pdf_model(
def n3fit_pdf_grid(
n3fit_pdf_model,
filter_arclength_outliers: bool = True,
filter_integrability: bool = True,
integrability_threshold: float = 0.5,
):
Comment thread
vschutze-alt marked this conversation as resolved.
"""
Returns the PDF grid for the n3fit model evaluated on the LHAPDF_XGRID.
Expand All @@ -81,6 +83,11 @@ def n3fit_pdf_grid(
The xgrid to use.
filter_arclength_outliers: bool, default is True
Whether to filter out the arclength outliers from the PDF grid.
filter_integrability: bool, default is True
Whether to filter replicas based on integrability conditions.
integrability_threshold: float, default=0.5
Tolerance used to enforce integrability. If the sum is above
this value, the replica is discarded.

Returns
-------
Expand Down Expand Up @@ -112,6 +119,37 @@ def n3fit_pdf_grid(
log.info("No more outliers found in the PDF grid")
filter_arclength_outliers = False

# filter from integrability outliers
if filter_integrability:
log.info("Filtering out non-integrable replicas")
log.info(f"Integrability threshold: {integrability_threshold}")

flavours = ["V", "V3", "V8", "T3", "T8"]

while True:
initial_size = pdf_array.shape[0]

for flavour in flavours:
grid = pdf_array[:, FLAVOUR_TO_ID_MAPPING[flavour], :]

# sum over the first 20 points in the xgrid: all points of order e-9 up to first point of order e-7
mask = np.abs(grid[:, :20].sum(axis=1)) <= integrability_threshold

n_discarded = (~mask).sum()
n_kept = mask.sum()

log.info(
f"Filtering {flavour} integrability: "
f"discarded {n_discarded}, kept {n_kept}"
)

pdf_array = pdf_array[mask, :, :]

# stop condition: nothing changed in full cycle
if pdf_array.shape[0] == initial_size:
log.info("No more integrability outliers found")
break

return pdf_array


Expand Down
20 changes: 6 additions & 14 deletions wmin/tests/test_wmin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,11 @@
class MockPDFModel:
n_basis = N_MOCK_DATA

@staticmethod
def pred_and_pdf_func(FIT_XGRID, forward_map):
def pred_and_pdf(params, fast_kernel_arrays):
predictions = params * 2
pdf = params * 3
return predictions, pdf

return pred_and_pdf
def __call__(self, fast_kernel_arrays, params):
"""Make the model callable for LogLikelihood"""
predictions = params * 2
pdf = params * 3
return predictions, pdf


def mock_bayesian_prior(rng):
Expand All @@ -64,7 +61,7 @@ def mock_bayesian_prior(rng):
"positivity_fast_kernel_arrays": positivity_fast_kernel_arrays,
"_pred_data": mock_pred_data,
"FIT_XGRID": TEST_XGRID,
"pdf_model": MockPDFModel(),
"forward_map": MockPDFModel(),
"bayesian_prior": mock_bayesian_prior,
"theoryid": MOCK_NAME_THEORY,
"n_prior_samples": 100,
Expand All @@ -80,11 +77,6 @@ def test_likelihood_time_structure():
"""
test the structure of the output of likelihood_time
"""
# Debug: Check what likelihood_time actually is
print(f"likelihood_time type: {type(likelihood_time)}")
print(f"likelihood_time: {likelihood_time}")
print(f"likelihood_time module: {likelihood_time.__module__}")

result = likelihood_time(**SETUP)
assert isinstance(result, pd.DataFrame)
assert list(result.columns) == ["Ndata", "Theory", "Likelihood eval time (s)"]
Expand Down
11 changes: 5 additions & 6 deletions wmin/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def likelihood_time(
positivity_fast_kernel_arrays,
_pred_data,
FIT_XGRID,
pdf_model,
forward_map,
bayesian_prior,
theoryid,
n_prior_samples=1000,
Expand Down Expand Up @@ -117,7 +117,7 @@ def likelihood_time(
FIT_XGRID: array
The xgrid to use.

pdf_model: PDFModel
forward_map: ForwardMap

bayesian_prior: function
The prior function to use.
Expand All @@ -144,9 +144,8 @@ def likelihood_time(

log_likelihood = LogLikelihood(
central_covmat_index,
pdf_model,
FIT_XGRID,
_pred_data,
forward_map,
forward_map,
fast_kernel_arrays,
positivity_fast_kernel_arrays,
_penalty_posdata,
Expand All @@ -159,7 +158,7 @@ def likelihood_time(
prior_samples = []
for i in range(n_prior_samples):
prior_samples.append(
bayesian_prior(jax.random.uniform(rng, shape=(pdf_model.n_basis,)))
bayesian_prior(jax.random.uniform(rng, shape=(forward_map.n_basis,)))
)

# compile likelihood
Expand Down
Loading