diff --git a/colibri/core.py b/colibri/core.py index 238b947e..c9d2cd33 100644 --- a/colibri/core.py +++ b/colibri/core.py @@ -206,12 +206,15 @@ class GradientDescentResult: Recorded (epoch) validation losses (sampled according to record_every). specs: dict Dictionary of settings used for the run (epochs, batch size, etc.). + best_epoch: dict + Dictionary containing the best epoch and corresponding parameters and losses. """ optimized_parameters: Any training_loss: jnp.array validation_loss: jnp.array specs: Dict[str, Any] + best_epoch: Dict[str, Any] @dataclass(frozen=True) diff --git a/colibri/doc/sphinx/source/tutorials/scripts/fit-folders.rst b/colibri/doc/sphinx/source/tutorials/scripts/fit-folders.rst index 616c645c..05a9c137 100644 --- a/colibri/doc/sphinx/source/tutorials/scripts/fit-folders.rst +++ b/colibri/doc/sphinx/source/tutorials/scripts/fit-folders.rst @@ -119,6 +119,14 @@ the first few lines would look like this: which would represent the losses for the first 150 epochs (i.e. 0, 1, 2 are just labels). +Best epoch +----------- +The final model weights for each replica are taken from the best epoch. This is defined +as the epoch that satisfies the positivity threshold (as specified in `likelihood.py``) +and achieves the lowest validation loss among all such epochs. +The information for the best epoch is stored in ``fit_replicas/replica_n/best_epoch_specs.csv``, +which contains the epoch number, the training loss and the validation loss. + Postfit selection """"""""""""""""" @@ -135,7 +143,7 @@ You can run a postfit selection by running: where the ``-c`` is optional and ``CHI2_THRESHOLD`` is a number that determines the :math:`\chi^2` threshold above which a MC replica will be rejected, where this -value is taken from the last row of the ``training_loss`` column shown above. +value is taken from the training loss of the best epoch. This can also be run as ``--chi2_threshold`` instead of ``-c``. If no value is specified, a default value of 1.5 will be applied. diff --git a/colibri/gradient_descent.py b/colibri/gradient_descent.py index 39a62e9c..d9d50cfd 100644 --- a/colibri/gradient_descent.py +++ b/colibri/gradient_descent.py @@ -36,6 +36,8 @@ def run_gradient_descent( max_epochs: int, data_batch: Optional[colibri.DataBatches] = None, record_every: int = 50, + positivity_check_fn: Optional[Callable[[jnp.ndarray], bool]] = None, + threshold_chi2: float = 10.0, ) -> GradientDescentResult: """Generic gradient descent loop. @@ -90,6 +92,12 @@ def _step(p, ostate, batch: BatchSpec): train_losses = [] val_losses = [] + best_params = params + best_train_loss = jnp.inf + best_val_loss = jnp.inf + best_epoch_idx = 0 + any_pos_pass = False + if data_batch is None: # single fake iterator repeatedly yielding EMPTY_BATCH def _gen(): @@ -114,6 +122,27 @@ def _gen(): epoch_val_loss = validation_loss_fn(params) early_stopper = early_stopper.update(epoch_val_loss) + # Update best epoch based on positivity and validation loss + pos_pass = True + if positivity_check_fn is not None: + pos_pass = positivity_check_fn(params) + + update_best = False + meets_threshold = epoch_val_loss < threshold_chi2 + if meets_threshold: + if pos_pass and not any_pos_pass: + update_best = True + any_pos_pass = True + elif pos_pass == any_pos_pass and pos_pass: + if epoch_val_loss < best_val_loss: + update_best = True + + if update_best: + best_val_loss = epoch_val_loss + best_train_loss = epoch_train_loss + best_params = params + best_epoch_idx = epoch + if record_every and (epoch % record_every == 0): log.info( f"Epoch {epoch}, loss: {epoch_train_loss:.3f}, " @@ -127,6 +156,24 @@ def _gen(): log.info(f"Early stopping at epoch {epoch}") break + if best_epoch_idx == 0: + log.warning( + "No epoch passed positivity check. Returning last epoch's parameters." + ) + best_epoch_dict = { + "epoch": epoch, + "best_parameters": params, + "best_val_loss": epoch_val_loss, + "best_train_loss": epoch_train_loss, + } + else: + best_epoch_dict = { + "epoch": best_epoch_idx, + "best_parameters": best_params, + "best_val_loss": best_val_loss, + "best_train_loss": best_train_loss, + } + return GradientDescentResult( optimized_parameters=params, training_loss=jnp.array(train_losses), @@ -136,4 +183,5 @@ def _gen(): "batch_size": batch_size, "record_every": record_every, }, + best_epoch=best_epoch_dict, ) diff --git a/colibri/likelihood.py b/colibri/likelihood.py index 7f3d2563..3171c675 100644 --- a/colibri/likelihood.py +++ b/colibri/likelihood.py @@ -12,6 +12,8 @@ from colibri.commondata_utils import CentralCovmatIndex from colibri.data_batch import BatchSpec +THRESHOLD_POS = 1e-6 + class LogLikelihood(object): """ @@ -69,6 +71,14 @@ def __init__( self.fast_kernel_arrays = fast_kernel_arrays self.positivity_fast_kernel_arrays = positivity_fast_kernel_arrays + def get_pos_pass(self, params): + _, pdf = self.pred_and_pdf(params, self.fast_kernel_arrays) + pos_pass, _ = self.positivity_check_and_penalty( + pdf, + self.positivity_fast_kernel_arrays, + ) + return pos_pass + def __call__(self, params, batch: BatchSpec | None = None): """ Note that this function is called by the samplers, and it must be @@ -99,6 +109,25 @@ def __call__(self, params, batch: BatchSpec | None = None): batch=batch, ) + def positivity_check_and_penalty(self, pdf, positivity_fast_kernel_arrays): + if self.positivity_penalty_settings["positivity_penalty"]: + pos_penalties = self.penalty_posdata( + pdf, + self.positivity_penalty_settings["alpha"], + self.positivity_penalty_settings["lambda_positivity"], + positivity_fast_kernel_arrays, + ) + pos_pass = jnp.all(pos_penalties < THRESHOLD_POS) + + pos_penalty = jnp.sum( + pos_penalties, + axis=-1, + ) + else: + pos_penalty = 0 + pos_pass = True + return pos_pass, pos_penalty + @partial(jax.jit, static_argnames=("self",)) def log_likelihood( self, @@ -140,18 +169,10 @@ def log_likelihood( else: inv_covmat = batch.inv_cov - if self.positivity_penalty_settings["positivity_penalty"]: - pos_penalty = jnp.sum( - self.penalty_posdata( - pdf, - self.positivity_penalty_settings["alpha"], - self.positivity_penalty_settings["lambda_positivity"], - positivity_fast_kernel_arrays, - ), - axis=-1, - ) - else: - pos_penalty = 0 + _, pos_penalty = self.positivity_check_and_penalty( + pdf, + positivity_fast_kernel_arrays, + ) integ_penalty = jnp.sum( self.integrability_penalty( diff --git a/colibri/monte_carlo_fit.py b/colibri/monte_carlo_fit.py index 76809ffe..c05c3a95 100644 --- a/colibri/monte_carlo_fit.py +++ b/colibri/monte_carlo_fit.py @@ -28,6 +28,7 @@ def monte_carlo_fit( early_stopper, max_epochs, data_batches, + threshold_chi2=10.0, ): """ This function performs a Monte Carlo fit. @@ -81,7 +82,8 @@ def loss_validation(parameters): log.info("Starting Monte Carlo fit...") t0 = time.time() - # Delegate to generic gradient descent + positivity_check_fn = mc_log_likelihood[0].get_pos_pass + gd_result = run_gradient_descent( initial_parameters=pdf_initial_parameters.copy(), training_loss_fn=loss_training, @@ -91,6 +93,8 @@ def loss_validation(parameters): max_epochs=max_epochs, data_batch=data_batches, record_every=50, + positivity_check_fn=positivity_check_fn, + threshold_chi2=threshold_chi2, ) t1 = time.time() @@ -101,10 +105,11 @@ def loss_validation(parameters): "max_epochs": max_epochs, "batch_size": data_batches.batch_size, "batch_seed": data_batches.batch_seed, + "best_epoch_specs": gd_result.best_epoch, }, training_loss=gd_result.training_loss, validation_loss=gd_result.validation_loss, - optimized_parameters=gd_result.optimized_parameters, + optimized_parameters=gd_result.best_epoch["best_parameters"], ) @@ -168,3 +173,21 @@ def run_monte_carlo_fit(monte_carlo_fit, pdf_model, output_path, replica_index, index=False, float_format="%.5e", ) + + best_epoch_specs = mc_fit.monte_carlo_specs.get("best_epoch_specs") + + df = pd.DataFrame( + { + "best_epoch": best_epoch_specs["epoch"], + "best_val_loss": best_epoch_specs["best_val_loss"], + "best_train_loss": best_epoch_specs["best_train_loss"], + }, + index=[0], + ) + + df.to_csv( + str(output_path) + + f"/fit_replicas/replica_{replica_index}" + + "/best_epoch_specs.csv", + index=False, + ) diff --git a/colibri/scripts/mc_postfit.py b/colibri/scripts/mc_postfit.py index 0b78d2e8..71ec6568 100644 --- a/colibri/scripts/mc_postfit.py +++ b/colibri/scripts/mc_postfit.py @@ -71,24 +71,24 @@ def main(): replicas_list = sorted(list(replicas_path.iterdir())) - final_losses = jnp.array([]) + best_epoch_losses = jnp.array([]) valid_replicas = [] # Keep track of which replicas are valid for replica in replicas_list: try: - df = pd.read_csv(replica / "mc_loss.csv") + df = pd.read_csv(replica / "best_epoch_specs.csv") if ( df.empty - or df["training_loss"].iloc[-1] is pd.NA - or pd.isna(df["training_loss"].iloc[-1]) + or df["best_train_loss"].iloc[0] is pd.NA + or pd.isna(df["best_train_loss"].iloc[0]) ): log.warning(f"Skipping replica {replica} - empty or NaN training_loss") continue - final_loss = df.iloc[-1]["training_loss"] - final_losses = jnp.concatenate( - (final_losses, jnp.array([final_loss])), axis=0 + best_epoch_loss = df.iloc[-1]["best_train_loss"] + best_epoch_losses = jnp.concatenate( + (best_epoch_losses, jnp.array([best_epoch_loss])), axis=0 ) valid_replicas.append(replica) @@ -96,8 +96,8 @@ def main(): log.critical(f"Skipping replica {replica} - error reading file: {e}") continue - mean_loss = jnp.mean(final_losses) - std_loss = jnp.std(final_losses) + mean_loss = jnp.mean(best_epoch_losses) + std_loss = jnp.std(best_epoch_losses) # List of replicas to keep good_replicas = [] @@ -105,7 +105,7 @@ def main(): # We will copy the replicas and order them starting with 0 # and increasing the index for each good replica we find i = 0 - for replica, loss in zip(valid_replicas, final_losses): + for replica, loss in zip(valid_replicas, best_epoch_losses): index = int(replica.name.split("_")[1]) diff --git a/colibri/tests/test_likelihood.py b/colibri/tests/test_likelihood.py index 4d00e77c..b08ce67b 100644 --- a/colibri/tests/test_likelihood.py +++ b/colibri/tests/test_likelihood.py @@ -8,6 +8,7 @@ import jax.numpy as jnp import pytest from numpy.testing import assert_allclose +from unittest.mock import MagicMock from colibri.likelihood import LogLikelihood, log_likelihood, mc_log_likelihood from colibri.mc_utils import MCPseudodata @@ -501,3 +502,36 @@ def test_LogLikelihood_call_with_batch_with_inv_cov(pos_penalty): expected = -0.5 * (chi2_b + pos_pen + integ_pen) assert_allclose(float(ll_value_batched), float(expected)) + + +def test_LogLikelihood_get_pos_pass(): + """ + Tests the get_pos_pass method of LogLikelihood. + """ + positivity_penalty_settings = { + "positivity_penalty": True, + "alpha": 1e-7, + "lambda_positivity": 1000, + } + + log_likelihood_class = LogLikelihood( + central_covmat_index=MOCK_CENTRAL_COVMAT_INDEX, + pdf_model=MOCK_PDF_MODEL, + fit_xgrid=TEST_XGRID, + forward_map=TEST_FORWARD_MAP_DIS, + fast_kernel_arrays=TEST_FK_ARRAYS, + positivity_fast_kernel_arrays=TEST_POS_FK_ARRAYS, + penalty_posdata=MOCK_PENALTY_POSDATA, + positivity_penalty_settings=positivity_penalty_settings, + integrability_penalty=integrability_penalty, + ) + + params = jnp.array([0.3, 0.4]) + + # MOCK_PENALTY_POSDATA returns [5.0] by default, which is > THRESHOLD_POS (1e-6) + pos_pass = log_likelihood_class.get_pos_pass(params) + assert pos_pass == False + + log_likelihood_class.penalty_posdata = MagicMock(return_value=jnp.array([1e-10])) + pos_pass = log_likelihood_class.get_pos_pass(params) + assert pos_pass == True diff --git a/colibri/tests/test_monte_carlo_fit.py b/colibri/tests/test_monte_carlo_fit.py index 1ce522d4..ae968e8c 100644 --- a/colibri/tests/test_monte_carlo_fit.py +++ b/colibri/tests/test_monte_carlo_fit.py @@ -39,13 +39,21 @@ def update(self, epoch_val_loss): return self +class MockLikelihood: + def __call__(self, *args, **kwargs): + return 0.0 + + def get_pos_pass(self, params): + return True + + def test_monte_carlo_fit_runs_without_errors(): # Provide necessary inputs for the function training_indices = jnp.arange(100) data_batch = data_batches(training_indices, 100) result = monte_carlo_fit( - mc_log_likelihood=(lambda *args: 0.0, lambda *args: 0.0), + mc_log_likelihood=(MockLikelihood(), MockLikelihood()), len_trval_data=(100, 50), pdf_initial_parameters=np.zeros((N_PARAMS,)), optimizer_provider=MockOptimizerProvider(), @@ -77,7 +85,14 @@ def create_directory_side_effect(*args, **kwargs): # Define mock ultranest fit mock_monte_carlo_fit = Mock() - mock_monte_carlo_fit.monte_carlo_specs = {} + mock_monte_carlo_fit.monte_carlo_specs = { + "best_epoch_specs": { + "epoch": 1, + "best_parameters": 2, + "best_val_loss": 3, + "best_train_loss": 4, + } + } mock_monte_carlo_fit.training_loss = jnp.array([0.1, 0.2, 0.3]) mock_monte_carlo_fit.validation_loss = jnp.array([0.2, 0.3, 0.4]) mock_monte_carlo_fit.optimized_parameters = jnp.array([0.0, 0.0]) @@ -95,3 +110,4 @@ def create_directory_side_effect(*args, **kwargs): # Assertions - check if files are created in the output path assert (tmp_path / "fit_replicas/replica_1/mc_loss.csv").exists() assert (tmp_path / "fit_replicas/replica_1/mc_result_replica_1.csv").exists() + assert (tmp_path / "fit_replicas/replica_1/best_epoch_specs.csv").exists()