Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 297 additions & 19 deletions wmin/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@

import numpy as np
import tensorflow as tf
from colibri.constants import EXPORT_LABELS, LHAPDF_XGRID
from colibri.constants import (
EXPORT_LABELS,
FLAVOUR_TO_ID_MAPPING,
LHAPDF_XGRID,
flavour_to_evolution_matrix,
)
from colibri.export_results import write_exportgrid
from n3fit.model_gen import _pdfNN_layer_generator
from n3fit.model_gen import _pdfNN_layer_generator, ReplicaSettings

from wmin.utils import (
FLAV_INFO,
Expand All @@ -38,28 +43,42 @@ def n3fit_pdf_model(
However, for better stability, the sum rules are also imposed later-on in a more accurate way
using a quadrature integration.
"""
min_rep = replica_range_settings["min_replica"]
max_rep = replica_range_settings["max_replica"]

# Build one ReplicaSettings per replica
replicas_settings = [
ReplicaSettings(
nodes=nodes,
activations=activations,
initializer=initializer_name,
architecture=layer_type,
seed=seed,
)
for seed in range(min_rep, max_rep + 1)
]

pdf_model = _pdfNN_layer_generator(
nodes=nodes,
activations=activations,
initializer_name=initializer_name,
layer_type=layer_type,
seed=range(
replica_range_settings["min_replica"],
replica_range_settings["max_replica"] + 1,
),
impose_sumrule=True, # sum-rules are also imposed later-on in a more accurate way.
replicas_settings=replicas_settings,
flav_info=flav_info,
fitbasis=fitbasis,
num_replicas=replica_range_settings["max_replica"]
- replica_range_settings["min_replica"]
+ 1,
# leave impose_sumrule=None here so it defaults to "All"
# and still produces xgrid_integration (x_in) as before
)

# Make flavour ordering available downstream (n3fit may not expose it publicly).
pdf_model.flav_info = flav_info

return pdf_model


def n3fit_pdf_grid(
n3fit_pdf_model,
filter_arclength_outliers: bool = True,
overwrite_non_gluon_from_lhapdf: bool = True,
lhapdf_set: str = "NNPDF31_nnlo_as_0118",
lhapdf_member: int = 0,
lhapdf_q: float = 1.65,
):
"""
Returns the PDF grid for the n3fit model evaluated on the LHAPDF_XGRID.
Expand All @@ -74,6 +93,13 @@ def n3fit_pdf_grid(
The xgrid to use.
filter_arclength_outliers: bool, default is True
Whether to filter out the arclength outliers from the PDF grid.
lhapdf_set: str, default is "NNPDF31_nnlo_as_0118"
LHAPDF set used to build the non-gluon baseline when
overwrite_non_gluon_from_lhapdf is True.
lhapdf_member: int, default is 0
LHAPDF member index used for the non-gluon baseline.
lhapdf_q: float, default is 1.65
Scale Q used when sampling the LHAPDF baseline.

Returns
-------
Expand All @@ -87,7 +113,230 @@ def n3fit_pdf_grid(
pdf_grid = tf.squeeze(n3fit_pdf_model(input), axis=0)

# shapes here are (nreplicas, nflavours, nx)
pdf_array = np.array(tf.transpose(pdf_grid, perm=[0, 2, 1]))
pdf_array_nn = np.array(
tf.transpose(pdf_grid, perm=[0, 2, 1])
) # vanilla functions generated by n3fit_pdf_model
nreplicas, nflavours, _nx = pdf_array_nn.shape
xgrid = np.asarray(LHAPDF_XGRID, dtype=float)
if _nx != len(xgrid):
raise ValueError(f"NN output nx={_nx} does not match LHAPDF_XGRID length={len(xgrid)}")

# Canonical evolution-basis ordering follows validphys/colibri FK flavour indices.
flavour_to_id = {str(key): int(val) for key, val in FLAVOUR_TO_ID_MAPPING.items()}
id_to_flavour = {idx: flav for flav, idx in flavour_to_id.items()}
canonical_evol_labels = [id_to_flavour.get(i) for i in range(nflavours)]

def _normalize_evol_label(label):
if label is None:
return None
s = str(label)
aliases = {
"sng": "\\Sigma",
"sigma": "\\Sigma",
"singlet": "\\Sigma",
"gluon": "g",
"v": "V",
"v3": "V3",
"v8": "V8",
"v15": "V15",
"v24": "V24",
"v35": "V35",
"t3": "T3",
"t8": "T8",
"t15": "T15",
"t24": "T24",
"t35": "T35",
}
return aliases.get(s, aliases.get(s.lower(), s))

# n3fit model output channels are in a fixed evolution ordering.
# Reindex explicitly to canonical FK ordering expected by the downstream rotation.
source_evol_labels = list(getattr(n3fit_pdf_model, "evolution_labels", []))
n3fit_evol_output_labels = [
"photon",
"sigma",
"gluon",
"V",
"V3",
"V8",
"V15",
"V24",
"V35",
"T3",
"T8",
"T15",
"T24",
"T35",
]
if len(source_evol_labels) != nflavours:
if nflavours == len(n3fit_evol_output_labels):
source_evol_labels = n3fit_evol_output_labels
else:
source_evol_labels = canonical_evol_labels.copy()

source_norm = [_normalize_evol_label(fl) for fl in source_evol_labels]
canonical_norm = [_normalize_evol_label(fl) for fl in canonical_evol_labels]
if None not in source_norm and None not in canonical_norm:
source_positions = {fl: i for i, fl in enumerate(source_norm)}
missing = [fl for fl in canonical_norm if fl not in source_positions]
if missing:
log.warning(
"Could not fully reindex NN flavours to canonical FK order; missing=%s. "
"Proceeding with original NN ordering.",
missing,
)
else:
permutation = [source_positions[fl] for fl in canonical_norm]
if permutation != list(range(nflavours)):
pdf_array_nn = pdf_array_nn[:, permutation, :]
log.info("Reindexed NN flavour channels to canonical FK ordering: %s", permutation)

if "g" not in canonical_norm:
raise ValueError(
"Could not determine gluon index in canonical evolution ordering. "
f"canonical_labels={canonical_evol_labels}"
)
gluon_index = canonical_norm.index("g")

# Build the baseline pdf_array from LHAPDF (central member), then overwrite
# only the gluon channel using the NN evaluation.
#
if overwrite_non_gluon_from_lhapdf:
try:
import lhapdf # type: ignore
except Exception as exc:
log.warning(
"LHAPDF not available (%s); using NN-generated grid for all flavours.",
exc,
)
pdf_array = pdf_array_nn
else:
Comment on lines +204 to +213

Copilot AI Feb 11, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New behavior branches on overwrite_non_gluon_from_lhapdf and has multiple error/fallback paths (LHAPDF missing, set load failure, flavour mapping mismatch). There are existing tests for wmin.basis, but none cover n3fit_pdf_grid or the new LHAPDF overwrite logic. Add unit tests that patch lhapdf.mkPDF / xfxQ and assert that only the gluon channel varies across replicas when overwrite is enabled, and that it falls back to NN-only grids on failures.

Copilot uses AI. Check for mistakes.
if not lhapdf_set:
log.warning(
"No LHAPDF set configured; using NN-generated grid for all flavours."
)
pdf_array = pdf_array_nn
else:
try:
lha_pdf = lhapdf.mkPDF(lhapdf_set, lhapdf_member)
log.info(
"Loaded LHAPDF set %s member %s at Q=%s",
lhapdf_set,
lhapdf_member,
lhapdf_q,
)
except Exception as exc:
log.warning(
"Failed to load LHAPDF set %s member %s (%s); using NN-generated grid for all flavours.",
lhapdf_set,
lhapdf_member,
exc,
)
pdf_array = pdf_array_nn
else:
# Build baseline in physical flavour basis (LHAPDF PIDs), then rotate to EVOL.
pid_from_export_label = {
"TBAR": -6,
"BBAR": -5,
"CBAR": -4,
"SBAR": -3,
"UBAR": -2,
"DBAR": -1,
"GLUON": 21,
"D": 1,
"U": 2,
"S": 3,
"C": 4,
"B": 5,
"T": 6,
"PHT": 22,
}
baseline_flavour = np.zeros(
(len(EXPORT_LABELS), len(xgrid)), dtype=pdf_array_nn.dtype
)
for flavour_index, export_label in enumerate(EXPORT_LABELS):
lha_id = pid_from_export_label[export_label]
try:
baseline_flavour[flavour_index, :] = np.asarray(
[lha_pdf.xfxQ(lha_id, float(x), lhapdf_q) for x in xgrid],
dtype=pdf_array_nn.dtype,
)
except Exception as exc:
log.warning(
"Could not read physical flavour '%s' (pid=%s) from LHAPDF baseline; "
"setting this channel to zero (%s).",
export_label,
lha_id,
exc,
)
baseline_flavour[flavour_index, :] = 0.0

baseline_evol = np.einsum(
"ef,fx->ex",
np.asarray(flavour_to_evolution_matrix, dtype=pdf_array_nn.dtype),
baseline_flavour,
)

# Keep the downstream flavour-channel convention unchanged.
baseline = np.zeros((nflavours, len(xgrid)), dtype=pdf_array_nn.dtype)
for flav_index, flav_key in enumerate(canonical_evol_labels):
if flav_key is None:
continue

evol_index = flavour_to_id.get(str(flav_key))
if evol_index is None or evol_index >= baseline_evol.shape[0]:
continue
baseline[flav_index, :] = baseline_evol[evol_index, :]

# Build pdf_array as LHAPDF baseline broadcast across replicas.
pdf_array = np.repeat(baseline[np.newaxis, :, :], nreplicas, axis=0)

# Impose MSR with fixed singlet from the reference:
# ∫ dx [x f_g^(m)(x) + x f_Sigma^(ref)(x)] = 1
# The arrays here are x*f(x), so we integrate them directly.
singlet_label = "\\Sigma"
if singlet_label not in canonical_norm:
raise ValueError(
"Could not determine singlet index in canonical evolution ordering. "
f"canonical_labels={canonical_evol_labels}"
)
singlet_index = canonical_norm.index(singlet_label)

sigma_ref = pdf_array[0, singlet_index, :]
sigma_momentum = float(np.trapezoid(sigma_ref, xgrid))
target_gluon_momentum = 1.0 - sigma_momentum

gluon_raw = pdf_array_nn[:, gluon_index, :].copy()
gluon_momenta = np.trapezoid(gluon_raw, xgrid, axis=1)

safe = np.abs(gluon_momenta) > 1e-16
if not np.all(safe):
bad = np.where(~safe)[0].tolist()
log.warning(
"Skipping MSR rescaling for gluon replicas with near-zero momentum moment: %s",
bad,
)

scales = np.ones(nreplicas, dtype=pdf_array_nn.dtype)
scales[safe] = target_gluon_momentum / gluon_momenta[safe]
gluon_rescaled = gluon_raw * scales[:, np.newaxis]

# Overwrite gluon with NN per-replica values.
pdf_array[:, gluon_index, :] = gluon_rescaled
else:
pdf_array = pdf_array_nn

# Sanity check: report which flavours vary across replicas.
# Expected: only the gluon channel varies.
if pdf_array.shape[0] > 1:
maxdiff = np.max(np.abs(pdf_array - pdf_array[0:1, :, :]), axis=(0, 2))
varying = [int(i) for i, d in enumerate(maxdiff) if float(d) > 1e-12]
log.info(
"Post-merge varying flavour indices=%s (gluon_index=%s; nflavours=%s)",
varying,
gluon_index if "gluon_index" in locals() else None,
pdf_array.shape[1],
)

# filter from arclength outliers
while filter_arclength_outliers:
Expand Down Expand Up @@ -197,12 +446,13 @@ def write_pod_basis(
"""
pod, phi0 = pod_basis
basis = pod + phi0
output_dir = os.fspath(output_path)

replicas_path = str(output_path) + "/replicas"
replicas_path = output_dir + "/replicas"
if not os.path.exists(replicas_path):
os.mkdir(replicas_path)

fit_name = str(output_path).split("/")[-1]
fit_name = output_dir.split("/")[-1]

for i in range(basis.shape[0]):

Expand All @@ -220,7 +470,7 @@ def write_pod_basis(
grid_for_writing=phi0,
grid_name=grid_name,
replica_index=i + 1,
Q=Q,
Q0=Q,
xgrid=xgrid,
export_labels=export_labels,
)
Expand All @@ -229,15 +479,43 @@ def write_pod_basis(
grid_for_writing=basis[i - 1],
grid_name=grid_name,
replica_index=i + 1,
Q=Q,
Q0=Q,
xgrid=xgrid,
export_labels=export_labels,
)

# POD modes are U_i * S_i, so ||mode_i||^2 gives the associated eigenvalue.
pod_flat = pod.reshape(pod.shape[0], -1)
singular_values = np.linalg.norm(pod_flat, axis=1)
eigenvalues = singular_values**2
leading_eigenvalue = float(eigenvalues[0]) if eigenvalues.size > 0 else 0.0
if leading_eigenvalue > 0.0:
eigenvalue_ratio = eigenvalues / leading_eigenvalue
else:
eigenvalue_ratio = np.zeros_like(eigenvalues)

spectrum = np.column_stack(
[
np.arange(1, pod.shape[0] + 1, dtype=float),
eigenvalues,
eigenvalue_ratio,
]
)
eigenvalues_path = os.path.join(output_dir, "pod_eigenvalues.csv")
np.savetxt(
eigenvalues_path,
spectrum,
delimiter=",",
comments="",
header="mode_index,eigenvalue,eigenvalue_ratio",
fmt=["%.0f", "%.18e", "%.18e"],
)

# TODO: how can we ensure that in the postfit of the evolution we don't by mistake also create another central member?
log.info(
f"Replicas written to {replicas_path}, with the central member at replica_1."
)
log.info("POD eigenvalue spectrum written to %s", eigenvalues_path)

log.warning(
"Note: this is a POD basis, so the central member is not the mean but always replica_1.\n"
Expand Down
5 changes: 5 additions & 0 deletions wmin/runcards/pod_basis_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ filter_sr_outliers: false

fitbasis: EVOL

# Baseline LHAPDF settings used by n3fit_pdf_grid
lhapdf_set: NNPDF40_nnlo_as_01180
lhapdf_member: 0
lhapdf_q: 1.65

nodes: [25, 20, 8]

activations: ["tanh", "tanh", "linear"]
Expand Down
Loading