Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions colibri/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,15 @@ class GradientDescentResult:
Recorded (epoch) validation losses (sampled according to record_every).
specs: dict
Dictionary of settings used for the run (epochs, batch size, etc.).
best_epoch: dict
Dictionary containing the best epoch and corresponding parameters and losses.
"""

optimized_parameters: Any
training_loss: jnp.array
validation_loss: jnp.array
specs: Dict[str, Any]
best_epoch: Dict[str, Any]


@dataclass(frozen=True)
Expand Down
10 changes: 9 additions & 1 deletion colibri/doc/sphinx/source/tutorials/scripts/fit-folders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ the first few lines would look like this:

which would represent the losses for the first 150 epochs (i.e. 0, 1, 2 are just labels).

Best epoch
-----------
The final model weights for each replica are taken from the best epoch. This is defined
as the epoch that satisfies the positivity threshold (as specified in `likelihood.py``)
and achieves the lowest validation loss among all such epochs.
The information for the best epoch is stored in ``fit_replicas/replica_n/best_epoch_specs.csv``,
which contains the epoch number, the training loss and the validation loss.

Postfit selection
"""""""""""""""""

Expand All @@ -135,7 +143,7 @@ You can run a postfit selection by running:

where the ``-c`` is optional and ``CHI2_THRESHOLD`` is a number that determines
the :math:`\chi^2` threshold above which a MC replica will be rejected, where this
value is taken from the last row of the ``training_loss`` column shown above.
value is taken from the training loss of the best epoch.
This can also be run as ``--chi2_threshold`` instead of ``-c``. If no value is
specified, a default value of 1.5 will be applied.

Expand Down
48 changes: 48 additions & 0 deletions colibri/gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def run_gradient_descent(
max_epochs: int,
data_batch: Optional[colibri.DataBatches] = None,
record_every: int = 50,
positivity_check_fn: Optional[Callable[[jnp.ndarray], bool]] = None,
threshold_chi2: float = 10.0,
) -> GradientDescentResult:
"""Generic gradient descent loop.

Expand Down Expand Up @@ -90,6 +92,12 @@ def _step(p, ostate, batch: BatchSpec):
train_losses = []
val_losses = []

best_params = params
best_train_loss = jnp.inf
best_val_loss = jnp.inf
best_epoch_idx = 0
any_pos_pass = False

if data_batch is None:
# single fake iterator repeatedly yielding EMPTY_BATCH
def _gen():
Expand All @@ -114,6 +122,27 @@ def _gen():
epoch_val_loss = validation_loss_fn(params)
early_stopper = early_stopper.update(epoch_val_loss)

# Update best epoch based on positivity and validation loss
pos_pass = True
if positivity_check_fn is not None:
pos_pass = positivity_check_fn(params)

update_best = False
meets_threshold = epoch_val_loss < threshold_chi2
if meets_threshold:
if pos_pass and not any_pos_pass:
update_best = True
any_pos_pass = True
elif pos_pass == any_pos_pass and pos_pass:
if epoch_val_loss < best_val_loss:
update_best = True

if update_best:
best_val_loss = epoch_val_loss
best_train_loss = epoch_train_loss
best_params = params
best_epoch_idx = epoch

if record_every and (epoch % record_every == 0):
log.info(
f"Epoch {epoch}, loss: {epoch_train_loss:.3f}, "
Expand All @@ -127,6 +156,24 @@ def _gen():
log.info(f"Early stopping at epoch {epoch}")
break

if best_epoch_idx == 0:
log.warning(
"No epoch passed positivity check. Returning last epoch's parameters."
)
best_epoch_dict = {
"epoch": epoch,
"best_parameters": params,
"best_val_loss": epoch_val_loss,
"best_train_loss": epoch_train_loss,
}
else:
best_epoch_dict = {
"epoch": best_epoch_idx,
"best_parameters": best_params,
"best_val_loss": best_val_loss,
"best_train_loss": best_train_loss,
}

return GradientDescentResult(
optimized_parameters=params,
training_loss=jnp.array(train_losses),
Expand All @@ -136,4 +183,5 @@ def _gen():
"batch_size": batch_size,
"record_every": record_every,
},
best_epoch=best_epoch_dict,
)
45 changes: 33 additions & 12 deletions colibri/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from colibri.commondata_utils import CentralCovmatIndex
from colibri.data_batch import BatchSpec

THRESHOLD_POS = 1e-6


class LogLikelihood(object):
"""
Expand Down Expand Up @@ -69,6 +71,14 @@ def __init__(
self.fast_kernel_arrays = fast_kernel_arrays
self.positivity_fast_kernel_arrays = positivity_fast_kernel_arrays

def get_pos_pass(self, params):
_, pdf = self.pred_and_pdf(params, self.fast_kernel_arrays)
pos_pass, _ = self.positivity_check_and_penalty(
pdf,
self.positivity_fast_kernel_arrays,
)
return pos_pass

def __call__(self, params, batch: BatchSpec | None = None):
"""
Note that this function is called by the samplers, and it must be
Expand Down Expand Up @@ -99,6 +109,25 @@ def __call__(self, params, batch: BatchSpec | None = None):
batch=batch,
)

def positivity_check_and_penalty(self, pdf, positivity_fast_kernel_arrays):
if self.positivity_penalty_settings["positivity_penalty"]:
pos_penalties = self.penalty_posdata(
pdf,
self.positivity_penalty_settings["alpha"],
self.positivity_penalty_settings["lambda_positivity"],
positivity_fast_kernel_arrays,
)
pos_pass = jnp.all(pos_penalties < THRESHOLD_POS)

pos_penalty = jnp.sum(
pos_penalties,
axis=-1,
)
else:
pos_penalty = 0
pos_pass = True
return pos_pass, pos_penalty

@partial(jax.jit, static_argnames=("self",))
def log_likelihood(
self,
Expand Down Expand Up @@ -140,18 +169,10 @@ def log_likelihood(
else:
inv_covmat = batch.inv_cov

if self.positivity_penalty_settings["positivity_penalty"]:
pos_penalty = jnp.sum(
self.penalty_posdata(
pdf,
self.positivity_penalty_settings["alpha"],
self.positivity_penalty_settings["lambda_positivity"],
positivity_fast_kernel_arrays,
),
axis=-1,
)
else:
pos_penalty = 0
_, pos_penalty = self.positivity_check_and_penalty(
pdf,
positivity_fast_kernel_arrays,
)

integ_penalty = jnp.sum(
self.integrability_penalty(
Expand Down
27 changes: 25 additions & 2 deletions colibri/monte_carlo_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def monte_carlo_fit(
early_stopper,
max_epochs,
data_batches,
threshold_chi2=10.0,
):
"""
This function performs a Monte Carlo fit.
Expand Down Expand Up @@ -81,7 +82,8 @@ def loss_validation(parameters):
log.info("Starting Monte Carlo fit...")
t0 = time.time()

# Delegate to generic gradient descent
positivity_check_fn = mc_log_likelihood[0].get_pos_pass

gd_result = run_gradient_descent(
initial_parameters=pdf_initial_parameters.copy(),
training_loss_fn=loss_training,
Expand All @@ -91,6 +93,8 @@ def loss_validation(parameters):
max_epochs=max_epochs,
data_batch=data_batches,
record_every=50,
positivity_check_fn=positivity_check_fn,
threshold_chi2=threshold_chi2,
)

t1 = time.time()
Expand All @@ -101,10 +105,11 @@ def loss_validation(parameters):
"max_epochs": max_epochs,
"batch_size": data_batches.batch_size,
"batch_seed": data_batches.batch_seed,
"best_epoch_specs": gd_result.best_epoch,
},
training_loss=gd_result.training_loss,
validation_loss=gd_result.validation_loss,
optimized_parameters=gd_result.optimized_parameters,
optimized_parameters=gd_result.best_epoch["best_parameters"],
)


Expand Down Expand Up @@ -168,3 +173,21 @@ def run_monte_carlo_fit(monte_carlo_fit, pdf_model, output_path, replica_index,
index=False,
float_format="%.5e",
)

best_epoch_specs = mc_fit.monte_carlo_specs.get("best_epoch_specs")

df = pd.DataFrame(
{
"best_epoch": best_epoch_specs["epoch"],
"best_val_loss": best_epoch_specs["best_val_loss"],
"best_train_loss": best_epoch_specs["best_train_loss"],
},
index=[0],
)

df.to_csv(
str(output_path)
+ f"/fit_replicas/replica_{replica_index}"
+ "/best_epoch_specs.csv",
index=False,
)
20 changes: 10 additions & 10 deletions colibri/scripts/mc_postfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,41 +71,41 @@ def main():

replicas_list = sorted(list(replicas_path.iterdir()))

final_losses = jnp.array([])
best_epoch_losses = jnp.array([])

valid_replicas = [] # Keep track of which replicas are valid

for replica in replicas_list:
try:
df = pd.read_csv(replica / "mc_loss.csv")
df = pd.read_csv(replica / "best_epoch_specs.csv")
if (
df.empty
or df["training_loss"].iloc[-1] is pd.NA
or pd.isna(df["training_loss"].iloc[-1])
or df["best_train_loss"].iloc[0] is pd.NA
or pd.isna(df["best_train_loss"].iloc[0])
):
log.warning(f"Skipping replica {replica} - empty or NaN training_loss")
continue

final_loss = df.iloc[-1]["training_loss"]
final_losses = jnp.concatenate(
(final_losses, jnp.array([final_loss])), axis=0
best_epoch_loss = df.iloc[-1]["best_train_loss"]
best_epoch_losses = jnp.concatenate(
(best_epoch_losses, jnp.array([best_epoch_loss])), axis=0
)
valid_replicas.append(replica)

except (FileNotFoundError, KeyError, IndexError) as e:
log.critical(f"Skipping replica {replica} - error reading file: {e}")
continue

mean_loss = jnp.mean(final_losses)
std_loss = jnp.std(final_losses)
mean_loss = jnp.mean(best_epoch_losses)
std_loss = jnp.std(best_epoch_losses)

# List of replicas to keep
good_replicas = []

# We will copy the replicas and order them starting with 0
# and increasing the index for each good replica we find
i = 0
for replica, loss in zip(valid_replicas, final_losses):
for replica, loss in zip(valid_replicas, best_epoch_losses):

index = int(replica.name.split("_")[1])

Expand Down
34 changes: 34 additions & 0 deletions colibri/tests/test_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jax.numpy as jnp
import pytest
from numpy.testing import assert_allclose
from unittest.mock import MagicMock

from colibri.likelihood import LogLikelihood, log_likelihood, mc_log_likelihood
from colibri.mc_utils import MCPseudodata
Expand Down Expand Up @@ -501,3 +502,36 @@ def test_LogLikelihood_call_with_batch_with_inv_cov(pos_penalty):
expected = -0.5 * (chi2_b + pos_pen + integ_pen)

assert_allclose(float(ll_value_batched), float(expected))


def test_LogLikelihood_get_pos_pass():
"""
Tests the get_pos_pass method of LogLikelihood.
"""
positivity_penalty_settings = {
"positivity_penalty": True,
"alpha": 1e-7,
"lambda_positivity": 1000,
}

log_likelihood_class = LogLikelihood(
central_covmat_index=MOCK_CENTRAL_COVMAT_INDEX,
pdf_model=MOCK_PDF_MODEL,
fit_xgrid=TEST_XGRID,
forward_map=TEST_FORWARD_MAP_DIS,
fast_kernel_arrays=TEST_FK_ARRAYS,
positivity_fast_kernel_arrays=TEST_POS_FK_ARRAYS,
penalty_posdata=MOCK_PENALTY_POSDATA,
positivity_penalty_settings=positivity_penalty_settings,
integrability_penalty=integrability_penalty,
)

params = jnp.array([0.3, 0.4])

# MOCK_PENALTY_POSDATA returns [5.0] by default, which is > THRESHOLD_POS (1e-6)
pos_pass = log_likelihood_class.get_pos_pass(params)
assert pos_pass == False

log_likelihood_class.penalty_posdata = MagicMock(return_value=jnp.array([1e-10]))
pos_pass = log_likelihood_class.get_pos_pass(params)
assert pos_pass == True
20 changes: 18 additions & 2 deletions colibri/tests/test_monte_carlo_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,21 @@ def update(self, epoch_val_loss):
return self


class MockLikelihood:
def __call__(self, *args, **kwargs):
return 0.0

def get_pos_pass(self, params):
return True


def test_monte_carlo_fit_runs_without_errors():
# Provide necessary inputs for the function
training_indices = jnp.arange(100)
data_batch = data_batches(training_indices, 100)

result = monte_carlo_fit(
mc_log_likelihood=(lambda *args: 0.0, lambda *args: 0.0),
mc_log_likelihood=(MockLikelihood(), MockLikelihood()),
len_trval_data=(100, 50),
pdf_initial_parameters=np.zeros((N_PARAMS,)),
optimizer_provider=MockOptimizerProvider(),
Expand Down Expand Up @@ -77,7 +85,14 @@ def create_directory_side_effect(*args, **kwargs):

# Define mock ultranest fit
mock_monte_carlo_fit = Mock()
mock_monte_carlo_fit.monte_carlo_specs = {}
mock_monte_carlo_fit.monte_carlo_specs = {
"best_epoch_specs": {
"epoch": 1,
"best_parameters": 2,
"best_val_loss": 3,
"best_train_loss": 4,
}
}
mock_monte_carlo_fit.training_loss = jnp.array([0.1, 0.2, 0.3])
mock_monte_carlo_fit.validation_loss = jnp.array([0.2, 0.3, 0.4])
mock_monte_carlo_fit.optimized_parameters = jnp.array([0.0, 0.0])
Expand All @@ -95,3 +110,4 @@ def create_directory_side_effect(*args, **kwargs):
# Assertions - check if files are created in the output path
assert (tmp_path / "fit_replicas/replica_1/mc_loss.csv").exists()
assert (tmp_path / "fit_replicas/replica_1/mc_result_replica_1.csv").exists()
assert (tmp_path / "fit_replicas/replica_1/best_epoch_specs.csv").exists()
Loading