diff --git a/README.md b/README.md index ea78088..aaf4fa2 100644 --- a/README.md +++ b/README.md @@ -32,8 +32,6 @@ This link points you to additional references for setting up your environment co 4. Install `pre-commit` hooks using `pre-commit install`. -5. INstall netcdf4 with pip. After activating your environment, run `pip install netcdf4`. This package cannot be installed with poetry because of dependencies. - ### 3. Downloading input data For running the model on real cliamte data, please download monthly climate model data and regrid it to an icosahedral grid using ClimateSet https://github.com/RolnickLab/ClimateSet. diff --git a/climatem/config.py b/climatem/config.py index 6d6165b..3487d5e 100644 --- a/climatem/config.py +++ b/climatem/config.py @@ -4,6 +4,8 @@ class expParams: + """Experiment setup: paths, dimensions, random seed, and hardware config.""" + def __init__( self, exp_path, # Path to where the output will be saved i.e. model runs, plots @@ -38,6 +40,8 @@ def __init__( class dataParams: + """Data loading: paths, scenarios, variables, batch size, and preprocessing options.""" + def __init__( self, data_dir, # The processed (normalized, deseasonalized, numpy...) data will be stored here @@ -64,6 +68,7 @@ def __init__( channels_last: bool = False, # last dimension of data is the channel ishdf5: bool = False, # numpy vs hdf5. for now only numpy is supported. Redundant with next param data_format: str = "numpy", # numpy vs hdf5. for now only numpy is supported + forcing_conditioning: str = "raw", # how to condition on forcings: raw | template | mode | region (SAVAR) seq_to_seq: bool = True, # predicting a sequence from a sequence? train_val_interval_length: int = 11, load_train_into_mem: bool = True, @@ -100,6 +105,7 @@ def __init__( self.channels_last = channels_last self.ishdf5 = ishdf5 self.data_format = data_format + self.forcing_conditioning = forcing_conditioning self.seq_to_seq = seq_to_seq self.train_val_interval_length = train_val_interval_length self.load_train_into_mem = load_train_into_mem @@ -123,6 +129,8 @@ def __init__( class trainParams: + """Training loop: learning rate, iterations, patience for phase transition, and validation frequency.""" + def __init__( self, ratio_train: float = 0.9, @@ -147,6 +155,8 @@ def __init__( class modelParams: + """Model architecture: latent dynamics type, MLP sizes, embedding, and causal mask options.""" + def __init__( self, instantaneous: bool = False, # Allow instantaneous connections? @@ -160,11 +170,21 @@ def __init__( num_layers: int = 2, num_output: int = 2, # NOT SURE position_embedding_dim: int = 100, # Dimension of positional embedding + reduce_encoding_pos_dim: bool = False, # Reduce encoder positional embedding dimension by x10 + tau_neigh: int = 0, # Legacy neighborhood radius used in older configs + hard_gumbel: bool = False, # Legacy mask sampling flag used in analysis scripts transition_param_sharing: bool = True, position_embedding_transition: int = 100, fixed: bool = False, # Do we fix the causal graph? Should be in gt_params maybe fixed_output_fraction=None, # This is used if we fix the mask, and want to get a fix number of 0 and 1 constraint_func: str = "trace", # This is used for the constraint - trace is the correct one here + use_exogenous: bool = False, # NEW: Enable conditioning on exogenous forcings (CO2 + aerosols) + d_y_co2: int = 1, # NEW: Dimension of CO2 forcing (typically 1 for global, or spatial_dim for local) + d_y_aerosol: int = 900, # NEW: Dimension of aerosol forcing (typically spatial_dim for local effects) + use_forced_latents: bool = False, # NEW: Map forcings directly to dedicated latent dimensions + n_forced_latents_co2: int = 1, # NEW: Number of latent dimensions for CO2 forcing + n_forced_latents_aerosol: int = 2, # NEW: Number of latent dimensions for aerosol forcing + forcing_arch: str = "baseline", # NEW: baseline | transitioned | predefined ): self.instantaneous = instantaneous self.no_w_constraint = no_w_constraint @@ -177,14 +197,26 @@ def __init__( self.num_hidden_mixing = num_hidden_mixing self.num_layers_mixing = num_layers_mixing self.position_embedding_dim = position_embedding_dim + self.reduce_encoding_pos_dim = reduce_encoding_pos_dim + self.tau_neigh = tau_neigh + self.hard_gumbel = hard_gumbel self.transition_param_sharing = transition_param_sharing self.position_embedding_transition = position_embedding_transition self.fixed = fixed self.fixed_output_fraction = fixed_output_fraction self.constraint_func = constraint_func + self.use_exogenous = use_exogenous + self.d_y_co2 = d_y_co2 + self.d_y_aerosol = d_y_aerosol + self.use_forced_latents = use_forced_latents + self.n_forced_latents_co2 = n_forced_latents_co2 + self.n_forced_latents_aerosol = n_forced_latents_aerosol + self.forcing_arch = forcing_arch class optimParams: + """Optimization: loss coefficients, ALM penalty parameters, and constraint schedules.""" + def __init__( self, optimizer: str = "rmsprop", @@ -227,6 +259,11 @@ def __init__( acyclic_min_iter_convergence: float = 1_000, mu_acyclic_init: float = 0, h_acyclic_threshold: float = 0, + forcing_co2_coeff: float = 10.0, # Weight for CO2 forcing reconstruction loss + forcing_aerosol_coeff: float = 10.0, # Weight for aerosol forcing reconstruction loss + forcing_latent_supervision_coeff: float = 10.0, # Weight for direct forcing latent supervision loss + decoder_utilization_coeff: float = 0.1, # Penalty coefficient for underutilized forcing latent decoder weights + min_forcing_decoder_norm: float = 1.5, # Target minimum L2 norm for forcing latent decoder weights udpate_ALM_using_valid: bool = True, # If False use training loss convergence if True uses valid loss convergence udpate_ALM_using_nll: bool = True, # If False use augmented loss convergence if True uses NLL convergence ): @@ -275,11 +312,19 @@ def __init__( self.mu_acyclic_init = mu_acyclic_init self.h_acyclic_threshold = h_acyclic_threshold + self.forcing_co2_coeff = forcing_co2_coeff + self.forcing_aerosol_coeff = forcing_aerosol_coeff + self.forcing_latent_supervision_coeff = forcing_latent_supervision_coeff + self.decoder_utilization_coeff = decoder_utilization_coeff + self.min_forcing_decoder_norm = min_forcing_decoder_norm + self.udpate_ALM_using_valid = udpate_ALM_using_valid self.udpate_ALM_using_nll = udpate_ALM_using_nll class plotParams: + """Plotting frequency and toggle options for training diagnostics.""" + def __init__( self, plot_freq: int = 500, plot_through_time: bool = True, print_freq: int = 500, savar: bool = False ): @@ -290,25 +335,80 @@ def __init__( class savarParams: - # Params for generating synthetic data + """ + Configuration for SAVAR synthetic data generation. + + Controls all aspects of the Seasonal Vector Auto-Regressive data generator: + spatial grid, temporal length, causal graph structure, seasonality, external + forcing (CO2 + aerosol), noise characteristics, and background state. + See ``climatem/synthetic_data/savar.py`` for the generator implementation. + """ + def __init__( self, - time_len: int = 10_000, # Time length of the data - comp_size: int = 10, # Each component size - noise_val: float = 0.2, # Noise variance relative to signal - n_per_col: int = 2, # square grid, equivalent of lat/lon - difficulty: str = "easy", # easy, med_easy, med_hard, hard: difficulty of the graph - seasonality: bool = False, # Seasonality in synthetic data - overlap: float = 0, # Modes overlap between 0 and 1 - is_forced: bool = False, # Forcings in synthetic data - f_1: int = 1, - f_2: int = 2, - f_time_1: int = 4000, - f_time_2: int = 8000, - ramp_type: str = "linear", - linearity: str = "linear", - poly_degrees: List[int] = [2], - plot_original_data: bool = True, + # Basic data generation parameters + time_len: int = 10_000, # Total number of timesteps to generate (longer = more data for training) + comp_size: int = 10, # Size of each spatial component/mode + noise_val: float = 0.02, # Noise strength relative to signal (higher = noisier data) + n_per_col: int = 2, # Number of grid points per row/column in square spatial grid (total spatial size = n_per_col^2 * comp_size) + # Causal graph structure + difficulty: str = "easy", # Complexity of causal graph: "easy" (sparse), "med_easy", "med_hard", "hard" (dense/complex) + # Seasonality parameters + seasonality: bool = False, # Whether to add seasonal variations (e.g., annual cycles like climate data) + periods: List[float] = [ + 365, + 182.5, + 60, + ], # Seasonal periods in days (e.g., annual=365, semi-annual=182.5, bi-monthly=60) + amplitudes: List[float] = [0.06, 0.02, 0.01], # Amplitude of each seasonal component (matched to periods list) + phases: List[float] = [ + 0.0, + 0.7853981634, + 1.5707963268, + ], # Phase shifts for seasonality in radians (0, π/4, π/2) + yearly_jitter_amp: float = 0.05, # Year-to-year random variation in seasonal amplitude (adds realism) + yearly_jitter_phase: float = 0.10, # Year-to-year random variation in seasonal phase (adds realism) + # Spatial structure + overlap: float = 0, # Whether spatial modes can overlap between 0 and 1 (True = modes share spatial regions) + # External forcing parameters + is_forced: bool = False, # Whether to include external forcings like CO2 and aerosols (mimics climate change) + f_1: int = 0, # Initial forcing value at start of ramp (baseline level). NOTE: used as float downstream + f_2: int = 1, # Final forcing value at end of ramp (target level). NOTE: used as float downstream + f_time_1: int = 4000, # Timestep when forcing ramp begins (relative to start after transient) + f_time_2: int = 8000, # Timestep when forcing ramp ends and forcing becomes constant at f_2 + ramp_type: str = "linear", # Temporal evolution of forcing: "linear", "quadratic", "exponential", "sigmoid", "sinusoidal" + # Dynamics type + linearity: str = "linear", # Type of dynamics: "linear" (VAR model), "polynomial", or "nonlinear" (neural net) + poly_degrees: List[int] = [ + 2 + ], # Polynomial degrees to use if linearity="polynomial" (e.g., [2] for quadratic, [2,3] for quad+cubic) + # Visualization + plot_original_data: bool = True, # Whether to generate plots during data generation + # Separate forcing fields (more realistic than single forcing) + use_separate_forcings: bool = False, # Use distinct CO2 and aerosol forcing fields with different dynamics + forcing_amplification: float = 1.2, # Overall scaling factor for forcing magnitudes + # Aerosol forcing parameters + aerosol_scale: float = 0.02, # Strength of aerosol forcing (typically negative for cooling effect, positive here for magnitude) + aerosol_spatial_contrast: float = 1.05, # Regional variability of aerosol effects (>1 increases heterogeneity across space) + aerosol_ramp_up_time: int = 2000, # When aerosol forcing starts increasing (default: 20% of time_len) + aerosol_peak_time: int = 5000, # When aerosol forcing reaches maximum (default: 50% of time_len) + aerosol_decline_time: int = 8000, # When aerosol forcing finishes declining to baseline (default: 80% of time_len) + aerosol_timing_stagger: float = 0.3, # Fraction of timeline to stagger aerosol latents (creates distinct temporal patterns per latent) + # Forcing causal structure parameters + n_co2_latents: int = 1, # Number of latent variables representing CO2 forcing in causal graph (typically 1 for global) + n_aerosol_latents: int = 2, # Number of latent variables representing aerosol forcing (multiple for regional effects) + co2_effect_strength: float = 0.25, # Causal coefficient strength for CO2 → climate mode links (larger = stronger influence) + aerosol_effect_strength: float = 0.20, # Causal coefficient strength for aerosol → climate mode links (larger = stronger influence) + # Noise temporal correlation (AR(1) / Ornstein-Uhlenbeck) + noise_ar1_rho: float = 0.95, # AR(1) persistence parameter ρ (0=white noise, 0.95=realistic red noise). Can also be "decay" for mode-dependent ρₖ = exp(-k/K) + noise_ar1: bool = True, # Use AR(1) (red) noise instead of white noise for realistic temporal correlations + # Background state parameters + enable_background: bool = False, # Whether to add low-frequency background state (slow climate mean state drift) + background_strength: float = 0.3, # Strength relative to mode std (if < 1 and mode="relative") or absolute magnitude + background_strength_mode: str = "relative", # "relative" to mode std or "absolute" + background_smoothness: float = 0.15, # Controls spatial frequency (higher = smoother spatial patterns) + background_timescale_rho: float = 0.995, # AR(1) persistence (higher = slower temporal evolution, 0.995 ≈ 200 step timescale) + background_n_modes: int = 3, # Number of low-frequency Fourier components for spatial smoothness use_correct_hyperparams: bool = True, # Override some of the model params to match those of savar data if true ): self.time_len = time_len @@ -317,6 +417,11 @@ def __init__( self.n_per_col = n_per_col self.difficulty = difficulty self.seasonality = seasonality + self.periods = periods + self.amplitudes = amplitudes + self.phases = phases + self.yearly_jitter_amp = yearly_jitter_amp + self.yearly_jitter_phase = yearly_jitter_phase self.overlap = overlap self.is_forced = is_forced self.f_1 = f_1 @@ -327,6 +432,29 @@ def __init__( self.linearity = linearity self.poly_degrees = poly_degrees self.plot_original_data = plot_original_data + self.use_separate_forcings = use_separate_forcings + self.forcing_amplification = forcing_amplification + self.aerosol_scale = aerosol_scale + self.aerosol_spatial_contrast = aerosol_spatial_contrast + self.aerosol_ramp_up_time = aerosol_ramp_up_time + self.aerosol_peak_time = aerosol_peak_time + self.aerosol_decline_time = aerosol_decline_time + self.aerosol_timing_stagger = aerosol_timing_stagger + # Forcing causal structure + self.n_co2_latents = n_co2_latents + self.n_aerosol_latents = n_aerosol_latents + self.co2_effect_strength = co2_effect_strength + self.aerosol_effect_strength = aerosol_effect_strength + # Noise temporal correlation + self.noise_ar1_rho = noise_ar1_rho + self.noise_ar1 = noise_ar1 + # Background state parameters + self.enable_background = enable_background + self.background_strength = background_strength + self.background_strength_mode = background_strength_mode + self.background_smoothness = background_smoothness + self.background_timescale_rho = background_timescale_rho + self.background_n_modes = background_n_modes self.use_correct_hyperparams = use_correct_hyperparams diff --git a/climatem/data_loader/causal_datamodule.py b/climatem/data_loader/causal_datamodule.py index 8c35e7f..c1b0aa2 100644 --- a/climatem/data_loader/causal_datamodule.py +++ b/climatem/data_loader/causal_datamodule.py @@ -17,12 +17,39 @@ class CausalDataset(torch.utils.data.Dataset): - def __init__(self, x, y): + def __init__(self, x, y, co2_forcing=None, aerosol_forcing=None, gt_co2_latent=None, gt_aerosol_latent=None): self.x = x self.y = y + self.co2_forcing = co2_forcing + self.aerosol_forcing = aerosol_forcing + self.gt_co2_latent = gt_co2_latent + self.gt_aerosol_latent = gt_aerosol_latent def __getitem__(self, index: int): - return self.x[index], self.y[index] + """ + Return batch as dictionary if forcings are available, otherwise as tuple. + + Returns: + dict: {'x': x, 'y': y, 'co2_forcing': co2, 'aerosol_forcing': aerosol, + 'gt_co2_latent': gt_co2, 'gt_aerosol_latent': gt_aerosol} if forcings present + tuple: (x, y) if forcings not present + """ + if self.co2_forcing is not None and self.aerosol_forcing is not None: + result = { + "x": self.x[index], + "y": self.y[index], + "co2_forcing": self.co2_forcing[index], + "aerosol_forcing": self.aerosol_forcing[index], + } + # Add ground truth forcing latents if available + if self.gt_co2_latent is not None: + result["gt_co2_latent"] = self.gt_co2_latent[index] + if self.gt_aerosol_latent is not None: + result["gt_aerosol_latent"] = self.gt_aerosol_latent[index] + return result + else: + # Backward compatibility: return tuple if no forcings + return self.x[index], self.y[index] def __len__(self): return len(self.x) @@ -36,7 +63,7 @@ class CausalClimateDataModule(ClimateDataModule): """ def __init__(self, tau=5, future_timesteps=1, num_months_aggregated=1, train_val_interval_length=100, **kwargs): - super().__init__(self) + super().__init__(**kwargs) # kwargs are initialized as self.hparams by the Lightning module # WHat is this line? We cannot have different test vs train models @@ -98,6 +125,11 @@ def setup(self, stage: Optional[str] = None): n_per_col=self.hparams.n_per_col, difficulty=self.hparams.difficulty, seasonality=self.hparams.seasonality, + periods=self.hparams.periods, + amplitudes=self.hparams.amplitudes, + phases=self.hparams.phases, + yearly_jitter_amp=self.hparams.yearly_jitter_amp, + yearly_jitter_phase=self.hparams.yearly_jitter_phase, overlap=self.hparams.overlap, is_forced=self.hparams.is_forced, f_1=self.hparams.f_1, @@ -108,6 +140,26 @@ def setup(self, stage: Optional[str] = None): linearity=self.hparams.linearity, poly_degrees=self.hparams.poly_degrees, plot_original_data=self.hparams.plot_original_data, + use_separate_forcings=self.hparams.use_separate_forcings, + forcing_amplification=self.hparams.forcing_amplification, + forcing_conditioning=self.hparams.forcing_conditioning, + aerosol_scale=self.hparams.aerosol_scale, + aerosol_spatial_contrast=self.hparams.aerosol_spatial_contrast, + aerosol_ramp_up_time=self.hparams.aerosol_ramp_up_time, + aerosol_peak_time=self.hparams.aerosol_peak_time, + aerosol_decline_time=self.hparams.aerosol_decline_time, + n_co2_latents=self.hparams.n_co2_latents, + n_aerosol_latents=self.hparams.n_aerosol_latents, + co2_effect_strength=self.hparams.co2_effect_strength, + aerosol_effect_strength=self.hparams.aerosol_effect_strength, + noise_ar1=self.hparams.noise_ar1, + noise_ar1_rho=self.hparams.noise_ar1_rho, + enable_background=self.hparams.enable_background, + background_strength=self.hparams.background_strength, + background_strength_mode=self.hparams.background_strength_mode, + background_smoothness=self.hparams.background_smoothness, + background_timescale_rho=self.hparams.background_timescale_rho, + background_n_modes=self.hparams.background_n_modes, ) self.savar_name = train_val_input4mips.savar_name @@ -198,20 +250,50 @@ def setup(self, stage: Optional[str] = None): self.savar_gt_modes = train_val_input4mips.gt_modes self.savar_gt_noise = train_val_input4mips.gt_noise self.savar_gt_adj = train_val_input4mips.gt_adj + self.forcing_indices = getattr(train_val_input4mips, "forcing_indices", None) + # Store reference to SAVAR instance for later plotting + self.savar = train_val_input4mips train_x, train_y = train train_x = train_x.reshape((train_x.shape[0], train_x.shape[1], train_x.shape[2], -1)) train_y = train_y.reshape((train_y.shape[0], train_y.shape[1], train_y.shape[2], -1)) + # Get forcing data if available (only for SAVAR with dual exogenous forcings) + # For other datasets, getattr returns None and CausalDataset falls back to tuple mode + co2_forcing_train = getattr(train_val_input4mips, "co2_forcing_train", None) + aerosol_forcing_train = getattr(train_val_input4mips, "aerosol_forcing_train", None) + co2_forcing_valid = getattr(train_val_input4mips, "co2_forcing_valid", None) + aerosol_forcing_valid = getattr(train_val_input4mips, "aerosol_forcing_valid", None) + + # Get ground truth forcing latents if available (for SAVAR with forcing latent supervision) + gt_co2_latent_train = getattr(train_val_input4mips, "gt_co2_latent_train", None) + gt_aerosol_latent_train = getattr(train_val_input4mips, "gt_aerosol_latent_train", None) + gt_co2_latent_valid = getattr(train_val_input4mips, "gt_co2_latent_valid", None) + gt_aerosol_latent_valid = getattr(train_val_input4mips, "gt_aerosol_latent_valid", None) + self.d = train_x.shape[2] - self._data_train = CausalDataset(train_x, train_y) + self._data_train = CausalDataset( + train_x, + train_y, + co2_forcing=co2_forcing_train, + aerosol_forcing=aerosol_forcing_train, + gt_co2_latent=gt_co2_latent_train, + gt_aerosol_latent=gt_aerosol_latent_train, + ) self.n_train = train_x.shape[0] if val is not None: val_x, val_y = val val_x = val_x.reshape((val_x.shape[0], val_x.shape[1], val_x.shape[2], -1)) val_y = val_y.reshape((val_y.shape[0], val_y.shape[1], val_y.shape[2], -1)) - self._data_val = CausalDataset(val_x, val_y) + self._data_val = CausalDataset( + val_x, + val_y, + co2_forcing=co2_forcing_valid, + aerosol_forcing=aerosol_forcing_valid, + gt_co2_latent=gt_co2_latent_valid, + gt_aerosol_latent=gt_aerosol_latent_valid, + ) self.coordinates = train_val_input4mips.coordinates diff --git a/climatem/data_loader/climate_datamodule.py b/climatem/data_loader/climate_datamodule.py index da7419d..2e47d51 100644 --- a/climatem/data_loader/climate_datamodule.py +++ b/climatem/data_loader/climate_datamodule.py @@ -56,6 +56,7 @@ def __init__( output_save_dir: Optional[str] = "Climateset_DATA", reload_climate_set_data=True, seasonality_removal=True, + global_normalization: bool = True, num_ensembles: int = 1, # 1 for first ensemble, -1 for all lon: int = 125, lat: int = 125, @@ -70,11 +71,41 @@ def __init__( n_per_col: int = 2, difficulty: str = "easy", seasonality: bool = False, + periods: List[float] = [365, 182.5, 60], + amplitudes: List[float] = [0.06, 0.02, 0.01], + phases: List[float] = [0.0, 0.7853981634, 1.5707963268], + yearly_jitter_amp: float = 0.05, + yearly_jitter_phase: float = 0.10, overlap: bool = False, is_forced: bool = False, + f_1: float = 0.0, + f_2: float = 1.0, + f_time_1: int = 4000, + f_time_2: int = 8000, + ramp_type: str = "linear", linearity: str = "linear", poly_degrees: List[int] = [2], plot_original_data: bool = True, + use_separate_forcings: bool = False, + forcing_amplification: float = 1.2, + forcing_conditioning: str = "raw", + aerosol_scale: float = 0.02, + aerosol_spatial_contrast: float = 1.05, + aerosol_ramp_up_time: int = 2000, + aerosol_peak_time: int = 5000, + aerosol_decline_time: int = 8000, + n_co2_latents: int = 1, + n_aerosol_latents: int = 2, + co2_effect_strength: float = 0.25, + aerosol_effect_strength: float = 0.20, + noise_ar1: bool = True, + noise_ar1_rho: float = 0.95, + enable_background: bool = False, + background_strength: float = 0.3, + background_strength_mode: str = "relative", + background_smoothness: float = 0.15, + background_timescale_rho: float = 0.995, + background_n_modes: int = 3, ): """ Args: diff --git a/climatem/data_loader/savar_dataset.py b/climatem/data_loader/savar_dataset.py index 32eb952..eb0d059 100644 --- a/climatem/data_loader/savar_dataset.py +++ b/climatem/data_loader/savar_dataset.py @@ -1,48 +1,113 @@ +""" +PyTorch Dataset for SAVAR synthetic climate data. + +Wraps the SAVAR data generation pipeline, handling: +- Generation (via ``generate_save_savar_data``) or loading of pre-generated data +- Normalization and deseasonalization of observations +- Splitting into train/validation temporal sequences for causal discovery +- Extraction and conditioning of CO2 and aerosol forcing inputs + +Typical usage:: + + dataset = SavarDataset(output_save_dir="data/savar", ...) + train, valid = dataset.get_causal_data(tau=5, future_timesteps=1, ...) +""" + import os from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Sequence import numpy as np import torch from climatem.synthetic_data.generate_savar_datasets import generate_save_savar_data from climatem.synthetic_data.graph_evaluation import extract_adjacency_matrix +from climatem.utils import get_logger + +logger = get_logger(__name__) class SavarDataset(torch.utils.data.Dataset): + """ + PyTorch Dataset for SAVAR synthetic climate data. + + Handles loading or generating SAVAR data, applying normalization and deseasonalization, and producing temporal + sequences suitable for training the causal discovery model (LatentTSDCD). + + The dataset stores ground truth information (causal graph, mode weights, forcing trajectories) for evaluation during + training. + """ + def __init__( self, + # --- Output and grid --- output_save_dir: Optional[str] = "Savar_DATA", - lat: int = 125, - lon: int = 125, - tau: int = 5, - global_normalization: bool = True, - seasonality_removal: bool = True, - reload_climate_set_data: Optional[bool] = True, - time_len: int = 10_000, - comp_size: int = 10, - noise_val: float = 0.2, - n_per_col: int = 2, - difficulty: str = "easy", - seasonality: bool = False, - overlap: bool = False, - is_forced: bool = False, - f_1: int = 1, - f_2: int = 2, - f_time_1: int = 4000, - f_time_2: int = 8000, - ramp_type: str = "linear", - linearity: str = "linear", - poly_degrees: List[int] = [2, 3], - plot_original_data: bool = True, + lat: int = 125, # number of latitude grid points (square grid: lat*lon = spatial_resolution) + lon: int = 125, # number of longitude grid points + tau: int = 5, # number of autoregressive lags for causal discovery + global_normalization: bool = True, # z-score normalize the full dataset + seasonality_removal: bool = True, # subtract estimated seasonal cycle before training + reload_climate_set_data: Optional[bool] = True, # if True, load from disk; if False, generate fresh data + # --- Temporal and spatial resolution --- + time_len: int = 10_000, # total timesteps to generate + comp_size: int = 10, # spatial size of each mode pattern (comp_size^2 per mode) + noise_val: float = 0.2, # noise strength (controls SNR) + n_per_col: int = 2, # modes per grid column; total climate modes = n_per_col^2 + difficulty: str = "easy", # causal graph density: "easy", "medium", "hard" + # --- Seasonality --- + seasonality: bool = False, # add seasonal harmonics to the data + periods: List[float] = [365, 182.5, 60], # harmonic periods in timesteps (annual, semi-annual, bi-monthly) + amplitudes: List[float] = [0.06, 0.02, 0.01], # amplitude per harmonic + phases: List[float] = [0.0, 0.7853981634, 1.5707963268], # phase offset (radians) per harmonic + yearly_jitter_amp: float = 0.05, # inter-annual amplitude jitter (fraction) + yearly_jitter_phase: float = 0.10, # inter-annual phase jitter (radians) + overlap: bool = False, # allow spatial mode patterns to overlap + # --- Forcing (external drivers) --- + is_forced: bool = False, # enable external forcing (CO2 + aerosol) + f_1: float = 0.1, # forcing ramp start value + f_2: float = 0.2, # forcing ramp end value + f_time_1: int = 4000, # forcing ramp start timestep + f_time_2: int = 8000, # forcing ramp end timestep + ramp_type: str = "linear", # ramp shape: "linear", "quadratic", "exponential", "sigmoid", "sinusoidal" + linearity: str = "linear", # dynamics type: "linear", "polynomial", or "nonlinear" (tanh) + poly_degrees: List[int] = [2, 3], # polynomial degrees (used when linearity="polynomial") + plot_original_data: bool = True, # produce diagnostic plots during generation + use_separate_forcings: bool = False, # enable dual exogenous forcings (CO2 + aerosol as separate inputs) + forcing_amplification: float = 1.0, # overall scaling factor for forcing magnitudes + forcing_conditioning: str = "raw", # how to condition forcings: "raw" (spatial fields), "mode" (project onto modes) + aerosol_scale: float = 0.02, # aerosol forcing magnitude + aerosol_spatial_contrast: float = 1.05, # spatial contrast exponent for aerosol templates + aerosol_ramp_up_time: int = 2000, # aerosol ramp-up start timestep + aerosol_peak_time: int = 5000, # aerosol peak timestep + aerosol_decline_time: int = 8000, # aerosol decline start timestep + # --- Forcing causal structure --- + n_co2_latents: int = 1, # number of CO2 latent variables (typically 1, global) + n_aerosol_latents: int = 2, # number of regional aerosol latent variables + co2_effect_strength: float = 0.25, # strength of CO2 → climate causal links + aerosol_effect_strength: float = 0.20, # strength of aerosol → climate causal links + # --- Noise temporal correlation (AR(1) / Ornstein-Uhlenbeck) --- + noise_ar1: bool = True, # use AR(1) (red) noise instead of white noise + noise_ar1_rho: float = 0.95, # AR(1) persistence parameter rho (or "decay" for mode-dependent) + # --- Background state (slow climate drift) --- + enable_background: bool = False, # add slow, spatially-smooth background state + background_strength: float = 0.3, # magnitude of background (relative to data std if < 1) + background_strength_mode: str = "relative", # "relative" (fraction of data std) or "absolute" + background_smoothness: float = 0.15, # spatial frequency decay (higher = smoother) + background_timescale_rho: float = 0.995, # AR(1) persistence for background evolution + background_n_modes: int = 3, # number of spatial Fourier modes for background ): super().__init__() self.output_save_dir = Path(output_save_dir) - savar_poly_deg = ( - str(poly_degrees)[1:-1].translate({ord("'"): None}).translate({ord(","): None}).translate({ord(" "): None}) + self.savar_name = ( + f"m_{n_per_col**2}_tl_{time_len}_ifd_{is_forced}_dif_{difficulty}_ns_" + f"{noise_val}_ses_{seasonality}_ol_{overlap}_f1_{f_1}_f2_{f_2}_ft1_{f_time_1}_ft2_{f_time_2}" + f"_rmp_{ramp_type}_lin_{linearity}_pds_{poly_degrees}_asp_{aerosol_scale}_asc_{aerosol_spatial_contrast}" + f"_art_{aerosol_ramp_up_time}_apt_{aerosol_peak_time}_adt_{aerosol_decline_time}" ) - self.savar_name = f"modes_{n_per_col**2}_tl_{time_len}_forced_{is_forced}_dif_{difficulty}_noise_{noise_val}_season_{seasonality}_over_{overlap}_lin_{linearity}_poldeg_{savar_poly_deg}" - self.savar_path = self.output_save_dir / f"{self.savar_name}.npy" + + # SAVAR data is stored in a subfolder named after the dataset + self.savar_dataset_dir = self.output_save_dir / self.savar_name + self.savar_path = self.savar_dataset_dir / "savar.npy" self.global_normalization = global_normalization self.seasonality_removal = seasonality_removal @@ -59,6 +124,11 @@ def __init__( self.n_per_col = n_per_col self.difficulty = difficulty self.seasonality = seasonality + self.periods = periods + self.amplitudes = amplitudes + self.phases = phases + self.yearly_jitter_amp = yearly_jitter_amp + self.yearly_jitter_phase = yearly_jitter_phase self.overlap = overlap self.is_forced = is_forced self.f_1 = f_1 @@ -69,19 +139,142 @@ def __init__( self.linearity = linearity self.poly_degrees = poly_degrees self.plot_original_data = plot_original_data + self.use_separate_forcings = use_separate_forcings + self.forcing_amplification = forcing_amplification + self.forcing_conditioning = forcing_conditioning + self._forcing_conditioning_logged = False + self.aerosol_scale = aerosol_scale + self.aerosol_spatial_contrast = aerosol_spatial_contrast + self.aerosol_ramp_up_time = aerosol_ramp_up_time + self.aerosol_peak_time = aerosol_peak_time + self.aerosol_decline_time = aerosol_decline_time + # Forcing causal structure parameters + self.n_co2_latents = n_co2_latents + self.n_aerosol_latents = n_aerosol_latents + self.co2_effect_strength = co2_effect_strength + self.aerosol_effect_strength = aerosol_effect_strength + # Noise temporal correlation + self.noise_ar1 = noise_ar1 + self.noise_ar1_rho = noise_ar1_rho + # Background state parameters + self.enable_background = enable_background + self.background_strength = background_strength + self.background_strength_mode = background_strength_mode + self.background_smoothness = background_smoothness + self.background_timescale_rho = background_timescale_rho + self.background_n_modes = background_n_modes + self.tau = tau if self.reload_climate_set_data: - self.gt_modes = np.load(self.output_save_dir / f"{self.savar_name}_modes.npy") - self.gt_noise = np.load(self.output_save_dir / f"{self.savar_name}_noise_modes.npy") - links_coeffs = np.load( - self.output_save_dir / f"{self.savar_name}_parameters.npy", allow_pickle=True - ).item()["links_coeffs"] - self.gt_adj = np.array(extract_adjacency_matrix(links_coeffs, n_per_col**2, tau))[::-1] + self._load_mode_artifacts() + + # Load noise data field for signal-noise decomposition plots + noise_data_path = self.savar_dataset_dir / "noise_data_field.npy" + if noise_data_path.exists(): + self.noise_data_field = np.load(noise_data_path) + logger.info(f"Loaded noise data field from {noise_data_path}, shape: {self.noise_data_field.shape}") + else: + self.noise_data_field = None + logger.warning(f"Noise data field not found at {noise_data_path}") + + # Load deterministic data for signal-noise decomposition plots + deterministic_data_path = self.savar_dataset_dir / "savar_deterministic.npy" + if deterministic_data_path.exists(): + self.deterministic_data_field = np.load(deterministic_data_path) + logger.info( + f"Loaded deterministic data field from {deterministic_data_path}, " + f"shape: {self.deterministic_data_field.shape}" + ) + else: + self.deterministic_data_field = None + logger.warning(f"Deterministic data field not found at {deterministic_data_path}") + + # Load background data field (if available) + background_data_path = self.savar_dataset_dir / "background_data_field.npy" + if background_data_path.exists(): + self.background_data_field = np.load(background_data_path) + logger.info( + f"Loaded background data field from {background_data_path}, shape: {self.background_data_field.shape}" + ) + else: + self.background_data_field = None + # Only print warning if background was enabled in parameters + + params = np.load(self.savar_dataset_dir / "parameters.npy", allow_pickle=True).item() + links_coeffs = params["links_coeffs"] + + # Use n_total_latents if available (includes forcing latents), otherwise fall back to n_per_col**2 + n_total_latents = params.get("n_total_latents", n_per_col**2) + self.forcing_indices = params.get("forcing_indices", None) + self.n_climate_modes = params.get("n_climate_modes", n_per_col**2) + + self.gt_adj = np.array(extract_adjacency_matrix(links_coeffs, n_total_latents, tau))[::-1] + + if self.forcing_indices is not None: + logger.info( + f"Loaded extended causal graph with {n_total_latents} latents " + f"(climate: {self.n_climate_modes}, CO2: {len(self.forcing_indices.get('co2', []))}, " + f"aerosol: {len(self.forcing_indices.get('aerosol', []))})" + ) + + # Load separate forcing files if requested + if self.use_separate_forcings and self.is_forced: + co2_forcing_path = self.savar_dataset_dir / "co2_forcing.npy" + aerosol_forcing_path = self.savar_dataset_dir / "aerosol_forcing.npy" + + if co2_forcing_path.exists(): + self.co2_forcing = np.load(co2_forcing_path) + logger.info(f"Loaded CO2 forcing from {co2_forcing_path}, shape: {self.co2_forcing.shape}") + else: + logger.warning(f"CO2 forcing file not found: {co2_forcing_path}") + self.co2_forcing = None + + if aerosol_forcing_path.exists(): + self.aerosol_forcing = np.load(aerosol_forcing_path) + logger.info( + f"Loaded aerosol forcing from {aerosol_forcing_path}, shape: {self.aerosol_forcing.shape}" + ) + else: + logger.warning(f"Aerosol forcing file not found: {aerosol_forcing_path}") + self.aerosol_forcing = None + + # Load ground truth forcing latent trajectories for supervision + co2_latent_path = self.savar_dataset_dir / "co2_latent_trajectory.npy" + aerosol_latent_path = self.savar_dataset_dir / "aerosol_latent_trajectory.npy" + + if co2_latent_path.exists(): + self.gt_co2_latent = np.load(co2_latent_path) + logger.info( + f"Loaded CO2 latent trajectory from {co2_latent_path}, shape: {self.gt_co2_latent.shape}" + ) + else: + logger.warning(f"CO2 latent trajectory not found: {co2_latent_path}") + self.gt_co2_latent = None + + if aerosol_latent_path.exists(): + self.gt_aerosol_latent = np.load(aerosol_latent_path) + logger.info( + f"Loaded aerosol latent trajectory from {aerosol_latent_path}, shape: {self.gt_aerosol_latent.shape}" + ) + else: + logger.warning(f"Aerosol latent trajectory not found: {aerosol_latent_path}") + self.gt_aerosol_latent = None + else: + self.co2_forcing = None + self.aerosol_forcing = None + self.gt_co2_latent = None + self.gt_aerosol_latent = None else: self.gt_modes = None self.gt_noise = None + self.mode_weights = None links_coeffs = None self.gt_adj = None + self.co2_forcing = None + self.aerosol_forcing = None + self.background_data_field = None + self.deterministic_data_field = None + self.noise_data_field = None @staticmethod def aggregate_months(data, num_months_aggregated): @@ -89,27 +282,27 @@ def aggregate_months(data, num_months_aggregated): # check if time dim is divisible by num_months_aggregated # if not print warning and drop the last few months if data.shape[1] % num_months_aggregated != 0: - print("WARNING:num_months_aggregated does not divide time dimension. Dropping last few months.") + logger.warning("num_months_aggregated does not divide time dimension. Dropping last few months.") end_idx = (data.shape[1] // num_months_aggregated) * num_months_aggregated data = data[:, :end_idx] # introduce a new dimension of size num_months_aggregated - print("Inside aggregate_months, and the data before reshaping is:", data.shape) + logger.debug("Inside aggregate_months, and the data before reshaping is:", data.shape) reshaped_data = data.reshape(data.shape[0], -1, num_months_aggregated, *data.shape[2:]) - print("Still inside aggregate months, reshaped_data shape:", reshaped_data.shape) + logger.debug("Still inside aggregate months, reshaped_data shape:", reshaped_data.shape) # average over the new dimension aggregated_data = np.nanmean(reshaped_data, axis=2) - print("Shape of the aggregated data?:", aggregated_data.shape) + logger.debug("Shape of the aggregated data?:", aggregated_data.shape) return aggregated_data def split_data_by_interval(self, data, tau, ratio_train, interval_length=100): """Given a dataset and interval length, divide the data into intervals, then splits each interval into training and validation indices based on ratio.""" # interval_length=10 - print(f"intervallength{interval_length}") - print(f"datashape{data.shape[0]}") + logger.debug(f"intervallength{interval_length}") + logger.debug(f"datashape{data.shape[0]}") assert interval_length <= data.shape[0], "interval length is longer than the data" idx_train, idx_valid = [], [] @@ -148,37 +341,143 @@ def get_overlapping_sequences(self, data, idxs, tau, future_timesteps): # (year, months, lon, lat) def load_savar_data(self, filepath): data = np.load(filepath, allow_pickle=True) - print(f"Loaded data shape: {data.shape}") + logger.debug(f"Loaded data shape: {data.shape}") time_steps = data.shape[1] data_reshaped = data.T.reshape((time_steps, self.lat, self.lon)) - print(f"Loaded data shape after: {data_reshaped.shape}") + logger.debug(f"Loaded data shape after: {data_reshaped.shape}") return data_reshaped - def get_causal_data( - self, - tau, - future_timesteps, - channels_last, - num_vars, - num_scenarios, - num_ensembles, - num_years, - mode, - num_months_aggregated=1, - ratio_train=None, - interval_length=100, - ): + def reshape_forcing_data(self, forcing_data): """ - Constructs dataset for causal discovery model. + Reshape forcing data from SAVAR format (spatial_res, time) to match observations. - Splits each scenario into training and validation sets, then generates overlapping sequences. + Args: + forcing_data: Array of shape (spatial_resolution, time_length) + + Returns: + Reshaped array of shape (time_length, lat, lon) """ - print(f"Getting causal data [mode={mode}] ...") - # TODO: change + .npy... + if forcing_data is None: + return None + + logger.debug(f"Reshaping forcing data from shape: {forcing_data.shape}") + time_steps = forcing_data.shape[1] + # Transpose and reshape to match observation format + forcing_reshaped = forcing_data.T.reshape((time_steps, self.lat, self.lon)) + logger.debug(f"Reshaped forcing data to: {forcing_reshaped.shape}") + return forcing_reshaped + + def _apply_forcing_conditioning(self, co2_reshaped, aerosol_reshaped): + """ + Apply a conditioning mapping to forcings before sequencing. + + Supported modes (scaffold only): + - raw: keep full spatial fields + - template: project onto aerosol templates (TODO) + - mode: project onto climate modes (implemented) + - region: region averages (TODO) + """ + if self.forcing_conditioning == "raw": + if not self._forcing_conditioning_logged and co2_reshaped is not None and aerosol_reshaped is not None: + logger.info("[ForcingConditioning] Applied forcing_conditioning='raw' (no projection)") + self._forcing_conditioning_logged = True + return co2_reshaped, aerosol_reshaped + if self.forcing_conditioning == "mode": + if co2_reshaped is None or aerosol_reshaped is None: + return co2_reshaped, aerosol_reshaped + if self.mode_weights is None: + raise FileNotFoundError("mode_weights.npy is required for forcing_conditioning='mode'") + + # Flatten spatial dims: (time, spatial) + co2_flat = co2_reshaped.reshape(co2_reshaped.shape[0], -1) + aerosol_flat = aerosol_reshaped.reshape(aerosol_reshaped.shape[0], -1) + + # CO2 conditioning: spatial mean (time, 1) + co2_mean = co2_flat.mean(axis=1, keepdims=True) + + # Aerosol conditioning: projection onto climate mode weights (time, n_modes) + mode_weights = self.mode_weights.reshape(self.mode_weights.shape[0], -1) + norms = np.linalg.norm(mode_weights, axis=1, keepdims=True) + mode_weights = mode_weights / (norms + 1e-8) + aerosol_mode = aerosol_flat @ mode_weights.T + + if not self._forcing_conditioning_logged: + logger.info("[ForcingConditioning] Applied forcing_conditioning='mode' (climate mode projection)") + self._forcing_conditioning_logged = True + return co2_mean.astype("float32"), aerosol_mode.astype("float32") + if self.forcing_conditioning in {"template", "region"}: + raise NotImplementedError( + f"forcing_conditioning='{self.forcing_conditioning}' is scaffolded but not implemented yet." + ) + raise ValueError(f"Unknown forcing_conditioning='{self.forcing_conditioning}'") + + def get_forcing_sequences(self, forcing_data, idxs): + """ + Extract forcing values at specified timestep indices. + + Args: + forcing_data: Array of shape (time, lat, lon) - already reshaped forcing data + idxs: Timestep indices to extract + + Returns: + Array of shape (len(idxs), lat, lon) containing forcing at each timestep + """ + if forcing_data is None: + return None + + forcing_sequences = [] + for idx in idxs: + forcing_sequences.append(forcing_data[idx]) + + return np.stack(forcing_sequences) + + def _load_mode_artifacts(self): + """Load SAVAR mode/noise artifacts with backward-compatible fallbacks.""" + modes_path = self.savar_dataset_dir / "modes.npy" + mode_weights_path = self.savar_dataset_dir / "mode_weights.npy" + noise_modes_path = self.savar_dataset_dir / "noise_modes.npy" + noise_weights_path = self.savar_dataset_dir / "noise_weights.npy" + + if modes_path.exists(): + self.gt_modes = np.load(modes_path) + elif mode_weights_path.exists(): + self.gt_modes = np.load(mode_weights_path) + logger.warning("modes.npy missing at %s; using mode_weights.npy fallback.", modes_path) + else: + raise FileNotFoundError( + f"Missing SAVAR mode files in {self.savar_dataset_dir}: expected modes.npy or mode_weights.npy" + ) + + if noise_modes_path.exists(): + self.gt_noise = np.load(noise_modes_path) + elif noise_weights_path.exists(): + noise_weights = np.load(noise_weights_path) + self.gt_noise = noise_weights.sum(axis=0) if noise_weights.ndim == 3 else noise_weights + logger.warning("noise_modes.npy missing at %s; using noise_weights.npy fallback.", noise_modes_path) + else: + if self.gt_modes.ndim == 3: + self.gt_noise = np.zeros_like(self.gt_modes.sum(axis=0)) + else: + self.gt_noise = np.zeros_like(self.gt_modes) + logger.warning( + "Missing SAVAR noise files in %s (noise_modes.npy/noise_weights.npy). Using zero noise map.", + self.savar_dataset_dir, + ) + + if mode_weights_path.exists(): + self.mode_weights = np.load(mode_weights_path) + logger.info("Loaded mode weights from %s, shape: %s", mode_weights_path, self.mode_weights.shape) + elif self.gt_modes.ndim == 3: + self.mode_weights = self.gt_modes + else: + self.mode_weights = None + + def _load_or_generate_savar_data(self, tau): + """Load existing SAVAR data or generate new data.""" if os.path.exists(self.savar_path) and self.reload_climate_set_data: data = self.load_savar_data(self.savar_path) else: - print("CREATE SAVAR DATA") + logger.info("CREATE SAVAR DATA") data = generate_save_savar_data( self.output_save_dir, self.savar_name, @@ -188,6 +487,11 @@ def get_causal_data( self.n_per_col, self.difficulty, self.seasonality, + self.periods, + self.amplitudes, + self.phases, + self.yearly_jitter_amp, + self.yearly_jitter_phase, self.overlap, self.is_forced, self.f_1, @@ -198,151 +502,382 @@ def get_causal_data( self.linearity, self.poly_degrees, self.plot_original_data, + self.aerosol_scale, + self.aerosol_spatial_contrast, + self.aerosol_ramp_up_time, + self.aerosol_peak_time, + self.aerosol_decline_time, + n_co2_latents=self.n_co2_latents, + n_aerosol_latents=self.n_aerosol_latents, + co2_effect_strength=self.co2_effect_strength, + aerosol_effect_strength=self.aerosol_effect_strength, + forcing_amplification=self.forcing_amplification, + noise_ar1=self.noise_ar1, + noise_ar1_rho=self.noise_ar1_rho, + tau=self.tau, + enable_background=self.enable_background, + background_strength=self.background_strength, + background_strength_mode=self.background_strength_mode, + background_smoothness=self.background_smoothness, + background_timescale_rho=self.background_timescale_rho, + background_n_modes=self.background_n_modes, ) time_steps = data.shape[1] data = data.T.reshape((time_steps, self.lat, self.lon)) - self.gt_modes = np.load(self.output_save_dir / f"{self.savar_name}_modes.npy") - self.gt_noise = np.load(self.output_save_dir / f"{self.savar_name}_noise_modes.npy") - links_coeffs = np.load( - self.output_save_dir / f"{self.savar_name}_parameters.npy", allow_pickle=True - ).item()["links_coeffs"] - self.gt_adj = np.array(extract_adjacency_matrix(links_coeffs, self.n_per_col**2, tau)) + self._load_mode_artifacts() - data = data.astype("float32") - # TODO: normalize by saveing std/mean from train data and then normalize test by reloading - # Very important to avoid normalizing differently test and train data - if self.global_normalization: - data = (data - data.mean()) / data.std() - if self.seasonality_removal: - self.norm_data = self.remove_seasonality(self.norm_data) + # Load noise data field for signal-noise decomposition plots + noise_data_path = self.savar_dataset_dir / "noise_data_field.npy" + if noise_data_path.exists(): + self.noise_data_field = np.load(noise_data_path) + logger.info(f"Loaded noise data field from {noise_data_path}, shape: {self.noise_data_field.shape}") + else: + self.noise_data_field = None + logger.warning(f"Noise data field not found at {noise_data_path}") + + # Load deterministic data for signal-noise decomposition plots + deterministic_data_path = self.savar_dataset_dir / "savar_deterministic.npy" + if deterministic_data_path.exists(): + self.deterministic_data_field = np.load(deterministic_data_path) + logger.info( + f"Loaded deterministic data field from {deterministic_data_path}, " + f"shape: {self.deterministic_data_field.shape}" + ) + else: + self.deterministic_data_field = None + logger.warning(f"Deterministic data field not found at {deterministic_data_path}") + + # Load background data field (if available) + background_data_path = self.savar_dataset_dir / "background_data_field.npy" + if background_data_path.exists(): + self.background_data_field = np.load(background_data_path) + logger.info( + f"Loaded background data field from {background_data_path}, shape: {self.background_data_field.shape}" + ) + else: + self.background_data_field = None - print(f"data is {data.dtype}") + params = np.load(self.savar_dataset_dir / "parameters.npy", allow_pickle=True).item() + links_coeffs = params["links_coeffs"] - try: - # NOTE:(seb) this is what we do when we have the regularly gridded data! - # (years, months, vars, lon, lat) -> (scenrios, years*months, vars, lon, lat) - # Regular data shape before reshaping: (101, 12, 1, 96, 144) - # Regular data shape after reshaping: (1, 1212, 1, 96, 144) - print("Trying to regrid to lon, lat if we have regular data...") - # data = data.reshape(num_scenarios, num_years, num_vars, LON, LAT) + n_total_latents = params.get("n_total_latents", self.n_per_col**2) + self.forcing_indices = params.get("forcing_indices", None) + self.n_climate_modes = params.get("n_climate_modes", self.n_per_col**2) - data = data.reshape(1, data.shape[0], 1, self.lon, self.lat) + self.gt_adj = np.array(extract_adjacency_matrix(links_coeffs, n_total_latents, tau))[::-1] + + return data + + def _load_forcing_files(self): + """Load separate CO2 and aerosol forcing files if available.""" + if self.use_separate_forcings and self.is_forced: + co2_forcing_path = self.savar_dataset_dir / "co2_forcing.npy" + aerosol_forcing_path = self.savar_dataset_dir / "aerosol_forcing.npy" + + if co2_forcing_path.exists(): + self.co2_forcing = np.load(co2_forcing_path) + logger.info(f"Loaded CO2 forcing from {co2_forcing_path}, shape: {self.co2_forcing.shape}") + else: + logger.warning(f"CO2 forcing file not found: {co2_forcing_path}") + self.co2_forcing = None + + if aerosol_forcing_path.exists(): + self.aerosol_forcing = np.load(aerosol_forcing_path) + logger.info(f"Loaded aerosol forcing from {aerosol_forcing_path}, shape: {self.aerosol_forcing.shape}") + else: + logger.warning(f"Aerosol forcing file not found: {aerosol_forcing_path}") + self.aerosol_forcing = None + + co2_latent_path = self.savar_dataset_dir / "co2_latent_trajectory.npy" + aerosol_latent_path = self.savar_dataset_dir / "aerosol_latent_trajectory.npy" + if co2_latent_path.exists(): + self.gt_co2_latent = np.load(co2_latent_path) + logger.info(f"Loaded CO2 latent trajectory from {co2_latent_path}, shape: {self.gt_co2_latent.shape}") + else: + logger.warning(f"CO2 latent trajectory not found: {co2_latent_path}") + self.gt_co2_latent = None + + if aerosol_latent_path.exists(): + self.gt_aerosol_latent = np.load(aerosol_latent_path) + logger.info( + f"Loaded aerosol latent trajectory from {aerosol_latent_path}, shape: {self.gt_aerosol_latent.shape}" + ) + else: + logger.warning(f"Aerosol latent trajectory not found: {aerosol_latent_path}") + self.gt_aerosol_latent = None + else: + self.co2_forcing = None + self.aerosol_forcing = None + self.gt_co2_latent = None + self.gt_aerosol_latent = None + + def _reshape_data(self, data): + """Reshape data to proper dimensions.""" + try: + logger.debug("Trying to regrid to lon, lat if we have regular data...") + data = data.reshape(1, data.shape[0], 1, self.lon, self.lat) except ValueError: - print( - "I saw a ValueError and now I am reshaping the data differently, probably as I have icosahedral data!" - ) - # I need to include the number of years in the reshape here...! - # How to access it? As the length of the list of paths? - # NOTE: currently hardcoding 101 year long sequences...need to unhack this... - # NOTE:(seb) now we hard code that we want to change to num_years*12, like we had before - # this -1 should probably be changed to reflect the number of coordinates that we have for the icosahedral grid... - # also the .txt file will not be right for different resolutions!!!! - - print("Data shape before reshaping:", data.shape) + logger.debug("Reshaping data for icosahedral grid...") + logger.debug("Data shape before reshaping:", data.shape) data = data.reshape(1, data.shape[0], 1, -1) - print("Data shape after reshaping:", data.shape) + logger.debug("Data shape after reshaping:", data.shape) + return data - if isinstance(num_months_aggregated, (int, np.integer)) and num_months_aggregated > 1: - data = self.aggregate_months(data, num_months_aggregated) - # for each scenario in data, generate overlapping sequences - if mode == "train" or mode == "train+val": - print("IN IF") - x_train_list, y_train_list = [], [] - x_valid_list, y_valid_list = [], [] - - for scenario in data: - idx_train, idx_valid = self.split_data_by_interval(scenario, tau, ratio_train, interval_length) - # np.random.shuffle(idx_train) - # np.random.shuffle(idx_valid) - - x_train, y_train = self.get_overlapping_sequences(scenario, idx_train, tau, future_timesteps) - x_train_list.extend(x_train) - y_train_list.extend(y_train) - - x_valid, y_valid = self.get_overlapping_sequences(scenario, idx_valid, tau, future_timesteps) - x_valid_list.extend(x_valid) - y_valid_list.extend(y_valid) - - train_x, train_y = np.stack(x_train_list), np.stack(y_train_list) - if ratio_train == 1: - valid_x, valid_y = np.array(x_valid_list), np.array(y_valid_list) - else: - valid_x, valid_y = np.stack(x_valid_list), np.stack(y_valid_list) - train_y = np.expand_dims(train_y, axis=1) - valid_y = np.expand_dims(valid_y, axis=1) - - # z-score normalization - # make train_y go from (2550, 4, 96, 144) to (2550, 1, 4, 96, 144) - train = train_x, train_y - valid = valid_x, valid_y - - # print(train_y.shape) - # plot_species(train_y[:, :, 0, :, :], self.coordinates, "tas", "../../TEST_REPO", "after_causal") - return train, valid + def _process_aggregated_data(self, data, tau, future_timesteps, mode, ratio_train, interval_length): + """Process aggregated monthly data for training/validation.""" + data = self.aggregate_months(data, num_months_aggregated=1) + + if mode == "train" or mode == "train+val": + train, valid = self._generate_train_valid_sequences( + data, tau, future_timesteps, ratio_train, interval_length + ) + return train, valid + else: + test = self._generate_test_sequences(data, tau, future_timesteps) + return test + + def _generate_train_valid_sequences(self, data, tau, future_timesteps, ratio_train, interval_length): + """Generate training and validation sequences from data.""" + x_train_list, y_train_list, x_valid_list, y_valid_list = [], [], [], [] + + for scenario in data: + idx_train, idx_valid = self.split_data_by_interval(scenario, tau, ratio_train, interval_length) + x_train, y_train = self.get_overlapping_sequences(scenario, idx_train, tau, future_timesteps) + x_train_list.extend(x_train) + y_train_list.extend(y_train) + + x_valid, y_valid = self.get_overlapping_sequences(scenario, idx_valid, tau, future_timesteps) + x_valid_list.extend(x_valid) + y_valid_list.extend(y_valid) + + train_x, train_y = np.stack(x_train_list), np.stack(y_train_list) + if ratio_train == 1: + valid_x, valid_y = np.array(x_valid_list), np.array(y_valid_list) + else: + valid_x, valid_y = np.stack(x_valid_list), np.stack(y_valid_list) + + train_y = np.expand_dims(train_y, axis=1) + valid_y = np.expand_dims(valid_y, axis=1) + + self._extract_forcing_sequences(data, tau, ratio_train, interval_length) + + return (train_x, train_y), (valid_x, valid_y) + + def _generate_test_sequences(self, data, tau, future_timesteps): + """Generate test sequences from data.""" + x_test_list, y_test_list = [], [] + for scenario in data: + idx_test = np.arange(tau, scenario.shape[0]) + x_test, y_test = self.get_overlapping_sequences(scenario, idx_test, tau, future_timesteps) + x_test_list.extend(x_test) + y_test_list.extend(y_test) + + test_x, test_y = np.stack(x_test_list), np.stack(y_test_list) + test_y = np.expand_dims(test_y, axis=1) + return test_x, test_y + + def _extract_forcing_sequences(self, data, tau, ratio_train, interval_length): + """Extract forcing sequences for training and validation.""" + if not (self.use_separate_forcings and hasattr(self, "co2_forcing") and self.co2_forcing is not None): + self.co2_forcing_train = None + self.aerosol_forcing_train = None + self.co2_forcing_valid = None + self.aerosol_forcing_valid = None + self.gt_co2_latent_train = None + self.gt_aerosol_latent_train = None + self.gt_co2_latent_valid = None + self.gt_aerosol_latent_valid = None + return + + co2_reshaped = self.reshape_forcing_data(self.co2_forcing) + aerosol_reshaped = self.reshape_forcing_data(self.aerosol_forcing) + co2_reshaped, aerosol_reshaped = self._apply_forcing_conditioning(co2_reshaped, aerosol_reshaped) + + co2_train_list, co2_valid_list = [], [] + aerosol_train_list, aerosol_valid_list = [], [] + + # Lists for forcing latent trajectories + co2_latent_train_list, co2_latent_valid_list = [], [] + aerosol_latent_train_list, aerosol_latent_valid_list = [], [] + + for scenario in data: + idx_train, idx_valid = self.split_data_by_interval(scenario, tau, ratio_train, interval_length) + + co2_train_list.extend([co2_reshaped[idx - tau : idx + 1] for idx in idx_train]) + aerosol_train_list.extend([aerosol_reshaped[idx - tau : idx + 1] for idx in idx_train]) + + co2_valid_list.extend([co2_reshaped[idx - tau : idx + 1] for idx in idx_valid]) + aerosol_valid_list.extend([aerosol_reshaped[idx - tau : idx + 1] for idx in idx_valid]) + + # Extract forcing latent sequences at the same indices + if hasattr(self, "gt_co2_latent") and self.gt_co2_latent is not None: + # gt_co2_latent shape: (time,) - scalar at each timestep + co2_latent_train_list.extend([self.gt_co2_latent[idx - tau : idx + 1] for idx in idx_train]) + co2_latent_valid_list.extend([self.gt_co2_latent[idx - tau : idx + 1] for idx in idx_valid]) + + if hasattr(self, "gt_aerosol_latent") and self.gt_aerosol_latent is not None: + # gt_aerosol_latent shape: (n_aerosol_latents, time) + aerosol_latent_train_list.extend([self.gt_aerosol_latent[:, idx - tau : idx + 1] for idx in idx_train]) + aerosol_latent_valid_list.extend([self.gt_aerosol_latent[:, idx - tau : idx + 1] for idx in idx_valid]) + + co2_train_stacked = np.stack(co2_train_list).astype("float32") + aerosol_train_stacked = np.stack(aerosol_train_list).astype("float32") + co2_valid_stacked = np.stack(co2_valid_list).astype("float32") if len(co2_valid_list) > 0 else None + aerosol_valid_stacked = np.stack(aerosol_valid_list).astype("float32") if len(aerosol_valid_list) > 0 else None + + # Keep CO2 spatial (same as aerosol) instead of averaging to scalar + self.co2_forcing_train = co2_train_stacked.reshape(co2_train_stacked.shape[0], co2_train_stacked.shape[1], -1) + self.aerosol_forcing_train = aerosol_train_stacked.reshape( + aerosol_train_stacked.shape[0], aerosol_train_stacked.shape[1], -1 + ) + + if co2_valid_stacked is not None: + # Keep CO2 spatial (same as aerosol) instead of averaging to scalar + self.co2_forcing_valid = co2_valid_stacked.reshape( + co2_valid_stacked.shape[0], co2_valid_stacked.shape[1], -1 + ) + self.aerosol_forcing_valid = aerosol_valid_stacked.reshape( + aerosol_valid_stacked.shape[0], aerosol_valid_stacked.shape[1], -1 + ) + else: + self.co2_forcing_valid = None + self.aerosol_forcing_valid = None + + # Stack forcing latent trajectories + if len(co2_latent_train_list) > 0: + # CO2 latent: (batch, tau+1) -> add feature dimension -> (batch, tau+1, 1) + self.gt_co2_latent_train = np.stack(co2_latent_train_list).astype("float32")[:, :, np.newaxis] + self.gt_co2_latent_valid = ( + np.stack(co2_latent_valid_list).astype("float32")[:, :, np.newaxis] + if len(co2_latent_valid_list) > 0 + else None + ) + logger.info(f"Extracted CO2 latent sequences - train: {self.gt_co2_latent_train.shape}") + else: + self.gt_co2_latent_train = None + self.gt_co2_latent_valid = None + + if len(aerosol_latent_train_list) > 0: + # Aerosol latent: (batch, n_latents, tau+1) -> transpose to (batch, tau+1, n_latents) + self.gt_aerosol_latent_train = np.stack(aerosol_latent_train_list).astype("float32").transpose(0, 2, 1) + self.gt_aerosol_latent_valid = ( + np.stack(aerosol_latent_valid_list).astype("float32").transpose(0, 2, 1) + if len(aerosol_latent_valid_list) > 0 + else None + ) + logger.info(f"Extracted aerosol latent sequences - train: {self.gt_aerosol_latent_train.shape}") + else: + self.gt_aerosol_latent_train = None + self.gt_aerosol_latent_valid = None + + logger.info( + f"Extracted forcing sequences - CO2 train: {self.co2_forcing_train.shape}, " + f"aerosol train: {self.aerosol_forcing_train.shape}" + ) + logger.info("Both CO2 and aerosols are spatially varying (not averaged)") + + def _generate_non_aggregated_data(self, data, tau, future_timesteps, mode, ratio_train, interval_length): + """Generate non-aggregated causal data for training/validation or testing.""" + if mode == "train" or mode == "train+val": + x_train_list, y_train_list = [], [] + x_valid_list, y_valid_list = [], [] + + for scenario in data: + idx_train, idx_valid = self.split_data_by_interval(scenario, tau, ratio_train, interval_length) + x_train, y_train = self.get_overlapping_sequences(scenario, idx_train, tau, future_timesteps) + x_train_list.extend(x_train) + y_train_list.extend(y_train) + + x_valid, y_valid = self.get_overlapping_sequences(scenario, idx_valid, tau, future_timesteps) + x_valid_list.extend(x_valid) + y_valid_list.extend(y_valid) + + train_x, train_y = np.stack(x_train_list), np.stack(y_train_list) + if ratio_train == 1: + valid_x, valid_y = np.array(x_valid_list), np.array(y_valid_list) else: - x_test_list, y_test_list = [], [] - for scenario in data: - idx_test = np.arange(tau, scenario.shape[0]) - x_test, y_test = self.get_overlapping_sequences(scenario, idx_test, tau, future_timesteps) - x_test_list.extend(x_test) - y_test_list.extend(y_test) + valid_x, valid_y = np.stack(x_valid_list), np.stack(y_valid_list) + + train_y = np.expand_dims(train_y, axis=1) + valid_y = np.expand_dims(valid_y, axis=1) - test_x, test_y = np.stack(x_test_list), np.stack(y_test_list) - test_y = np.expand_dims(test_y, axis=1) + self._extract_forcing_sequences(data, tau, ratio_train, interval_length) - test = test_x, test_y + return (train_x, train_y), (valid_x, valid_y) + else: + x_test_list, y_test_list = [], [] + + for scenario in data: + idx_test = np.arange(tau, scenario.shape[0]) + x_test, y_test = self.get_overlapping_sequences(scenario, idx_test, tau, future_timesteps) + x_test_list.extend(x_test) + y_test_list.extend(y_test) - return test + test_x, test_y = np.stack(x_test_list), np.stack(y_test_list) + test_y = np.expand_dims(test_y, axis=1) - # NOTE:seb delete commented code + return test_x, test_y + def get_causal_data( + self, + tau, + future_timesteps, + channels_last, + num_vars, + num_scenarios, + num_ensembles, + num_years, + mode, + num_months_aggregated=1, + ratio_train=None, + interval_length=100, + ): + """ + Constructs dataset for causal discovery model. + + Splits each scenario into training and validation sets, then generates overlapping sequences. + """ + logger.info(f"Getting causal data [mode={mode}] ...") + data = self._load_or_generate_savar_data(tau) + self._load_forcing_files() + + data = data.astype("float32") + if self.global_normalization: + data = (data - data.mean()) / data.std() + if self.seasonality_removal: + data = self.remove_seasonality( + data, + periods=self.periods, + demean=True, + normalise=False, + rolling=True, + w=10, + ) + + logger.debug(f"data is {data.dtype}") + data = self._reshape_data(data) + + if isinstance(num_months_aggregated, (int, np.integer)) and num_months_aggregated > 1: + return self._process_aggregated_data(data, tau, future_timesteps, mode, ratio_train, interval_length) else: - # TODO create this function and use it -> put it inside the data creation... - # data = self.create_multi_res_data(data, num_months_aggregated) - - # for each scenario in data, generate overlapping sequences - if mode == "train" or mode == "train+val": - x_train_list, y_train_list = [], [] - x_valid_list, y_valid_list = [], [] - for scenario in data: - idx_train, idx_valid = self.split_data_by_interval(scenario, tau, ratio_train, interval_length) - # np.random.shuffle(idx_train) - # np.random.shuffle(idx_valid) - - x_train, y_train = self.get_overlapping_sequences(scenario, idx_train, tau, future_timesteps) - x_train_list.extend(x_train) - y_train_list.extend(y_train) - - x_valid, y_valid = self.get_overlapping_sequences(scenario, idx_valid, tau, future_timesteps) - x_valid_list.extend(x_valid) - y_valid_list.extend(y_valid) - train_x, train_y = np.stack(x_train_list), np.stack(y_train_list) - if ratio_train == 1: - valid_x, valid_y = np.array(x_valid_list), np.array(y_valid_list) - else: - valid_x, valid_y = np.stack(x_valid_list), np.stack(y_valid_list) - train_y = np.expand_dims(train_y, axis=1) - valid_y = np.expand_dims(valid_y, axis=1) - - train = train_x, train_y - valid = valid_x, valid_y - print(f"train: {train[0].dtype}") - return train, valid - else: - x_test_list, y_test_list = [], [] - for scenario in data: - idx_test = np.arange(tau, scenario.shape[0]) - x_test, y_test = self.get_overlapping_sequences(scenario, idx_test, tau, future_timesteps) - x_test_list.extend(x_test) - y_test_list.extend(y_test) + return self._generate_non_aggregated_data(data, tau, future_timesteps, mode, ratio_train, interval_length) + + def get_forcing_data(self): + """ + Get CO2 and aerosol forcing data, properly reshaped to match observations. + + Returns: + Tuple of (co2_forcing, aerosol_forcing), each of shape (time, lat, lon) or None + """ + if not self.use_separate_forcings: + return None, None - test_x, test_y = np.stack(x_test_list), np.stack(y_test_list) - test_y = np.expand_dims(test_y, axis=1) + co2_reshaped = self.reshape_forcing_data(self.co2_forcing) if self.co2_forcing is not None else None + aerosol_reshaped = self.reshape_forcing_data(self.aerosol_forcing) if self.aerosol_forcing is not None else None - test = test_x, test_y - return test + return co2_reshaped, aerosol_reshaped def save_data_into_disk(self, data: np.ndarray, fname: str, output_save_dir: str) -> str: @@ -368,7 +903,7 @@ def get_mean_std(self, data): vars_mean = np.expand_dims(vars_mean, (1, 2, 3)) vars_std = np.expand_dims(vars_std, (1, 2, 3)) else: - print("Data dimension not recognized. Please check the dimensions of the data.") + logger.warning("Data dimension not recognized. Please check the dimensions of the data.") raise ValueError return vars_mean, vars_std @@ -390,40 +925,65 @@ def get_min_max(self, data): vars_max = np.expand_dims(vars_max, (1, 2, 3)) vars_min = np.expand_dims(vars_min, (1, 2, 3)) else: - print("Data dimension not recognized. Please check the dimensions of the data.") + logger.warning("Data dimension not recognized. Please check the dimensions of the data.") raise ValueError return vars_min, vars_max - # important? - # NOTE:(seb) I need to check the axis is correct here? - def remove_seasonality(self, data): + def remove_seasonality( + self, + data: np.ndarray, + periods: int | Sequence[int] | Sequence[float] = (12, 6, 3), + demean: bool = True, + normalise: bool = False, + rolling: bool = True, # ← default TRUE because of jitter + w: int = 10, # (10 years ≈ 120 steps @ monthly) + ): """ - Function to remove seasonality from the data There are various different options to do this These are just - different methods of removing seasonality. + Remove deterministic periodic seasonality from a [time, …] array. - e.g. - monthly - remove seasonality on a per month basis - rolling monthly - remove seasonality on a per month basis but using a rolling window, - removing only the average from the months that have preceded this month - linear - remove seasonality using a linear model to predict seasonality - - or trend removal - emissions - remove the trend using the emissions data, such as cumulative CO2 + Parameters + ---------- + period single cycle length **or** list/tuple of lengths + (e.g. [12, 6] for annual + semi-annual) + … """ - mean = np.nanmean(data, axis=0) - std = np.nanstd(data, axis=0) - - # return data - - # NOTE: SH - do we not do this above? - # standardise - I hope this is doing by month, to check - - return (data - mean[None]) / std[None] - - # now just divide by std... - # return data / std[None] + def _remove_one(x: np.ndarray, p: int) -> np.ndarray: + """Inner helper that handles a single period length.""" + t = x.shape[0] + rem = t % p + if rem: + x = x[:-rem] + t -= rem + folded = x.reshape((t // p, p) + x.shape[1:]) + if rolling: + k = min(w, folded.shape[0]) + mean = np.nanmean(folded[-k:], axis=0) + std = np.nanstd(folded[-k:], axis=0) + else: + mean = np.nanmean(folded, axis=0) + std = np.nanstd(folded, axis=0) + mean_full = np.tile(mean, (t // p, *[1] * (x.ndim - 1))) + std_full = np.tile(std, (t // p, *[1] * (x.ndim - 1))) + out = x.copy() + if demean: + out -= mean_full + if normalise: + out /= np.where(std_full == 0, 1, std_full) + return out.astype(np.float32) + + # handle one or many cycle lengths + if isinstance(periods, (list, tuple, np.ndarray)): + # remove the longest cycle first to avoid leakage + _periods = sorted([int(round(p)) for p in periods], reverse=True) + else: # single scalar + _periods = [int(round(periods))] + + out = data.astype(np.float32) + for p in _periods: + out = _remove_one(out, p) + return out def write_dataset_statistics(self, fname, stats): # fname = fname.replace('.npz.npy', '.npy') @@ -454,6 +1014,6 @@ def __str__(self): # NOTE(seb): is this a good way to get the length? def __len__(self): - print("Input4mips", self.input4mips_ds.length, "CMIP6 data", self.cmip6_ds.length) + logger.debug("Input4mips", self.input4mips_ds.length, "CMIP6 data", self.cmip6_ds.length) assert self.input4mips_ds.length == self.cmip6_ds.length, "Datasets not of same length" return self.input4mips_ds.length diff --git a/climatem/model/train_model.py b/climatem/model/train_model.py index 3eb69fb..d1e3faf 100644 --- a/climatem/model/train_model.py +++ b/climatem/model/train_model.py @@ -1,23 +1,83 @@ -# Adapting to do training across multiple GPUs with huggingface accelerate. +""" +Training loop for LatentTSDCD using the Augmented Lagrangian Method (ALM). + +This module implements the full training pipeline for the LatentTSDCD causal discovery model, including multi-phase +training (encoder/decoder warmup followed by joint optimization with causal graph learning), ALM-based constrained +optimization (orthogonality, sparsity, and acyclicity constraints), loss computation (ELBO, CRPS, spectral losses, +forcing losses), validation, checkpoint saving, and integration with HuggingFace Accelerate for distributed training. +""" + import numpy as np import torch import torch.distributions as dist - -# we use accelerate for distributed training from geopy import distance - -# from torch.nn.parallel import DistributedDataParallel as DDP from torch.profiler import ProfilerActivity from climatem.model.dag_optim import compute_dag_constraint from climatem.model.prox import monkey_patch_RMSprop from climatem.model.utils import ALM from climatem.plotting.plot_model_output import Plotter +from climatem.plotting.plot_savar_output import SavarPlotter +from climatem.utils import get_logger +logger = get_logger(__name__) + +# Euler-Mascheroni constant, used in Gumbel distribution CDF for GEV loss computation (duplicated from tsdcd_latent.py) euler_mascheroni = 0.57721566490153286060 class TrainingLatent: + """ + Training wrapper for the LatentTSDCD causal discovery model. + + This class manages the complete training loop for causal discovery with + latent variables, including data loading, loss computation, validation, + plotting, and checkpoint saving. + + Training proceeds in two phases: + + - **Phase 1** (first ``patience`` iterations): Only the encoder and decoder + are optimized, allowing the autoencoder to learn meaningful latent + representations before introducing causal graph learning. + - **Phase 2** (remaining iterations): Joint optimization of the encoder, + decoder, and the learnable causal graph structure. ALM constraints + (orthogonality on decoder weights, sparsity on the adjacency matrix, + and optionally acyclicity for instantaneous connections) are gradually + enforced via penalty terms and Lagrange multipliers with automatic + mu (penalty weight) scaling. + + Once all ALM constraints converge, training continues without penalties + until validation loss patience is exhausted, at which point the learned + adjacency matrix is thresholded to a binary graph and training resumes + briefly for fine-tuning. + + Parameters + ---------- + model : LatentTSDCD + The causal discovery model to train. + datamodule : CausalClimateDataModule + Provides train and validation dataloaders. + data_params, exp_params, gt_params, model_params, train_params, + optim_params, plot_params, savar_params : dataclass + Configuration parameter objects (see ``climatem.config``). + save_path : pathlib.Path + Directory for model checkpoints and intermediate outputs. + plots_path : pathlib.Path + Directory for training visualizations. + best_metrics : dict + Dictionary to store best validation metrics. + d : int + Number of observed climate variables. + accelerator : accelerate.Accelerator + HuggingFace Accelerate instance for distributed training. + wandbname : str, optional + Name for the Weights & Biases run. + profiler : bool, optional + Whether to enable PyTorch profiling. + profiler_path : str, optional + Output path for profiler traces. + """ + def __init__( self, model, @@ -39,7 +99,7 @@ def __init__( profiler=False, profiler_path="./log", ): - # TODO: do we want to have the profiler as an argument? Maybe not, but useful to speed up the code + # TODO(dev): Consider whether profiler should be a constructor arg or a separate config option self.accelerator = accelerator self.model = model self.model.to(accelerator.device) @@ -138,13 +198,13 @@ def __init__( self.best_spatial_spectra_score = None - self.plotter = Plotter() - - # if MULTI_GPU: - # print("I am using multiple GPUs!!") - # # setup_ddp() - # # DistributedSampler - # self.model = DDP(self.model) + # Automatically use SavarPlotter for SAVAR synthetic data, otherwise use standard Plotter + if plot_params.savar: + self.plotter = SavarPlotter() + logger.info("Using SavarPlotter for synthetic SAVAR data visualization") + else: + self.plotter = Plotter() + logger.info("Using standard Plotter for climate data visualization") # I think this is just initialising a tensor of zeroes to store results in # if not self.no_gt: @@ -188,29 +248,11 @@ def __init__( ) # prepare the model, optimizer, data loader, and scheduler using Accerate for distributed training - print("Preparing all the models here!") + logger.info("Preparing all the models here!") self.data_loader_train, self.model, self.optimizer, self.scheduler = accelerator.prepare( self.data_loader_train, self.model, self.optimizer, self.scheduler ) - # Check that model and everything is on gpu - # print("\nModel Parameter Devices after moving to GPU:") - # for name, param in self.model.named_parameters(): - # print(f"{name}: {param.device}") - - # # Check the device of a sample batch (after iterating through the prepared dataloader) - # for batch in self.data_loader_train: - # inputs, labels = batch - # print(f"Input tensor device: {inputs.device}") - # print(f"Label tensor device: {labels.device}") - # break - - # # Check the device of the optimizer's state (this might vary) - # for group in self.optimizer.param_groups: - # for param in group['params']: - # if param in self.optimizer.state: - # print(f"Optimizer state for parameter '{param.shape}': {self.optimizer.state[param].get('step', torch.tensor(0)).device}") - # compute constraint normalization with torch.no_grad(): d = model.d * model.d_z @@ -218,7 +260,12 @@ def __init__( self.acyclic_constraint_normalization = compute_dag_constraint(full_adjacency).item() if self.latent: - self.ortho_normalization = self.d_x * self.d_z + # Ortho normalization only applies to climate latents + if model.use_forced_latents: + n_climate = self.d_z - model.n_forced_latents_co2 - model.n_forced_latents_aerosol + else: + n_climate = self.d_z + self.ortho_normalization = self.d_x * n_climate if self.instantaneous: self.sparsity_normalization = (self.tau + 1) * self.d_z * self.d_z else: @@ -233,26 +280,19 @@ def train_with_QPM(self): # noqa: C901 the adjacency matrix """ - # Pre-Accelerate - start a new wandb run to track this script - # wandb.init( - # set the wandb project where this run will be logged - # please alter this project, and set the name to something appropriate for your experiments - # project="test-gpu-code-wandb", - # name=... - # # ) - - # print("what is the cuda device count?", torch.cuda.device_count()) - # print("MULTI GPU?", MULTI_GPU) - - # TODO: Why config here? - # config = self.hp self.accelerator.init_trackers( "gpu-code-wandb", - # config=config, init_kwargs={"wandb": {"name": self.wandbname}}, ) # initialize ALM/QPM for orthogonality and acyclicity constraints + # Orthogonality constraint only applies to climate latents (not forcing latents) + if self.model.use_forced_latents: + n_climate = self.d_z - self.model.n_forced_latents_co2 - self.model.n_forced_latents_aerosol + else: + n_climate = self.d_z + self.n_climate_latents = n_climate # Store for later use + self.ALM_ortho = ALM( self.optim_params.ortho_mu_init, self.optim_params.ortho_mu_mult_factor, @@ -260,7 +300,7 @@ def train_with_QPM(self): # noqa: C901 self.optim_params.ortho_omega_mu, self.optim_params.ortho_h_threshold, self.optim_params.ortho_min_iter_convergence, - dim_gamma=(self.d_z, self.d_z), + dim_gamma=(n_climate, n_climate), # Only climate latents valid_freq=self.train_params.valid_freq, ) @@ -292,11 +332,11 @@ def train_with_QPM(self): # noqa: C901 # we should have this function elsewhere. It is rarely used. def trace_handler(p): - print("Printing profiler key averages from trace handler!") + logger.debug("Printing profiler key averages from trace handler!") output_cpu = p.key_averages().table(sort_by="cpu_time_total", row_limit=20) output_cuda = p.key_averages().table(sort_by="cuda_time_total", row_limit=20) - print(output_cpu) - print(output_cuda) + logger.debug(output_cpu) + logger.debug(output_cuda) prof = torch.profiler.profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], @@ -314,6 +354,9 @@ def trace_handler(p): # print out the output of the profiler while self.iteration < self.train_params.max_iteration and not self.ended: + if self.plot_params.savar and self.iteration == 1 and hasattr(self.plotter, "prepare_savar_context"): + logger.info("[SAVAR] GT diagnostics triggered at iteration 1") + self.plotter.prepare_savar_context(self) # train and valid step self.train_step() @@ -323,7 +366,6 @@ def trace_handler(p): if self.iteration % self.train_params.valid_freq == 0: self.logging_iter += 1 - # HERE MODIFY valid_step() self.valid_step() self.log_losses() @@ -351,6 +393,9 @@ def trace_handler(p): "mae_recons_valid": self.val_mae_recons, "mae_pred_valid": self.val_mae_pred, "mse_recons_train": self.train_mse_recons, + "forcing_loss_co2": self.train_forcing_co2_loss, + "forcing_loss_aerosol": self.train_forcing_aerosol_loss, + "forcing_latent_supervision": self.train_forcing_latent_loss, "mse_pred_train": self.train_mse_pred, "mse_recons_valid": self.val_mse_recons, "mse_pred_valid": self.val_mse_pred, @@ -391,19 +436,21 @@ def trace_handler(p): "kl_valid": self.valid_kl, "loss_valid": self.valid_loss, "recons_valid": self.valid_recons, + "forcing_loss_co2": self.train_forcing_co2_loss, + "forcing_loss_aerosol": self.train_forcing_aerosol_loss, + "forcing_latent_supervision": self.train_forcing_latent_loss, } ) # print and plot losses - # TODO : the plotting frrequency is hard to control and unintuitive... update the code here + # TODO(dev): The plotting frequency is hard to control and unintuitive; update the code here if self.iteration % (self.plot_params.print_freq) == 0: self.print_results() if self.logging_iter > 0 and self.iteration % (self.plot_params.plot_freq) == 0: - print(f"Plotting Iteration {self.iteration}") + logger.info(f"Plotting Iteration {self.iteration}") self.plotter.plot_sparsity(self) - # trying to save coords and adjacency matrices - # Todo propagate the path! + # TODO(dev): Propagate the path for saving coordinates and adjacency matrices if not self.plot_params.savar: self.plotter.save_coordinates_and_adjacency_matrices(self) torch.save(self.model.state_dict(), self.save_path / f"model_{self.iteration}.pth") @@ -466,16 +513,13 @@ def trace_handler(p): else: # continue training without penalty method if not self.thresholded and self.iteration % self.patience_freq == 0: - # self.plotter.plot(self, save=True) if not self.has_patience(self.train_params.patience, self.valid_loss): self.threshold() self.patience = self.train_params.patience_post_thresh self.best_valid_loss = np.inf - # self.plotter.plot(self, save=True) # continue training after thresholding else: if self.iteration % self.patience_freq == 0: - # self.plotter.plot(self, save=True) if not self.has_patience(self.train_params.patience_post_thresh, self.valid_loss): self.ended = True @@ -530,14 +574,37 @@ def train_step(self): # noqa: C901 # x, y = next(self.data_loader_train) #.sample(self.batch_size, valid=False) try: - x, y = next(self.data_loader_train) - x = torch.nan_to_num(x) - y = torch.nan_to_num(y) + batch = next(self.data_loader_train) except StopIteration: self.data_loader_train = iter(self.datamodule.train_dataloader(accelerator=self.accelerator)) - x, y = next(self.data_loader_train) - x = torch.nan_to_num(x) - y = torch.nan_to_num(y) + batch = next(self.data_loader_train) + + # Extract data from batch (handles both dict and tuple formats) + if isinstance(batch, dict): + # New format with forcings + x = batch["x"] + y = batch["y"] + y_co2 = batch.get("co2_forcing", None) + y_aerosol = batch.get("aerosol_forcing", None) + # Extract ground truth forcing latents for supervision + gt_co2_latent = batch.get("gt_co2_latent", None) + gt_aerosol_latent = batch.get("gt_aerosol_latent", None) + else: + # Legacy format (tuple) + x, y = batch + y_co2, y_aerosol = None, None + gt_co2_latent, gt_aerosol_latent = None, None + + x = torch.nan_to_num(x) + y = torch.nan_to_num(y) + if y_co2 is not None: + y_co2 = torch.nan_to_num(y_co2) + if y_aerosol is not None: + y_aerosol = torch.nan_to_num(y_aerosol) + if gt_co2_latent is not None: + gt_co2_latent = torch.nan_to_num(gt_co2_latent) + if gt_aerosol_latent is not None: + gt_aerosol_latent = torch.nan_to_num(gt_aerosol_latent) # y = y[:, 0] z = None @@ -546,19 +613,33 @@ def train_step(self): # noqa: C901 nll = 0 recons = 0 kl = 0 + forcing_loss_co2 = torch.tensor(0.0, device=self.accelerator.device) + forcing_loss_aerosol = torch.tensor(0.0, device=self.accelerator.device) + forcing_latent_loss = torch.tensor(0.0, device=self.accelerator.device) # also make the proper prediction, not the reconstruction as we do above # With multiple future timesteps we append the prediction to x and compute the nll of next timestep etc.. # We add to the loss the sum multiplied by the decay in future timesteps # we have to take care here to make sure that we have the right tensors with requires_grad for k in range(self.future_timesteps): - nll_bis, recons_bis, kl_bis, y_pred_recons = self.get_nll(x_bis, y[:, k], z) + ( + nll_bis, + recons_bis, + kl_bis, + y_pred_recons, + forcing_co2_bis, + forcing_aerosol_bis, + forcing_latent_bis, + ) = self.get_nll(x_bis, y[:, k], z, y_co2, y_aerosol, gt_co2_latent, gt_aerosol_latent) nll += (self.optim_params.loss_decay_future_timesteps**k) * nll_bis recons += (self.optim_params.loss_decay_future_timesteps**k) * recons_bis kl += (self.optim_params.loss_decay_future_timesteps**k) * kl_bis + forcing_loss_co2 += (self.optim_params.loss_decay_future_timesteps**k) * forcing_co2_bis + forcing_loss_aerosol += (self.optim_params.loss_decay_future_timesteps**k) * forcing_aerosol_bis + forcing_latent_loss += (self.optim_params.loss_decay_future_timesteps**k) * forcing_latent_bis # Shall we do this if instantaneous?? if not self.instantaneous: - y_pred, y_spare, z_spare, pz_mu, pz_std = self.model.predict(x_bis, y[:, k]) + y_pred, y_spare, z_spare, pz_mu, pz_std = self.model.predict(x_bis, y[:, k], y_co2, y_aerosol) else: y_pred = y_pred_recons y_pred_all[:, k] = y_pred @@ -591,6 +672,13 @@ def train_step(self): # noqa: C901 h_acyclic = self.get_acyclicity_violation() h_ortho = self.get_ortho_violation(self.model.autoencoder.get_w_decoder()) + # NOTE: Decoder utilization penalty is no longer applicable. + # With the architectural fix, forcing latents are excluded from the observation decoder + # (they only go through forcing decoders and the causal transition model). + # The penalty below was a workaround for blob patterns in forcing decoder weights, + # which is now fixed by excluding forcing latents from observation decoding entirely. + decoder_utilization_penalty = torch.tensor(0.0, device=self.accelerator.device) + # compute total loss - here we are removing the sparsity regularisation as we are usings the constraint here. loss = nll + connect_reg + sparsity_reg if not self.no_w_constraint: @@ -612,7 +700,9 @@ def train_step(self): # noqa: C901 spectral_loss = torch.as_tensor([0.0]) for k in range(self.future_timesteps): # This step (predict) could be removed - need to rewrite predict function, to speed things up - px_mu, px_std = self.model.predict_pxmu_pxstd(torch.cat((x[:, k:], y_pred_all[:, :k]), dim=1), y[:, k]) + px_mu, px_std = self.model.predict_pxmu_pxstd( + torch.cat((x[:, k:], y_pred_all[:, :k]), dim=1), y[:, k], y_co2, y_aerosol + ) crps += (self.optim_params.loss_decay_future_timesteps**k) * self.get_crps_loss(y[:, k], px_mu, px_std) if self.optim_params.spectral_coeff > 0: spectral_loss += (self.optim_params.loss_decay_future_timesteps**k) * self.get_spatial_spectral_loss( @@ -634,6 +724,10 @@ def train_step(self): # noqa: C901 + self.optim_params.crps_coeff * crps + self.optim_params.spectral_coeff * spectral_loss + self.optim_params.temporal_spectral_coeff * temporal_spectral_loss + + self.optim_params.forcing_co2_coeff * forcing_loss_co2 + + self.optim_params.forcing_aerosol_coeff * forcing_loss_aerosol + + self.optim_params.forcing_latent_supervision_coeff * forcing_latent_loss + + decoder_utilization_penalty ) else: coef = 0 @@ -641,10 +735,10 @@ def train_step(self): # noqa: C901 if self.iteration >= iter_schedule: coef = new_coef if self.iteration == iter_schedule: - print( + logger.info( f"Scheduling spectrum coefficient at iterations {self.optim_params.scheduler_spectra} at coefficients {self.coefs_scheduler_spectra}" ) - print(f"Updating spectral coefficient to {coef} at iteration {self.iteration}!!") + logger.info(f"Updating spectral coefficient to {coef} at iteration {self.iteration}!!") loss = ( loss + self.optim_params.crps_coeff * crps @@ -653,6 +747,10 @@ def train_step(self): # noqa: C901 self.optim_params.spectral_coeff * spectral_loss + self.optim_params.temporal_spectral_coeff * temporal_spectral_loss ) + + self.optim_params.forcing_co2_coeff * forcing_loss_co2 + + self.optim_params.forcing_aerosol_coeff * forcing_loss_aerosol + + self.optim_params.forcing_latent_supervision_coeff * forcing_latent_loss + + decoder_utilization_penalty ) # backprop # mask_prev = self.model.mask.param.clone() @@ -673,7 +771,47 @@ def train_step(self): # noqa: C901 # assert torch.min(self.model.autoencoder.get_w_decoder()) >= 0.0 - self.train_loss = nll.item() if self.optim_params.udpate_ALM_using_nll else loss.item() + # Log gradient norms for diagnostic purposes (every 100 iterations) + if self.iteration % 100 == 0: + # Compute climate encoder gradient norm + climate_encoder_grads = ( + [p.grad for p in self.model.autoencoder.encoder.parameters() if p.grad is not None] + if hasattr(self.model.autoencoder, "encoder") + else [] + ) + if climate_encoder_grads: + climate_encoder_grad_norm = torch.norm(torch.stack([torch.norm(g) for g in climate_encoder_grads])) + else: + climate_encoder_grad_norm = torch.tensor(0.0) + + # Compute forcing encoder gradient norms if using forced latents + if self.model.use_forced_latents: + co2_encoder_grads = [ + p.grad for p in self.model.autoencoder.co2_forcing_encoder_mu.parameters() if p.grad is not None + ] + if co2_encoder_grads: + co2_encoder_grad_norm = torch.norm(torch.stack([torch.norm(g) for g in co2_encoder_grads])) + else: + co2_encoder_grad_norm = torch.tensor(0.0) + + aerosol_encoder_grads = [ + p.grad for p in self.model.autoencoder.aerosol_forcing_encoder_mu.parameters() if p.grad is not None + ] + if aerosol_encoder_grads: + aerosol_encoder_grad_norm = torch.norm(torch.stack([torch.norm(g) for g in aerosol_encoder_grads])) + else: + aerosol_encoder_grad_norm = torch.tensor(0.0) + + self.accelerator.log( + { + "grad_norm/climate_encoder": climate_encoder_grad_norm.item(), + "grad_norm/co2_encoder": co2_encoder_grad_norm.item(), + "grad_norm/aerosol_encoder": aerosol_encoder_grad_norm.item(), + }, + step=self.iteration, + ) + + self.train_loss = loss.item() self.train_nll = nll.item() self.train_recons = recons.item() self.train_kl = kl.item() @@ -698,6 +836,31 @@ def train_step(self): # noqa: C901 else: self.train_temporal_spectral_loss = torch.as_tensor([0.0]) + # adding the forcing reconstruction losses to the logs + self.train_forcing_co2_loss = forcing_loss_co2.item() + self.train_forcing_aerosol_loss = forcing_loss_aerosol.item() + self.train_forcing_latent_loss = forcing_latent_loss.item() + + # Debug logging for forcing latent supervision + if self.iteration % 1000 == 0 and self.model.use_forced_latents: + logger.debug( + f"[DEBUG iter {self.iteration}] Forcing losses: " + f"CO2={forcing_loss_co2.item():.6f}, " + f"Aerosol={forcing_loss_aerosol.item():.6f}, " + f"Supervision={forcing_latent_loss.item():.6f}" + ) + if gt_co2_latent is not None: + logger.debug(f" GT latent ranges: CO2=[{gt_co2_latent.min():.3f}, {gt_co2_latent.max():.3f}]") + else: + logger.warning(" gt_co2_latent is None - supervision loss not computed!") + if gt_aerosol_latent is not None: + logger.debug( + f" GT latent ranges: Aerosol=[{gt_aerosol_latent.min():.3f}, {gt_aerosol_latent.max():.3f}]" + ) + else: + logger.warning(" gt_aerosol_latent is None - supervision loss not computed!") + logger.debug(f" Decoder utilization penalty: {decoder_utilization_penalty.item():.6f}") + # # NOTE: here we have the saving, prediction, and analysis of some metrics, which comes at every print_freq # # This can be cut if we want faster training... # print(f"[GPU] Peak allocated: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB") @@ -761,7 +924,50 @@ def train_step(self): # noqa: C901 # Plotting the predictions for three different samples, including the reconstructions and the true values # if the shape of the data is icosahedral, we can plot like this: if self.iteration % self.plot_params.plot_freq == 0: - if not self.plot_params.savar and (self.d == 1 or self.d == 2 or self.d == 3 or self.d == 4): + if self.plot_params.savar: + # Use SAVAR-specific plotting for synthetic data + self.plotter.plot_compare_predictions_savar( + x_past=x_original[:, -1, :, :].cpu().detach().numpy(), + y_true=y_original.cpu().detach().numpy(), + y_recons=y_original_recons.cpu().detach().numpy(), + y_hat=y_original_pred.cpu().detach().numpy(), + sample=np.random.randint(0, self.batch_size), + lat=self.lat, + lon=self.lon, + path=self.plots_path, + iteration=self.iteration, + valid=False, + plot_through_time=self.plot_params.plot_through_time, + ) + + self.plotter.plot_compare_predictions_savar( + x_past=x_original[:, -1, :, :].cpu().detach().numpy(), + y_true=y_original.cpu().detach().numpy(), + y_recons=y_original_recons.cpu().detach().numpy(), + y_hat=y_original_pred.cpu().detach().numpy(), + sample=np.random.randint(0, self.batch_size), + lat=self.lat, + lon=self.lon, + path=self.plots_path, + iteration=self.iteration, + valid=False, + plot_through_time=self.plot_params.plot_through_time, + ) + + self.plotter.plot_compare_predictions_savar( + x_past=x_original[:, -1, :, :].cpu().detach().numpy(), + y_true=y_original.cpu().detach().numpy(), + y_recons=y_original_recons.cpu().detach().numpy(), + y_hat=y_original_pred.cpu().detach().numpy(), + sample=np.random.randint(0, self.batch_size), + lat=self.lat, + lon=self.lon, + path=self.plots_path, + iteration=self.iteration, + valid=False, + plot_through_time=self.plot_params.plot_through_time, + ) + elif self.d == 1 or self.d == 2 or self.d == 3 or self.d == 4: self.plotter.plot_compare_predictions_icosahedral( x_past=x_original[:, -1, :, :].cpu().detach().numpy(), y_true=y_original.cpu().detach().numpy(), @@ -801,7 +1007,7 @@ def train_step(self): # noqa: C901 plot_through_time=True, ) else: - print("Not plotting predictions.") + logger.debug("Not plotting predictions.") # note that this has been changed to y_pred_recons # return x, y, y_pred_all @@ -816,14 +1022,37 @@ def valid_step(self): # noqa: C901 # noqa: C901 with torch.no_grad(): # sample data try: - x, y = next(self.data_loader_val) - x = torch.nan_to_num(x) - y = torch.nan_to_num(y) + batch = next(self.data_loader_val) except StopIteration: self.data_loader_val = iter(self.datamodule.val_dataloader()) - x, y = next(self.data_loader_val) - x = torch.nan_to_num(x) - y = torch.nan_to_num(y) + batch = next(self.data_loader_val) + + # Extract data from batch (handles both dict and tuple formats) + if isinstance(batch, dict): + # New format with forcings + x = batch["x"] + y = batch["y"] + y_co2 = batch.get("co2_forcing", None) + y_aerosol = batch.get("aerosol_forcing", None) + # Extract ground truth forcing latents for supervision + gt_co2_latent = batch.get("gt_co2_latent", None) + gt_aerosol_latent = batch.get("gt_aerosol_latent", None) + else: + # Legacy format (tuple) + x, y = batch + y_co2, y_aerosol = None, None + gt_co2_latent, gt_aerosol_latent = None, None + + x = torch.nan_to_num(x) + y = torch.nan_to_num(y) + if y_co2 is not None: + y_co2 = torch.nan_to_num(y_co2) + if y_aerosol is not None: + y_aerosol = torch.nan_to_num(y_aerosol) + if gt_co2_latent is not None: + gt_co2_latent = torch.nan_to_num(gt_co2_latent) + if gt_aerosol_latent is not None: + gt_aerosol_latent = torch.nan_to_num(gt_aerosol_latent) # x, y = next(self.data_loader_val) #.sample(self.data_loader_val.n_valid - self.data_loader_val.tau, valid=True) #Check they have these features @@ -835,17 +1064,31 @@ def valid_step(self): # noqa: C901 # noqa: C901 nll = 0 recons = 0 kl = 0 + forcing_loss_co2 = torch.tensor(0.0, device=self.accelerator.device) + forcing_loss_aerosol = torch.tensor(0.0, device=self.accelerator.device) + forcing_latent_loss = torch.tensor(0.0, device=self.accelerator.device) # also make the proper prediction, not the reconstruction as we do above # With multiple future timesteps we append the prediction to x and compute the nll of next timestep etc.. # We add to the loss the sum multiplied by the decay in future timesteps # we have to take care here to make sure that we have the right tensors with requires_grad for k in range(self.future_timesteps): - nll_bis, recons_bis, kl_bis, y_pred_recons = self.get_nll(x_bis, y[:, k], z) + ( + nll_bis, + recons_bis, + kl_bis, + y_pred_recons, + forcing_co2_bis, + forcing_aerosol_bis, + forcing_latent_bis, + ) = self.get_nll(x_bis, y[:, k], z, y_co2, y_aerosol, gt_co2_latent, gt_aerosol_latent) nll += (self.optim_params.loss_decay_future_timesteps**k) * nll_bis recons += (self.optim_params.loss_decay_future_timesteps**k) * recons_bis kl += (self.optim_params.loss_decay_future_timesteps**k) * kl_bis - y_pred, y_spare, z_spare, pz_mu, pz_std = self.model.predict(x_bis, y[:, k]) + forcing_loss_co2 += (self.optim_params.loss_decay_future_timesteps**k) * forcing_co2_bis + forcing_loss_aerosol += (self.optim_params.loss_decay_future_timesteps**k) * forcing_aerosol_bis + forcing_latent_loss += (self.optim_params.loss_decay_future_timesteps**k) * forcing_latent_bis + y_pred, y_spare, z_spare, pz_mu, pz_std = self.model.predict(x_bis, y[:, k], y_co2, y_aerosol) y_pred_all[:, k] = y_pred x_bis = torch.cat((x_bis[:, 1:], y_pred.unsqueeze(1)), dim=1) # print(f"y_pred_recons shape {y_pred_recons.shape}") @@ -903,7 +1146,9 @@ def valid_step(self): # noqa: C901 # noqa: C901 spectral_loss = torch.as_tensor([0.0]) for k in range(self.future_timesteps): # This step (predict) could be removed - need to rewrite predict function, to speed things up - px_mu, px_std = self.model.predict_pxmu_pxstd(torch.cat((x[:, k:], y_pred_all[:, :k]), dim=1), y[:, k]) + px_mu, px_std = self.model.predict_pxmu_pxstd( + torch.cat((x[:, k:], y_pred_all[:, :k]), dim=1), y[:, k], y_co2=y_co2, y_aerosol=y_aerosol + ) crps += (self.optim_params.loss_decay_future_timesteps**k) * self.get_crps_loss(y[:, k], px_mu, px_std) if self.optim_params.spectral_coeff > 0: spectral_loss += ( @@ -968,6 +1213,9 @@ def valid_step(self): # noqa: C901 # noqa: C901 self.valid_ortho_cons = h_ortho.detach() # .detach() self.valid_connect_reg = connect_reg.item() self.valid_acyclic_cons = h_acyclic.item() + self.valid_forcing_co2_loss = forcing_loss_co2.item() + self.valid_forcing_aerosol_loss = forcing_loss_aerosol.item() + self.valid_forcing_latent_loss = forcing_latent_loss.item() # adding the sparsity constraint to the logs self.valid_sparsity_cons = h_sparsity.item() # .detach() @@ -1035,7 +1283,50 @@ def valid_step(self): # noqa: C901 # noqa: C901 # also plot a comparison of the past true, true, reconstructed and the predicted values for the validation data # self.plotter.plot_compare_predictions_icosahedral(self, lots of arguments! save=True) if self.iteration % self.plot_params.plot_freq == 0: - if not self.plot_params.savar and (self.d == 1 or self.d == 2 or self.d == 3 or self.d == 4): + if self.plot_params.savar: + # Use SAVAR-specific plotting for synthetic data + self.plotter.plot_compare_predictions_savar( + x_past=x_original[:, -1, :, :].cpu().detach().numpy(), + y_true=y_original.cpu().detach().numpy(), + y_recons=y_original_recons.cpu().detach().numpy(), + y_hat=y_original_pred.cpu().detach().numpy(), + sample=np.random.randint(0, self.batch_size), + lat=self.lat, + lon=self.lon, + path=self.plots_path, + iteration=self.iteration, + valid=True, + plot_through_time=self.plot_params.plot_through_time, + ) + + self.plotter.plot_compare_predictions_savar( + x_past=x_original[:, -1, :, :].cpu().detach().numpy(), + y_true=y_original.cpu().detach().numpy(), + y_recons=y_original_recons.cpu().detach().numpy(), + y_hat=y_original_pred.cpu().detach().numpy(), + sample=np.random.randint(0, self.batch_size), + lat=self.lat, + lon=self.lon, + path=self.plots_path, + iteration=self.iteration, + valid=True, + plot_through_time=self.plot_params.plot_through_time, + ) + + self.plotter.plot_compare_predictions_savar( + x_past=x_original[:, -1, :, :].cpu().detach().numpy(), + y_true=y_original.cpu().detach().numpy(), + y_recons=y_original_recons.cpu().detach().numpy(), + y_hat=y_original_pred.cpu().detach().numpy(), + sample=np.random.randint(0, self.batch_size), + lat=self.lat, + lon=self.lon, + path=self.plots_path, + iteration=self.iteration, + valid=True, + plot_through_time=self.plot_params.plot_through_time, + ) + elif self.d == 1 or self.d == 2 or self.d == 3 or self.d == 4: self.plotter.plot_compare_predictions_icosahedral( x_past=x_original[:, -1, :, :].cpu().detach().numpy(), y_true=y_original.cpu().detach().numpy(), @@ -1083,7 +1374,7 @@ def has_patience(self, patience_init, valid_loss): if valid_loss < self.best_valid_loss: self.best_valid_loss = valid_loss self.patience = patience_init - print(f"Best valid loss: {self.best_valid_loss}") + logger.info(f"Best valid loss: {self.best_valid_loss}") else: self.patience -= 1 return True @@ -1100,7 +1391,7 @@ def threshold(self): thresholded_adj = (self.model.get_adj() > 0.5).type(torch.Tensor) self.model.mask.fix(thresholded_adj) self.thresholded = True - print("Thresholding ================") + logger.info("Thresholding ================") def log_losses(self): """Append in lists values of the losses and more.""" @@ -1178,7 +1469,7 @@ def print_results(self): # print("The self.ALM_sparsity.gamma * h_sparsity is:", self.ALM_sparsity.gamma * self.train_sparsity_cons) # print("The 0.5 * self.ALM_sparsity.mu * h_sparsity**2 is:", (0.5 * self.ALM_sparsity.mu * self.train_sparsity_cons**2)) - print("****************************************************************************************") + logger.info("****************************************************************************************") # print("What are the actual values of the constraints?") # print("The connect reg is:", self.train_connect_reg) # print("The sparsity reg is:", self.train_sparsity_reg) @@ -1187,12 +1478,91 @@ def print_results(self): # print("The sparsity cons is:", self.train_sparsity_cons) # print("****************************************************************************************") - def get_nll(self, x, y, z=None) -> torch.Tensor: - - # this is just running the forward pass of LatentTSDCD... - elbo, recons, kl, preds = self.model(x, y, z, self.iteration) + def get_nll( + self, x, y, z=None, y_co2=None, y_aerosol=None, gt_co2_latent=None, gt_aerosol_latent=None + ) -> torch.Tensor: + """ + Compute negative ELBO (reconstruction + KL) and forcing reconstruction losses. - return -elbo, recons, kl, preds + Args: + x: Historical observations, shape (batch, tau, d, d_x) + y: Current observation, shape (batch, d, d_x) or (batch, 1, d, d_x) + z: Ground truth latents (None for real data) + y_co2: CO2 forcing, shape (batch, lat, lon) or (batch, 1) or None + y_aerosol: Aerosol forcing, shape (batch, lat, lon) or (batch, d_x) or None + gt_co2_latent: Ground truth CO2 latent, shape (batch, tau+1, 1) or None + gt_aerosol_latent: Ground truth aerosol latents, shape (batch, tau+1, n_aerosol) or None + + Returns: + Tuple of (-elbo, reconstruction_loss, kl_divergence, predictions, + forcing_co2_loss, forcing_aerosol_loss, forcing_latent_supervision_loss) + """ + # Process exogenous forcings if model uses them (either as MLP conditioning or forced latents) + if (self.model.use_exogenous or self.model.use_forced_latents) and y_co2 is not None and y_aerosol is not None: + # Check if we have temporal forcings (shape: batch, tau+1, spatial_dim) + # vs single-timestep forcings (shape: batch, spatial_dim) + has_temporal_dim = len(y_co2.shape) == 3 and y_co2.shape[1] == self.tau + 1 + + if not has_temporal_dim: + # Legacy path: single-timestep forcings that need spatial processing + # CO2: Keep spatial structure, flatten if needed (same as aerosol) + if len(y_co2.shape) == 3: # (batch, lat, lon) + y_co2 = y_co2.reshape(y_co2.shape[0], -1) # -> (batch, lat*lon) + # else: already (batch, spatial_dim) - keep as is + + # Aerosols: Flatten spatial structure to match d_x + if len(y_aerosol.shape) == 3: # (batch, lat, lon) + y_aerosol = y_aerosol.reshape(y_aerosol.shape[0], -1) # -> (batch, lat*lon) + # else: temporal forcings are already spatially processed by the dataset + # shapes are (batch, tau+1, d_x) for both CO2 and aerosols + else: + # Model doesn't use exogenous or forcings not provided + y_co2, y_aerosol = None, None + + # Forward pass through LatentTSDCD model + ( + elbo, + recons, + kl, + preds, + forcing_recons_loss_co2, + forcing_recons_loss_aerosol, + encoded_forcing_mu, + ) = self.model(x, y, z, self.iteration, y_co2=y_co2, y_aerosol=y_aerosol) + + # Compute forcing latent supervision loss if ground truth latents available + forcing_latent_supervision_loss = torch.tensor(0.0, device=x.device) + forcing_arch = getattr(self.model, "forcing_arch", "baseline") + if ( + forcing_arch != "predefined" + and encoded_forcing_mu is not None + and gt_co2_latent is not None + and gt_aerosol_latent is not None + ): + # Extract ground truth latents for the target timestep (last timestep, tau+1 index) + # gt_co2_latent shape: (batch, tau+1, 1), we want [:, -1, :] + # gt_aerosol_latent shape: (batch, tau+1, n_aerosol_latents), we want [:, -1, :] + gt_co2_target = gt_co2_latent[:, -1, :] # (batch, 1) + gt_aerosol_target = gt_aerosol_latent[:, -1, :] # (batch, n_aerosol_latents) + + # Concatenate ground truth forcing latents: [CO2, aerosol] + gt_forcing_target = torch.cat([gt_co2_target, gt_aerosol_target], dim=1) # (batch, n_forced_latents_total) + + # Compute MSE between encoded and ground truth forcing latents + forcing_latent_supervision_loss = torch.mean((encoded_forcing_mu - gt_forcing_target) ** 2) + elif forcing_arch == "predefined" and not hasattr(self, "_forcing_arch_supervision_logged"): + logger.info("[ForcingArch] Skipping forcing latent supervision (predefined arch)") + self._forcing_arch_supervision_logged = True + + return ( + -elbo, + recons, + kl, + preds, + forcing_recons_loss_co2, + forcing_recons_loss_aerosol, + forcing_latent_supervision_loss, + ) def get_regularisation(self) -> float: if self.iteration > self.optim_params.schedule_reg: @@ -1218,15 +1588,18 @@ def get_acyclicity_violation(self) -> torch.Tensor: def get_ortho_violation(self, w: torch.Tensor) -> float: if self.iteration > self.optim_params.schedule_ortho: - # What should be the size here? Is the first dimension the different variables? - # constraint = torch.tensor([0.]) - k = w.size(2) - # for i in range(w.size(0)): - # constraint = constraint + torch.norm(w[i].T @ w[i] - torch.eye(k), p=2) + # Only apply orthogonality constraint to columns that are actually used in decoding + # When use_forced_latents=True, forcing latent columns are not used, so exclude them + if self.model.use_forced_latents: + n_climate = self.model.d_z - self.model.n_forced_latents_co2 - self.model.n_forced_latents_aerosol + w_climate = w[:, :, :n_climate] # Only climate latent columns + k = n_climate + else: + w_climate = w + k = w.size(2) + i = 0 - # constraint = torch.norm(w[i].T @ w[i] - torch.eye(k), p=2, dim=1) - constraint = w[i].T @ w[i] - torch.eye(k) - # print('What is the ortho constraint shape:', constraint.shape) + constraint = w_climate[i].T @ w_climate[i] - torch.eye(k, device=w.device) h = constraint / self.ortho_normalization else: h = torch.as_tensor([0.0]) @@ -1258,6 +1631,7 @@ def get_sparsity_violation(self, lower_threshold, upper_threshold) -> float: # first get the adj adj = self.model.get_adj() + adj = self.model.get_adj() sum_of_connections = torch.norm(adj, p=1) / self.sparsity_normalization @@ -1412,7 +1786,7 @@ def get_crps_loss(self, y, mu, sigma): crps[idx] = t1 + t2 if torch.isnan(crps).any(): - print("[NaN] Final CRPS") + logger.warning("[NaN] Final CRPS") # Clamp final CRPS to ensure numerical validity crps = torch.nan_to_num(crps, nan=0.0, posinf=1e3, neginf=0.0) @@ -1595,23 +1969,40 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = # Make the iterator again, since otherwise we have iterated through it already... train_dataloader = iter(self.datamodule.train_dataloader(accelerator=self.accelerator)) - x, y = next(train_dataloader) + batch = next(train_dataloader) + + # Extract data from batch (handles both dict and tuple formats) + if isinstance(batch, dict): + x = batch["x"] + y = batch["y"] + y_co2 = batch.get("co2_forcing", None) + y_aerosol = batch.get("aerosol_forcing", None) + else: + x, y = batch + y_co2, y_aerosol = None, None x = torch.nan_to_num(x) y = torch.nan_to_num(y) + if y_co2 is not None: + y_co2 = torch.nan_to_num(y_co2) + if y_aerosol is not None: + y_aerosol = torch.nan_to_num(y_aerosol) + y = y[:, 0] z = None # print("First up, I will do the reconstruction effort") - nll, recons, kl, y_pred_recons = self.get_nll(x, y, z) + nll, recons, kl, y_pred_recons, forcing_co2_loss, forcing_aerosol_loss, forcing_latent_supervision_loss = ( + self.get_nll(x, y, z, y_co2, y_aerosol) + ) # ensure these are correct with torch.no_grad(): - y_pred, y, z, pz_mu, pz_std = self.model.predict(x, y) + y_pred, y, z, pz_mu, pz_std = self.model.predict(x, y, y_co2, y_aerosol) # Here we predict, but taking 100 samples from the latents # TODO: make this into an argument - samples_from_xs, samples_from_zs, y = self.model.predict_sample(x, y, 10) + samples_from_xs, samples_from_zs, y = self.model.predict_sample(x, y, 10, y_co2, y_aerosol) # append the first prediction predictions.append(y_pred) @@ -1652,7 +2043,7 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = # then predict the next timestep # y at this point is pointless!!! with torch.no_grad(): - y_pred, y, z, pz_mu, pz_std = self.model.predict(x, y) + y_pred, y, z, pz_mu, pz_std = self.model.predict(x, y, y_co2, y_aerosol) # append the prediction predictions.append(y_pred) @@ -1675,18 +2066,18 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = predictions = torch.stack(predictions, dim=1) # the resulting shape of this tensor is (batch_size, timesteps, num_vars, coords) - print("What is the shape of the predictions, once I made it into a tensor?", predictions.shape) + logger.debug("What is the shape of the predictions, once I made it into a tensor? %s", predictions.shape) # then calculate the mean of the predictions along the timesteps y_pred_mean = torch.mean(predictions, dim=1) # calculate the variance of the predictions along the timesteps dimension y_pred_var = torch.var(predictions, dim=1) - print("What is the shape of the mean of the predictions?", y_pred_mean.shape) - print("What is the shape of the variance of the predictions?", y_pred_var.shape) + logger.debug("What is the shape of the mean of the predictions? %s", y_pred_mean.shape) + logger.debug("What is the shape of the variance of the predictions? %s", y_pred_var.shape) # take the mean of the predictions along the batch and coordinates dimension: - print( - "What is the shape when I try to take the mean across the batch and coordinates:", + logger.debug( + "What is the shape when I try to take the mean across the batch and coordinates: %s", torch.mean(y_pred_mean, dim=(0, 2)), ) @@ -1705,7 +2096,9 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = # take the mean across the frequencies, the 1st dimension spatial_spectra_score = torch.mean(spatial_spectra_score, dim=1) - print("Spatial spectra score, lower is better...should be a spectra for each var", spatial_spectra_score) + logger.debug( + "Spatial spectra score, lower is better...should be a spectra for each var %s", spatial_spectra_score + ) # if this spatial_spectra_score is the lowest we have seen, then save the predictions if self.best_spatial_spectra_score is None: @@ -1715,19 +2108,21 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = assert self.best_spatial_spectra_score is not None # check if every element of spatial_spectra_score is less than the best_spatial_spectra_score: - print(torch.all(spatial_spectra_score < self.best_spatial_spectra_score)) + logger.debug("%s", torch.all(spatial_spectra_score < self.best_spatial_spectra_score)) - print("new score", spatial_spectra_score) - print("previous best score", self.best_spatial_spectra_score) + logger.debug("new score %s", spatial_spectra_score) + logger.debug("previous best score %s", self.best_spatial_spectra_score) if torch.all(spatial_spectra_score < self.best_spatial_spectra_score): - print("The spatial spectra score is the best we have seen for all variables, I am in the if.") + logger.info("The spatial spectra score is the best we have seen for all variables, I am in the if.") self.best_spatial_spectra_score = spatial_spectra_score - print(f"Best spatial spectra score: {self.best_spatial_spectra_score}") + logger.info(f"Best spatial spectra score: {self.best_spatial_spectra_score}") # save the model in its current state - print("Saving the model, since the spatial spectra score is the best we have seen for all variables.") + logger.info( + "Saving the model, since the spatial spectra score is the best we have seen for all variables." + ) torch.save(self.model.state_dict(), self.save_path / "best_model_for_average_spectra.pth") else: @@ -1735,25 +2130,42 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = # bs = np.min([self.data.n_valid, 1000]) # Make the iterator again val_dataloader = iter(self.datamodule.val_dataloader()) - x, y = next(val_dataloader) + batch = next(val_dataloader) + + # Extract data from batch (handles both dict and tuple formats) + if isinstance(batch, dict): + x = batch["x"] + y = batch["y"] + y_co2 = batch.get("co2_forcing", None) + y_aerosol = batch.get("aerosol_forcing", None) + else: + x, y = batch + y_co2, y_aerosol = None, None # old, using existing dataloader # x, y = next(self.data_loader_val) y = torch.nan_to_num(y) x = torch.nan_to_num(x) + if y_co2 is not None: + y_co2 = torch.nan_to_num(y_co2) + if y_aerosol is not None: + y_aerosol = torch.nan_to_num(y_aerosol) + y = y[:, 0] z = None # print("First up, I will do the reconstruction effort") - nll, recons, kl, y_pred_recons = self.get_nll(x, y, z) + nll, recons, kl, y_pred_recons, forcing_co2_loss, forcing_aerosol_loss, forcing_latent_supervision_loss = ( + self.get_nll(x, y, z, y_co2, y_aerosol) + ) # swap with torch.no_grad(): - y_pred, y, z, pz_mu, pz_std = self.model.predict(x, y) + y_pred, y, z, pz_mu, pz_std = self.model.predict(x, y, y_co2, y_aerosol) # predict and take 100 samples too - samples_from_xs, samples_from_zs, y = self.model.predict_sample(x, y, 100) + samples_from_xs, samples_from_zs, y = self.model.predict_sample(x, y, 100, y_co2, y_aerosol) x_original = x.clone().detach() y_original = y.clone().detach() @@ -1783,7 +2195,7 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = with torch.no_grad(): # then predict the next timestep - y_pred, y, z, pz_mu, pz_std = self.model.predict(x, y) + y_pred, y, z, pz_mu, pz_std = self.model.predict(x, y, y_co2, y_aerosol) np.save(self.save_path / f"val_x_ar_{i}.npy", x.detach().cpu().numpy()) np.save(self.save_path / f"val_y_ar_{i}.npy", y.detach().cpu().numpy()) @@ -1832,10 +2244,10 @@ def score_the_samples_for_spatial_spectra(self, y_true, y_pred_samples, num_samp """ # calculate the average spatial spectra of the true values, averaging across the batch - print("y_true shape:", y_true.shape) + logger.debug("y_true shape: %s", y_true.shape) fft_true = torch.mean(torch.abs(torch.fft.rfft(y_true, dim=3)), dim=0) # calculate the average spatial spectra of the individual predicted fields - I think this below is wrong - print("y_pred shape:", y_pred_samples.shape) + logger.debug("y_pred shape: %s", y_pred_samples.shape) fft_pred = torch.mean(torch.abs(torch.fft.rfft(y_pred_samples, dim=3)), dim=0) # extend fft_true so it is the same value but extended to the same shape as fft_pred @@ -1855,7 +2267,7 @@ def score_the_samples_for_spatial_spectra(self, y_true, y_pred_samples, num_samp return spatial_spectra_score - def particle_filter(self, x, y, num_particles, timesteps=120): + def particle_filter(self, x, y, num_particles, timesteps=120, y_co2=None, y_aerosol=None): """Implement a particle filter to make a set of autoregressive predictions, where each created sample is evaluated by some score, and we do a particle filter to select only best samples to continue the autoregressive rollout.""" @@ -1866,7 +2278,7 @@ def particle_filter(self, x, y, num_particles, timesteps=120): for _ in range(timesteps): # Prediction # make all the new predictions, taking samples from the latents - _, samples_from_zs, y = self.model.predict_sample(x, y, 100) + _, samples_from_zs, y = self.model.predict_sample(x, y, 100, y_co2, y_aerosol) # then calculate the score of each of the samples # Update the weights, where we want the weights to increase as the score improves diff --git a/climatem/model/tsdcd_latent.py b/climatem/model/tsdcd_latent.py index 98cea93..05d8981 100644 --- a/climatem/model/tsdcd_latent.py +++ b/climatem/model/tsdcd_latent.py @@ -1,4 +1,13 @@ -# Adapted from the original code for CDSD, Brouillard et al., 2024. +""" +Core causal discovery model with latent variables (LatentTSDCD). + +This module implements latent causal graph learning with an encoder/decoder architecture and a learnable causal mask. +The model discovers temporal causal structure among latent variables from observed time series, using differentiable +structure learning with Gumbel-softmax relaxation and acyclicity constraints. + +Adapted from the original code for CDSD (Brouillard et al., 2024): "Causal Discovery with Score-based methods for +time series with latent confounders." +""" from collections import OrderedDict from math import pi @@ -8,10 +17,35 @@ import torch.nn as nn from torch.distributions import Distribution +from climatem.utils import get_logger + +logger = get_logger(__name__) + +# Euler-Mascheroni constant, used in the Gumbel distribution CDF for Gumbel-softmax reparameterization euler_mascheroni = 0.57721566490153286060 class Mask(nn.Module): + """ + Learnable causal graph adjacency matrix for differentiable structure learning. + + Parameterizes the causal graph via sigmoid (or Gumbel-softmax) over learnable + logits. During training, edges are sampled stochastically; at evaluation, + the sigmoid probabilities can be thresholded to obtain a binary graph. + + Attributes: + param: Learnable logit tensor of shape ``(tau, d*d_x, d*d_x)`` where + ``tau`` is the number of time lags, and each ``(d*d_x, d*d_x)`` + slice is a source-to-target adjacency matrix over all + (variable, latent) pairs. + drawhard: If True, uses the straight-through estimator to produce + binary mask values in the forward pass while allowing gradient + flow through the soft Gumbel-sigmoid in the backward pass. + fixed_output_fraction: Fraction of mask entries that are fixed (not + learned). When ``fixed=True``, this controls the density of the + random fixed mask. + """ + def __init__( self, d: int, @@ -39,11 +73,15 @@ def __init__( # Here we could change how the mask is instantiated in the causal graph. if self.latent: if not nodiag: + # Initializes logits to 5 so sigmoid(5) ~ 0.993, meaning all edges + # start as "present" and are pruned during training via sparsity penalty. self.param = nn.Parameter(torch.ones((self.tau, d * d_x, d * d_x)) * 5) self.fixed_mask = torch.ones_like(self.param) else: param = torch.ones((self.tau, d * d_x, d * d_x)) param[:, torch.arange(d * d_x), torch.arange(d * d_x)] = -1 + # Initializes logits to 5 so sigmoid(5) ~ 0.993 (all edges "present"); + # diagonal entries are set to -5 so sigmoid(-5) ~ 0.007 (self-loops suppressed). self.param = nn.Parameter(param * 5) self.fixed_mask = torch.ones_like(self.param) self.fixed_mask[:, torch.arange(self.fixed_mask.size(1)), torch.arange(self.fixed_mask.size(2))] = 0 @@ -52,7 +90,7 @@ def __init__( self.fixed_mask[-1, torch.arange(self.fixed_mask.size(1)), torch.arange(self.fixed_mask.size(2))] = 0 else: if self.instantaneous: - # initialize mask as log(mask_ij) = 1 + # Logits initialized to 5: sigmoid(5) ~ 0.993, all edges start "present". self.param = nn.Parameter(torch.ones((self.tau, d, d, d_x)) * 5) self.fixed_mask = torch.ones_like(self.param) # set diagonal 0 for G_t0 @@ -60,7 +98,7 @@ def __init__( # TODO: set neighbors to 0 # self.fixed_mask[:, :, :, d_x] = 0 else: - # initialize mask as log(mask_ij) = 1 + # Logits initialized to 5: sigmoid(5) ~ 0.993, all edges start "present". self.param = nn.Parameter(torch.ones((tau, d, d, d_x)) * 5) self.fixed_mask = torch.ones_like(self.param) @@ -123,8 +161,11 @@ class MixingMask(nn.Module): def __init__(self, d: int, d_x: int, d_z: int, gt_mask=None): super().__init__() if gt_mask is not None: + # Clamps ground-truth edges to high logit value (10.0) so + # sigmoid(10) ~ 1.0, effectively fixing known edges as present. self.param = (gt_mask > 0) * 10.0 else: + # Logits initialized to 5: sigmoid(5) ~ 0.993, all edges start "present". self.param = nn.Parameter(torch.ones(d, d_x, d_z) * 5) def forward(self, batch_size): @@ -180,7 +221,23 @@ def forward(self, x) -> torch.Tensor: class LatentTSDCD(nn.Module): - """Differentiable Causal Discovery for time series with latent variables.""" + """ + Differentiable Causal Discovery for time series with latent variables. + + Implements the LatentTSDCD architecture: an encoder maps observations to latent + variables, a learnable causal mask (Gumbel-sigmoid) parameterizes the temporal + causal graph, a transition model predicts future latents conditioned on masked + past latents, and a decoder reconstructs observations from latents. + + Variable name glossary (used throughout forward / loss computation): + px_mu -- predicted x mean (decoder output), shape (batch, d, d_x) + px_std -- predicted x std (decoder output), shape (batch, d, d_x) + pz_mu -- predicted z mean (latent dynamics / transition output), shape (batch, d, d_z) + pz_std -- predicted z std (latent dynamics / transition output), shape (batch, d, d_z) + q_mu_y -- variational posterior mean for the target y, shape (batch, d, d_z) + q_std_y -- variational posterior std for the target y, shape (batch, d, d_z) + qz_mu -- variational posterior mean for z (alias used in some contexts) + """ def __init__( self, @@ -212,9 +269,17 @@ def __init__( # gt_graph: torch.tensor = None, # gt_w: torch.tensor = None, tied_w: bool = False, + reduce_encoding_pos_dim: bool = False, fixed: bool = False, fixed_output_fraction: float = 1.0, gev_learn_xi: bool = False, + use_exogenous: bool = False, + d_y_co2: int = 0, + d_y_aerosol: int = 0, + use_forced_latents: bool = False, + n_forced_latents_co2: int = 1, + n_forced_latents_aerosol: int = 2, + forcing_arch: str = "baseline", ): """ Args: @@ -248,6 +313,9 @@ def __init__( fixed: if True, fix the mask (in simple case to all ones) fixed_output_fraction: fraction of ones in the fixed gev_learn_xi: if True, GEV will take learned xi + use_exogenous: if True, condition on exogenous forcings (CO2 + aerosols) + d_y_co2: dimension of CO2 forcing (typically 1 for global) + d_y_aerosol: dimension of aerosol forcing (typically d_x for local) """ super().__init__() @@ -259,6 +327,7 @@ def __init__( self.num_layers_mixing = num_layers_mixing self.num_hidden_mixing = num_hidden_mixing self.position_embedding_dim = position_embedding_dim + self.reduce_encoding_pos_dim = reduce_encoding_pos_dim self.transition_param_sharing = transition_param_sharing self.position_embedding_transition = position_embedding_transition self.coeff_kl = coeff_kl @@ -278,6 +347,18 @@ def __init__( self.fixed = fixed self.fixed_output_fraction = fixed_output_fraction self.gev_learn_xi = gev_learn_xi + self.use_exogenous = use_exogenous + self.d_y_co2 = d_y_co2 + self.d_y_aerosol = d_y_aerosol + self.use_forced_latents = use_forced_latents + self.n_forced_latents_co2 = n_forced_latents_co2 + self.n_forced_latents_aerosol = n_forced_latents_aerosol + self.forcing_arch = forcing_arch + self._forcing_arch_logged = False + if self.forcing_arch == "predefined" and self.use_forced_latents: + raise ValueError( + "forcing_arch='predefined' requires use_forced_latents=False and d_z to include only climate latents." + ) if self.instantaneous: self.total_tau = tau + 1 @@ -325,8 +406,16 @@ def __init__( raise NotImplementedError(f"Decoder distribution '{distr_decoder}' is not implemented.") # self.encoder_decoder = EncoderDecoder(self.d, self.d_x, self.d_z, self.nonlinear_mixing, 4, 1, self.debug_gt_w, self.gt_w, self.tied_w) + # MLP conditioning dims: only non-zero when use_exogenous (raw forcings concatenated to MLP inputs) + d_y_co2_cond = self.d_y_co2 if self.use_exogenous else 0 + d_y_aerosol_cond = self.d_y_aerosol if self.use_exogenous else 0 + # Forcing encoder/decoder spatial dims: need real dims whenever use_forced_latents, + # independent of whether raw forcings are used as MLP conditioning (use_exogenous). + d_y_co2_spatial = self.d_y_co2 if self.use_forced_latents else 0 + d_y_aerosol_spatial = self.d_y_aerosol if self.use_forced_latents else 0 + if self.nonlinear_mixing: - print("NON-LINEAR MIXING") + logger.info("NON-LINEAR MIXING") # NOTE:(seb) using the noloop version of non-linear here to make it much faster. self.autoencoder = NonLinearAutoEncoderUniqueMLP_noloop( d, @@ -336,12 +425,32 @@ def __init__( self.num_layers_mixing, tied=tied_w, embedding_dim=self.position_embedding_dim, + reduce_encoding_pos_dim=self.reduce_encoding_pos_dim, gt_w=None, + d_y_co2=d_y_co2_cond, + d_y_aerosol=d_y_aerosol_cond, + use_forced_latents=self.use_forced_latents, + n_forced_latents_co2=self.n_forced_latents_co2, + n_forced_latents_aerosol=self.n_forced_latents_aerosol, + d_y_co2_spatial=d_y_co2_spatial, + d_y_aerosol_spatial=d_y_aerosol_spatial, ) else: # print('Using linear mixing') - print("LINEAR MIXING") - self.autoencoder = LinearAutoEncoder(d, d_x, d_z, tied=tied_w) + logger.info("LINEAR MIXING") + self.autoencoder = LinearAutoEncoder( + d, + d_x, + d_z, + tied=tied_w, + d_y_co2=d_y_co2_cond, + d_y_aerosol=d_y_aerosol_cond, + use_forced_latents=self.use_forced_latents, + n_forced_latents_co2=self.n_forced_latents_co2, + n_forced_latents_aerosol=self.n_forced_latents_aerosol, + d_y_co2_spatial=d_y_co2_spatial, + d_y_aerosol_spatial=d_y_aerosol_spatial, + ) # if debug_gt_w: # self.decoder.w = gt_w @@ -355,7 +464,9 @@ def __init__( self.num_layers, self.num_hidden, self.num_output, - self.position_embedding_transition, + self.position_embedding_dim, + d_y_co2=self.d_y_co2 if self.use_exogenous else 0, + d_y_aerosol=self.d_y_aerosol if self.use_exogenous else 0, ) else: self.transition_model = TransitionModel( @@ -391,51 +502,135 @@ def get_adj(self): """ return self.mask.get_proba() - def encode(self, x, y): - """Encode X and Y into latent variables Z.""" - b = x.size(0) - z = torch.zeros(b, self.tau + 1, self.d, self.d_z) - mu = torch.zeros(b, self.d, self.d_z) - std = torch.zeros(b, self.d, self.d_z) - - # sample Zs - - # TODO: Can we remove this for loop? - for i in range(self.d): - # TODO: Can we remove this for loop? - for t in range(self.tau): - # q_mu, q_logvar = self.encoder_decoder(x[:, t, i], i, encoder=True) # torch.matmul(self.W, x) - q_mu, q_logvar = self.autoencoder(x[:, t, i], i, encode=True) - # reparam trick - here we sample from a Gaussian...every time - q_std = torch.exp(0.5 * q_logvar) - z[:, t, i] = q_mu + q_std * self.distr_encoder(0, 1, size=q_mu.size()) + def encode(self, x, y, y_co2=None, y_aerosol=None): + """ + Encode observations X (history) and Y (target) into latent variables Z. - # q_mu, q_logvar = self.encoder_decoder(y[:, i], i, encoder=True) # torch.matmul(self.W, x) + Args: + x: Historical observations, shape (batch, tau, d, d_x). + y: Target observation, shape (batch, d, d_x). + y_co2: Optional CO2 forcing, shape (batch, tau+1, d_y_co2) or (batch, d_y_co2). + y_aerosol: Optional aerosol forcing, shape (batch, tau+1, d_y_aerosol) or (batch, d_y_aerosol). + + Returns: + z: Latent variables, shape (batch, tau+1, d, d_z). + mu: Variational posterior mean for the target timestep, shape (batch, d, d_z). + std: Variational posterior std for the target timestep, shape (batch, d, d_z). + """ + b = x.size(0) # batch size + z = torch.zeros(b, self.tau + 1, self.d, self.d_z, device=x.device) + mu = torch.zeros(b, self.d, self.d_z, device=x.device) + std = torch.zeros(b, self.d, self.d_z, device=x.device) + + # Handle forced latents if enabled + if self.use_forced_latents and y_co2 is not None and y_aerosol is not None: + # Calculate number of climate latents + n_climate_latents = self.d_z - self.n_forced_latents_co2 - self.n_forced_latents_aerosol + + # y_co2 shape: (batch_size, tau+1, spatial_dim) - NOW SPATIAL like aerosol! + # y_aerosol shape: (batch_size, tau+1, spatial_dim) + # We need to process forcings at each timestep separately + + # For SAVAR, d=1, so we only iterate once over i + for i in range(self.d): + # Encode climate latents from observations for all timesteps + for t in range(self.tau): + # Extract forcing at timestep t + y_co2_t = y_co2[:, t] if y_co2.dim() == 3 else y_co2 # Handle both temporal and single timestep + y_aerosol_t = y_aerosol[:, t] if y_aerosol.dim() == 3 else y_aerosol + + # Encode forcings for this timestep + z_forced_t, _, _ = self.autoencoder.encode_forcings(y_co2_t, y_aerosol_t) + + if self.use_exogenous: + q_mu, q_logvar = self.autoencoder( + x[:, t, i], i, encode=True, forcing_co2=y_co2_t, forcing_aerosol=y_aerosol_t + ) + else: + q_mu, q_logvar = self.autoencoder(x[:, t, i], i, encode=True) + + q_std = torch.exp(0.5 * q_logvar) + # Only encode to climate latents + z[:, t, i, :n_climate_latents] = q_mu[:, :n_climate_latents] + q_std[ + :n_climate_latents + ] * self.distr_encoder(0, 1, size=(b, n_climate_latents)) + # Fill forced latents with timestep-specific forcings + z[:, t, i, n_climate_latents:] = z_forced_t + + # Encode the target timestep (y) using final forcing timestep + y_co2_target = y_co2[:, -1] if y_co2.dim() == 3 else y_co2 + y_aerosol_target = y_aerosol[:, -1] if y_aerosol.dim() == 3 else y_aerosol + + z_forced_target, mu_forced, std_forced = self.autoencoder.encode_forcings( + y_co2_target, y_aerosol_target + ) - q_mu, q_logvar = self.autoencoder(y[:, i], i, encode=True) - q_std = torch.exp(0.5 * q_logvar) + if self.use_exogenous: + q_mu, q_logvar = self.autoencoder( + y[:, i], i, encode=True, forcing_co2=y_co2_target, forcing_aerosol=y_aerosol_target + ) + else: + q_mu, q_logvar = self.autoencoder(y[:, i], i, encode=True) - # # e.g. z[:, -2, i] - # all_z_except_last = z[:, :-1, i].clone() - # penultimate_z = z[:, -2, i].clone() + q_std = torch.exp(0.5 * q_logvar) + # Only encode climate latents + z[:, -1, i, :n_climate_latents] = q_mu[:, :n_climate_latents] + q_std[ + :n_climate_latents + ] * self.distr_encoder(0, 1, size=(b, n_climate_latents)) + # Fill forced latents with target timestep forcings + z[:, -1, i, n_climate_latents:] = z_forced_target + + # Store full mu and std (including forced latents from target timestep) + mu[:, i, :n_climate_latents] = q_mu[:, :n_climate_latents] + mu[:, i, n_climate_latents:] = mu_forced + std[:, i, :n_climate_latents] = q_std[:n_climate_latents] + std[:, i, n_climate_latents:] = std_forced - # assert torch.mean(z[:, -1, i]) == 0.0 + else: + # Original encoding path (all latents from observations) + for i in range(self.d): + for t in range(self.tau): + if self.use_exogenous and y_co2 is not None and y_aerosol is not None: + q_mu, q_logvar = self.autoencoder( + x[:, t, i], i, encode=True, forcing_co2=y_co2, forcing_aerosol=y_aerosol + ) + else: + q_mu, q_logvar = self.autoencoder(x[:, t, i], i, encode=True) + + q_std = torch.exp(0.5 * q_logvar) + z[:, t, i] = q_mu + q_std * self.distr_encoder(0, 1, size=q_mu.size()) + + if self.use_exogenous and y_co2 is not None and y_aerosol is not None: + q_mu, q_logvar = self.autoencoder( + y[:, i], i, encode=True, forcing_co2=y_co2, forcing_aerosol=y_aerosol + ) + else: + q_mu, q_logvar = self.autoencoder(y[:, i], i, encode=True) - # carry on - z[:, -1, i] = q_mu + q_std * self.distr_encoder(0, 1, size=q_mu.size()) - # assert torch.all(penultimate_z == z[:, -2, i]) - # assert torch.all(all_z_except_last == z[:, :-1, i]) + q_std = torch.exp(0.5 * q_logvar) + z[:, -1, i] = q_mu + q_std * self.distr_encoder(0, 1, size=q_mu.size()) - mu[:, i] = q_mu - std[:, i] = q_std + mu[:, i] = q_mu + std[:, i] = q_std return z, mu, std - def transition(self, z, mask): + def transition(self, z, mask, y_co2=None, y_aerosol=None): + """Compute latent dynamics: predict next-step latent distribution p(z^t | z^{= 0, f"KL={kl} has to be >= 0" - elbo = recons - kl - - return elbo, recons, kl, px_mu + elbo = recons - self.coeff_kl * kl + + # Compute forcing reconstruction losses + forcing_recons_loss_co2 = torch.tensor(0.0, device=x.device) + forcing_recons_loss_aerosol = torch.tensor(0.0, device=x.device) + + if self.use_forced_latents and y_co2 is not None and y_aerosol is not None: + # Extract forcing latents from z (last timestep, first feature dimension, forcing latent indices) + n_climate_latents = self.d_z - self.n_forced_latents_co2 - self.n_forced_latents_aerosol + forcing_arch = getattr(self, "forcing_arch", "baseline") + if forcing_arch == "baseline": + if not self._forcing_arch_logged: + logger.info("[ForcingArch] Using forcing_arch='baseline' (encoded forced latents)") + self._forcing_arch_logged = True + z_forced_target = z[:, -1, 0, n_climate_latents:] # Shape: (batch, n_forced_latents_total) + # Decode forcing latents back to forcing space + forcing_co2_recons, forcing_aerosol_recons = self.autoencoder.decode_forcings(z_forced_target) + elif forcing_arch == "transitioned": + if not self._forcing_arch_logged: + logger.info("[ForcingArch] Using forcing_arch='transitioned' (pz_mu forced latents)") + self._forcing_arch_logged = True + # Use transitioned latents (pz_mu) for forcing reconstruction + z_forced_target = pz_mu[:, 0, n_climate_latents:] # Shape: (batch, n_forced_latents_total) + forcing_co2_recons, forcing_aerosol_recons = self.autoencoder.decode_forcings(z_forced_target) + elif forcing_arch == "predefined": + if not self._forcing_arch_logged: + logger.info("[ForcingArch] Using forcing_arch='predefined' (no forcing reconstruction)") + self._forcing_arch_logged = True + # No forcing reconstruction in predefined conditioning mode + forcing_co2_recons, forcing_aerosol_recons = None, None + else: + raise ValueError(f"Unknown forcing_arch='{forcing_arch}'") + + if forcing_arch != "predefined": + # Get target forcings (last timestep) + y_co2_target = y_co2[:, -1] if y_co2.dim() == 3 else y_co2 + y_aerosol_target = y_aerosol[:, -1] if y_aerosol.dim() == 3 else y_aerosol + + # Compute MSE reconstruction losses + forcing_recons_loss_co2 = torch.mean((forcing_co2_recons - y_co2_target) ** 2) + forcing_recons_loss_aerosol = torch.mean((forcing_aerosol_recons - y_aerosol_target) ** 2) + + return ( + elbo, + recons, + kl, + px_mu, + forcing_recons_loss_co2, + forcing_recons_loss_aerosol, + encoded_forcing_mu, + ) # def predict(self, x, y): # b = x.size(0) # with torch.no_grad(): # # sample Zs (based on X) - # z, q_mu_y, q_std_y = self.encode(x, y) + # z, q_mu_y, q_std_y = self.encode(x, y, y_co2, y_aerosol) # # # get params of the transition model p(z^t | z^{ float: # torch.einsum('bd, bd -> b', (mu2 - mu1) * (1 / sigma2), mu2 - mu1)) if torch.sum(kl) < 0: __import__("ipdb").set_trace() - print(sigma2**self.d_z) - print(torch.prod(sigma1, dim=1)) - print(torch.sum(torch.log(sigma2**self.d_z / torch.prod(sigma1, dim=1)))) - print(torch.sum(torch.sum(sigma1 / sigma2, dim=1))) + logger.debug("sigma2**d_z: %s", sigma2**self.d_z) + logger.debug("prod(sigma1): %s", torch.prod(sigma1, dim=1)) + logger.debug( + "sum(log(sigma2**d_z / prod(sigma1))): %s", + torch.sum(torch.log(sigma2**self.d_z / torch.prod(sigma1, dim=1))), + ) + logger.debug("sum(sum(sigma1 / sigma2)): %s", torch.sum(torch.sum(sigma1 / sigma2, dim=1))) # print(torch.sum(torch.einsum('bd, bd -> b', (mu2 - mu1) * (1 / s_p), mu2 - mu1))) return torch.sum(kl) class LinearAutoEncoder(nn.Module): - def __init__(self, d, d_x, d_z, tied): + def __init__( + self, + d, + d_x, + d_z, + tied, + d_y_co2=0, + d_y_aerosol=0, + use_forced_latents=False, + n_forced_latents_co2=1, + n_forced_latents_aerosol=4, + d_y_co2_spatial=None, + d_y_aerosol_spatial=None, + ): super().__init__() + self.d_y_co2 = d_y_co2 + self.d_y_aerosol = d_y_aerosol + # Spatial dims for forcing encoders/decoders (independent of MLP conditioning dims). + # When use_exogenous=False but use_forced_latents=True, d_y_co2 may be 0 + # (no MLP conditioning) while d_y_co2_spatial carries the real spatial dimension. + self.d_y_co2_spatial = d_y_co2_spatial if d_y_co2_spatial is not None else d_y_co2 + self.d_y_aerosol_spatial = d_y_aerosol_spatial if d_y_aerosol_spatial is not None else d_y_aerosol self.d_x = d_x self.d_z = d_z self.tied = tied + self.use_grad_project = True + self.use_forced_latents = use_forced_latents + self.n_forced_latents_co2 = n_forced_latents_co2 + self.n_forced_latents_aerosol = n_forced_latents_aerosol + unif = (1 - 0.1) * torch.rand(size=(d, d_x, d_z)) + 0.1 self.w = nn.Parameter(unif / torch.as_tensor(d_z)) if not tied: unif = (1 - 0.1) * torch.rand(size=(d, d_z, d_x)) + 0.1 self.w_encoder = nn.Parameter(unif / torch.as_tensor(d_x)) - # self.logvar_encoder = nn.Parameter(torch.ones(d) * -1) - # self.logvar_decoder = nn.Parameter(torch.ones(d) * -1) self.logvar_encoder = nn.Parameter(torch.ones(d_z) * -1) self.logvar_decoder = nn.Parameter(torch.ones(d_x) * -1) + if use_forced_latents: + self.n_climate_latents = d_z - n_forced_latents_co2 - n_forced_latents_aerosol + + # CO2 forcing encoder: linear projection (uses real spatial dim, not MLP conditioning dim) + self.co2_forcing_encoder_mu = nn.Linear(self.d_y_co2_spatial, n_forced_latents_co2) + self.co2_forcing_encoder_logvar = nn.Parameter(torch.ones(n_forced_latents_co2) * -1) + + # Aerosol forcing encoder: linear projection + self.aerosol_forcing_encoder_mu = nn.Linear(self.d_y_aerosol_spatial, n_forced_latents_aerosol) + self.aerosol_forcing_encoder_logvar = nn.Parameter(torch.ones(n_forced_latents_aerosol) * -1) + + # Forcing decoder spatial weights + self.w_co2 = nn.Parameter(torch.randn(self.d_y_co2_spatial, n_forced_latents_co2)) + self.w_aerosol = nn.Parameter(torch.randn(self.d_y_aerosol_spatial, n_forced_latents_aerosol)) + def get_w_encoder(self): if self.tied: return torch.transpose(self.w, 1, 2) @@ -837,7 +1197,55 @@ def get_w_encoder(self): def get_w_decoder(self): return self.w - def encode(self, x, i): + def get_w_co2(self): + if self.use_forced_latents: + return self.w_co2.detach() + return None + + def get_w_aerosol(self): + if self.use_forced_latents: + return self.w_aerosol.detach() + return None + + def encode_forcings(self, forcing_co2, forcing_aerosol): + """Encode forcings into latent representations using linear projections.""" + batch_size = forcing_co2.shape[0] + device = forcing_co2.device + + co2_mu = self.co2_forcing_encoder_mu(forcing_co2) + co2_std = torch.exp(0.5 * self.co2_forcing_encoder_logvar).expand(batch_size, -1).to(device) + + aerosol_mu = self.aerosol_forcing_encoder_mu(forcing_aerosol) + aerosol_std = torch.exp(0.5 * self.aerosol_forcing_encoder_logvar).expand(batch_size, -1).to(device) + + co2_z = co2_mu + co2_std * torch.randn_like(co2_std) + aerosol_z = aerosol_mu + aerosol_std * torch.randn_like(aerosol_std) + + z_forced = torch.cat([co2_z, aerosol_z], dim=1) + mu_forced = torch.cat([co2_mu, aerosol_mu], dim=1) + std_forced = torch.cat([co2_std, aerosol_std], dim=1) + + return z_forced, mu_forced, std_forced + + def decode_co2_forcing(self, z_co2): + """Decode CO2 forcing latents using spatial weights.""" + z_expanded = z_co2.unsqueeze(1).expand(-1, self.d_y_co2_spatial, -1) + z_masked = z_expanded * self.w_co2.unsqueeze(0) + return z_masked.sum(dim=-1) + + def decode_aerosol_forcing(self, z_aerosol): + """Decode aerosol forcing latents using spatial weights.""" + z_expanded = z_aerosol.unsqueeze(1).expand(-1, self.d_y_aerosol_spatial, -1) + z_masked = z_expanded * self.w_aerosol.unsqueeze(0) + return z_masked.sum(dim=-1) + + def decode_forcings(self, z_forced_latents): + """Decode forcing latents back to forcing space for reconstruction loss.""" + co2_latents = z_forced_latents[:, : self.n_forced_latents_co2] + aerosol_latents = z_forced_latents[:, self.n_forced_latents_co2 :] + return self.decode_co2_forcing(co2_latents), self.decode_aerosol_forcing(aerosol_latents) + + def encode(self, x, i, forcing_co2=None, forcing_aerosol=None): if self.tied: w = self.w[i].T else: @@ -845,16 +1253,20 @@ def encode(self, x, i): mu = torch.matmul(x, w.T) return mu, self.logvar_encoder - def decode(self, z, i): + def decode(self, z, i, forcing_co2=None, forcing_aerosol=None): w = self.w[i] + # When using forced latents, z only contains climate latents (sliced by caller). + # Use only the corresponding climate columns of w. + if self.use_forced_latents and z.shape[-1] < w.shape[-1]: + w = w[:, : z.shape[-1]] mu = torch.matmul(z, w.T) return mu, self.logvar_decoder - def forward(self, x, i, encode: bool = False): + def forward(self, x, i, encode: bool = False, forcing_co2=None, forcing_aerosol=None): if encode: - return self.encode(x, i) + return self.encode(x, i, forcing_co2, forcing_aerosol) else: - return self.decode(x, i) + return self.decode(x, i, forcing_co2, forcing_aerosol) class NonLinearAutoEncoder(nn.Module): @@ -883,6 +1295,20 @@ def get_w_encoder(self): def get_w_decoder(self): return self.w + def get_w_co2(self): + """Get CO2 forcing decoder spatial weights.""" + if self.use_forced_latents: + return self.w_co2.detach() + else: + return None + + def get_w_aerosol(self): + """Get aerosol forcing decoder spatial weights.""" + if self.use_forced_latents: + return self.w_aerosol.detach() + else: + return None + def get_encode_mask(self): if self.tied: return torch.transpose(self.w, 1, 2) @@ -909,16 +1335,179 @@ def __init__( num_layer, tied, embedding_dim, + reduce_encoding_pos_dim=False, gt_w=None, + d_y_co2=0, + d_y_aerosol=0, + use_forced_latents=False, + n_forced_latents_co2=1, + n_forced_latents_aerosol=4, + d_y_co2_spatial=None, + d_y_aerosol_spatial=None, ): super().__init__(d, d_x, d_z, num_hidden, num_layer, tied, gt_w) - self.embedding_encoder = nn.Embedding(d_z, embedding_dim) - self.encoder = MLP(num_layer, num_hidden, d_x + embedding_dim, 1) # embedding_dim_encoding + self.d_y_co2 = d_y_co2 + self.d_y_aerosol = d_y_aerosol + self.reduce_encoding_pos_dim = reduce_encoding_pos_dim + # Spatial dims for forcing encoders/decoders (independent of MLP conditioning dims). + # When use_exogenous=False but use_forced_latents=True, d_y_co2 may be 0 + # (no MLP conditioning) while d_y_co2_spatial carries the real spatial dimension. + self.d_y_co2_spatial = d_y_co2_spatial if d_y_co2_spatial is not None else d_y_co2 + self.d_y_aerosol_spatial = d_y_aerosol_spatial if d_y_aerosol_spatial is not None else d_y_aerosol + self.use_forced_latents = use_forced_latents + self.n_forced_latents_co2 = n_forced_latents_co2 + self.n_forced_latents_aerosol = n_forced_latents_aerosol + + # embedding_dim_encoding = d_z // 10 + if not self.reduce_encoding_pos_dim: + self.embedding_encoder = nn.Embedding(d_z, embedding_dim) + self.encoder = MLP( + num_layer, num_hidden, d_x + embedding_dim + d_y_co2 + d_y_aerosol, 1 + ) # embedding_dim_encoding + else: + self.embedding_encoder = nn.Embedding(d_z, embedding_dim // 10) + self.encoder = MLP( + num_layer, num_hidden, d_x + embedding_dim // 10 + d_y_co2 + d_y_aerosol, 1 + ) # embedding_dim_encoding + # self.encoder = MLP(num_layer, num_hidden, d_x + embedding_dim, 1) + # self.embedding_encoder = nn.Embedding(d_z, embedding_dim) - self.decoder = MLP(num_layer, num_hidden, d_z + embedding_dim, 1) self.embedding_decoder = nn.Embedding(d_x, embedding_dim) - def encode(self, x, i): + # Create climate-only decoder if using forced latents + if use_forced_latents: + # Climate-only decoder: only climate latents go through this decoder + n_climate_latents = d_z - n_forced_latents_co2 - n_forced_latents_aerosol + self.n_climate_latents = n_climate_latents + self.decoder = MLP(num_layer, num_hidden, n_climate_latents + embedding_dim + d_y_co2 + d_y_aerosol, 1) + else: + # Original decoder: all latents + self.decoder = MLP(num_layer, num_hidden, d_z + embedding_dim + d_y_co2 + d_y_aerosol, 1) + self.n_climate_latents = d_z + + # Add forcing encoders if using forced latents (use real spatial dims, not MLP conditioning dims) + if self.use_forced_latents: + # CO2 forcing encoder: maps spatial CO2 to n_forced_latents_co2 latent means + logvars + self.co2_forcing_encoder_mu = MLP(num_layer, num_hidden, self.d_y_co2_spatial, n_forced_latents_co2) + self.co2_forcing_encoder_logvar = nn.Parameter(torch.ones(n_forced_latents_co2) * -1) + + # Aerosol forcing encoder: maps spatial aerosols to n_forced_latents_aerosol latent means + logvars + self.aerosol_forcing_encoder_mu = MLP( + num_layer, num_hidden, self.d_y_aerosol_spatial, n_forced_latents_aerosol + ) + self.aerosol_forcing_encoder_logvar = nn.Parameter(torch.ones(n_forced_latents_aerosol) * -1) + + # Forcing decoder weights: spatial mask (like climate decoder) + self.w_co2 = nn.Parameter(torch.randn(self.d_y_co2_spatial, n_forced_latents_co2)) + self.w_aerosol = nn.Parameter(torch.randn(self.d_y_aerosol_spatial, n_forced_latents_aerosol)) + + def encode_forcings(self, forcing_co2, forcing_aerosol): + """ + Encode forcings directly into latent representations. + + Args: + forcing_co2: CO2 forcing, shape (batch_size, d_y_co2) + forcing_aerosol: Aerosol forcing, shape (batch_size, d_y_aerosol) + + Returns: + z_forced: Forced latents, shape (batch_size, n_forced_latents_co2 + n_forced_latents_aerosol) + mu_forced: Means of forced latents + std_forced: Stds of forced latents + """ + batch_size = forcing_co2.shape[0] + device = forcing_co2.device + + # Encode CO2 forcing + co2_mu = self.co2_forcing_encoder_mu(forcing_co2) # (batch_size, n_forced_latents_co2) + co2_std = torch.exp(0.5 * self.co2_forcing_encoder_logvar).expand(batch_size, -1).to(device) + + # Encode aerosol forcing + aerosol_mu = self.aerosol_forcing_encoder_mu(forcing_aerosol) # (batch_size, n_forced_latents_aerosol) + aerosol_std = torch.exp(0.5 * self.aerosol_forcing_encoder_logvar).expand(batch_size, -1).to(device) + + # Sample forced latents using reparameterization trick + co2_z = co2_mu + co2_std * torch.randn_like(co2_std) + aerosol_z = aerosol_mu + aerosol_std * torch.randn_like(aerosol_std) + + # Concatenate forced latents + z_forced = torch.cat([co2_z, aerosol_z], dim=1) + mu_forced = torch.cat([co2_mu, aerosol_mu], dim=1) + std_forced = torch.cat([co2_std, aerosol_std], dim=1) + + return z_forced, mu_forced, std_forced + + def decode_co2_forcing(self, z_co2): + """ + Decode CO2 forcing latents using spatial mask (like climate decoder). + + Args: + z_co2: CO2 latents, shape (batch, n_forced_latents_co2) # (batch, 1) + + Returns: + co2_recons: Reconstructed CO2 field, shape (batch, d_y_co2) # (batch, 400) + """ + # z_co2: (batch, 1) + # self.w_co2: (400, 1) + # Output: (batch, 400) + + # Expand z to (batch, 400, 1) + z_expanded = z_co2.unsqueeze(1).expand(-1, self.d_y_co2_spatial, -1) + + # Apply mask: element-wise multiply with w_co2 + z_masked = z_expanded * self.w_co2.unsqueeze(0) # (batch, 400, 1) + + # Sum over latents (linear combination) + co2_recons = z_masked.sum(dim=-1) # (batch, 400) + + return co2_recons + + def decode_aerosol_forcing(self, z_aerosol): + """ + Decode aerosol forcing latents using spatial mask (like climate decoder). + + Args: + z_aerosol: Aerosol latents, shape (batch, n_forced_latents_aerosol) # (batch, 4) + + Returns: + aerosol_recons: Reconstructed aerosol field, shape (batch, d_y_aerosol) # (batch, 400) + """ + # z_aerosol: (batch, 4) + # self.w_aerosol: (400, 4) + # Output: (batch, 400) + + # Expand z to (batch, 400, 4) + z_expanded = z_aerosol.unsqueeze(1).expand(-1, self.d_y_aerosol_spatial, -1) + + # Apply mask: element-wise multiply with w_aerosol + z_masked = z_expanded * self.w_aerosol.unsqueeze(0) # (batch, 400, 4) + + # Sum over latents (linear combination) + aerosol_recons = z_masked.sum(dim=-1) # (batch, 400) + + return aerosol_recons + + def decode_forcings(self, z_forced_latents): + """ + Decode forcing latents back to forcing space for reconstruction loss. + + Args: + z_forced_latents: Forced latents, shape (batch_size, n_forced_latents_co2 + n_forced_latents_aerosol) + + Returns: + forcing_co2_recons: Reconstructed CO2 forcing, shape (batch_size, d_y_co2) + forcing_aerosol_recons: Reconstructed aerosol forcing, shape (batch_size, d_y_aerosol) + """ + # Split forced latents into CO2 and aerosol components + co2_latents = z_forced_latents[:, : self.n_forced_latents_co2] + aerosol_latents = z_forced_latents[:, self.n_forced_latents_co2 :] + + # Decode to forcing space using spatial mask decoders + forcing_co2_recons = self.decode_co2_forcing(co2_latents) + forcing_aerosol_recons = self.decode_aerosol_forcing(aerosol_latents) + + return forcing_co2_recons, forcing_aerosol_recons + + def encode(self, x, i, forcing_co2=None, forcing_aerosol=None): mask = super().get_encode_mask(x.shape[0]) mu = torch.zeros((x.shape[0], self.d_z), device=x.device) @@ -937,9 +1526,14 @@ def encode(self, x, i): # each location create a lask in latents b * d_z * d_x # Then concatenate in the last axis (d_x) with the embedding of the latents? # x_ = mask_ * x.unsqueeze(1) - x_ = torch.cat( - (mask_ * x.unsqueeze(1), embedded_x), dim=2 - ) # expand dimensions of x for broadcasting - looks good. + if forcing_co2 is not None and forcing_aerosol is not None: + forcing_co2_expanded = forcing_co2.unsqueeze(1).expand(-1, self.d_z, -1) + forcing_aerosol_expanded = forcing_aerosol.unsqueeze(1).expand(-1, self.d_z, -1) + x_ = torch.cat((mask_ * x.unsqueeze(1), embedded_x, forcing_co2_expanded, forcing_aerosol_expanded), dim=2) + else: + x_ = torch.cat( + (mask_ * x.unsqueeze(1), embedded_x), dim=2 + ) # expand dimensions of x for broadcasting - looks good. del embedded_x del mask_ @@ -948,8 +1542,22 @@ def encode(self, x, i): return mu, self.logvar_encoder - def decode(self, z, i): + def decode(self, z, i, forcing_co2=None, forcing_aerosol=None): + """ + Decode latents to observations. + + When use_forced_latents=True, this method expects ONLY climate latents + (the caller should slice z to only include climate latents). + Raw forcing fields (forcing_co2, forcing_aerosol) are still used as + conditioning inputs to the MLP but don't go through the learnable mask. + Args: + z: Latent variables. If use_forced_latents, should be climate latents only + with shape (batch, n_climate_latents). Otherwise (batch, d_z). + i: Feature index + forcing_co2: Raw CO2 forcing field for conditioning (NOT forcing latent!) + forcing_aerosol: Raw aerosol forcing field for conditioning (NOT forcing latent!) + """ mask = super().get_decode_mask(z.shape[0]) mu = torch.zeros((z.shape[0], self.d_x), device=z.device) @@ -959,18 +1567,32 @@ def decode(self, z, i): # Embed all j_values at once embedded_z = self.embedding_decoder(j_values) - # Select all decoder masks at once + # Select decoder masks - only use climate portion of mask mask_ = super().select_decoder_mask(mask, i, j_values) - if z.ndim < mask_.ndim: + # Only use climate latents (first n_climate columns of mask) + n_climate = self.n_climate_latents + if mask_.dim() == 3: + # mask_ shape: (batch, d_x, d_z) -> slice to (batch, d_x, n_climate) + mask_climate = mask_[:, :, :n_climate] + else: + # mask_ shape: (d_x, d_z) -> slice to (d_x, n_climate) + mask_climate = mask_[:, :n_climate] + + if z.ndim < mask_climate.ndim: z_expanded = z.unsqueeze(1).expand(-1, self.d_x, -1) else: z_expanded = z.expand(-1, self.d_x, -1) z_expanded_copy = z_expanded.clone() - z_expanded_copy.mul_(mask_) - z_expanded_copy.unsqueeze(2) + z_expanded_copy.mul_(mask_climate) - z_ = torch.cat((z_expanded_copy, embedded_z), dim=2) + # Raw forcing fields as conditioning (correct design - doesn't go through mask) + if forcing_co2 is not None and forcing_aerosol is not None: + forcing_co2_expanded = forcing_co2.unsqueeze(1).expand(-1, self.d_x, -1) + forcing_aerosol_expanded = forcing_aerosol.unsqueeze(1).expand(-1, self.d_x, -1) + z_ = torch.cat((z_expanded_copy, embedded_z, forcing_co2_expanded, forcing_aerosol_expanded), dim=2) + else: + z_ = torch.cat((z_expanded_copy, embedded_z), dim=2) del z_expanded del z_expanded_copy @@ -980,11 +1602,11 @@ def decode(self, z, i): return mu, self.logvar_decoder - def forward(self, x, i, encode: bool = False): + def forward(self, x, i, encode: bool = False, forcing_co2=None, forcing_aerosol=None): if encode: - return self.encode(x, i) + return self.encode(x, i, forcing_co2, forcing_aerosol) else: - return self.decode(x, i) + return self.decode(x, i, forcing_co2, forcing_aerosol) class TransitionModel(nn.Module): @@ -1027,10 +1649,10 @@ def __init__( # self.logvar = nn.Parameter(torch.ones(d) * -4) self.logvar = nn.Parameter(torch.ones(d, d_z) * -4) if self.nonlinear_dynamics: - print("NON LINEAR DYNAMICS") + logger.info("NON LINEAR DYNAMICS") self.nn = nn.ModuleList(MLP(num_layers, num_hidden, d * d_z * tau, self.num_output) for i in range(d * d_z)) else: - print("LINEAR DYNAMICS") + logger.info("LINEAR DYNAMICS") self.nn = nn.ModuleList(MLP(0, 0, d * d_z * tau, self.num_output) for i in range(d * d_z)) # self.nn = MLP(num_layers, num_hidden, d * k * k, self.num_output) @@ -1102,6 +1724,8 @@ def __init__( num_hidden: int, num_output: int = 2, embedding_dim: int = 100, + d_y_co2: int = 0, + d_y_aerosol: int = 0, ): """ Args: @@ -1114,6 +1738,8 @@ def __init__( """ super().__init__() + self.d_y_co2 = d_y_co2 + self.d_y_aerosol = d_y_aerosol self.d = d # number of variables self.d_z = d_z self.tau = tau @@ -1134,16 +1760,19 @@ def __init__( # self.logvar = nn.Parameter(torch.ones(d) * -4) self.logvar = nn.Parameter(torch.ones(d, d_z) * -4) if self.nonlinear_dynamics: - print("NON LINEAR DYNAMICS") + logger.info("NON LINEAR DYNAMICS") self.nn = nn.ModuleList( - MLP(num_layers, num_hidden, d * d_z * tau + embedding_dim, self.num_output) for i in range(d) + MLP(num_layers, num_hidden, d * d_z * tau + embedding_dim + d_y_co2 + d_y_aerosol, self.num_output) + for i in range(d) ) else: - print("LINEAR DYNAMICS") - self.nn = nn.ModuleList(MLP(0, 0, d * d_z * tau + embedding_dim, self.num_output) for i in range(d)) + logger.info("LINEAR DYNAMICS") + self.nn = nn.ModuleList( + MLP(0, 0, d * d_z * tau + embedding_dim + d_y_co2 + d_y_aerosol, self.num_output) for i in range(d) + ) # self.nn = MLP(num_layers, num_hidden, d * k * k, self.num_output) - def forward(self, z, mask, i): + def forward(self, z, mask, i, forcing_co2=None, forcing_aerosol=None): """Returns the params of N(z_t | z_{ 0 + + nrows = 1 if effective_no_gt else 3 + titles = ["Learned", "Ground Truth", "Difference"] + + axes = fig.subplots(nrows=nrows, ncols=1) + for row in range(nrows): + ax = axes if effective_no_gt else axes[row] + ax.set_title(titles[row]) + + if row == 0: + rgb = self._create_colored_rect_adjacency(mat1[0], n_climate, n_co2, n_aerosol) + ax.imshow(rgb, aspect="auto", interpolation="nearest") + elif row == 1: + rgb = self._create_colored_rect_adjacency(mat2_aligned[0][::-1], n_climate, n_co2, n_aerosol) + ax.imshow(rgb, aspect="auto", interpolation="nearest") + elif row == 2: + diff = mat1[0][:n_climate, : len(col_labels)] - mat2_aligned[0][::-1][:n_climate, : len(col_labels)] + sns.heatmap( + diff, + ax=ax, + cbar=False, + vmin=-1, + vmax=1, + cmap="RdBu_r", + center=0, + xticklabels=False, + yticklabels=False, + ) + + if row < 2: + ax.set_yticks(range(n_climate)) + ax.set_yticklabels(row_labels, fontsize=7) + ax.set_xticks(range(len(col_labels))) + ax.set_xticklabels(col_labels, fontsize=7, rotation=45) + if has_forcing: + ax.axvline(x=n_climate - 0.5, color="black", linewidth=1.5) + if n_co2 > 0 and n_aerosol > 0: + ax.axvline(x=n_climate + n_co2 - 0.5, color="black", linewidth=0.8, linestyle=":") + + def _get_forcing_layout(self, n_latents, forcing_indices=None): """ + Compute layout info from forcing_indices. - lat = learner.lat - lon = learner.lon - tau = mat1.shape[0] + Returns: + n_climate, n_co2, n_aerosol, row_labels (climate only), col_labels (all sources) + """ + if forcing_indices is not None: + co2_idx = forcing_indices.get("co2", []) + aerosol_idx = forcing_indices.get("aerosol", []) + n_co2 = len(co2_idx) + n_aerosol = len(aerosol_idx) + n_climate = n_latents - n_co2 - n_aerosol + else: + n_climate = n_latents + n_co2 = 0 + n_aerosol = 0 + + row_labels = [f"C{i+1}" for i in range(n_climate)] + col_labels = [f"C{i+1}" for i in range(n_climate)] + if n_co2 == 1: + col_labels.append("CO2") + else: + col_labels += [f"CO2_{i}" for i in range(n_co2)] + col_labels += [f"A{i+1}" for i in range(n_aerosol)] - if savar and modes_gt is not None and modes_inferred is not None: + return n_climate, n_co2, n_aerosol, row_labels, col_labels - mat1 = permute_matrices( - lat, - lon, - modes_inferred, - modes_gt, - mat1, - tau, - ) + def _create_colored_rect_adjacency(self, adj_matrix, n_climate, n_co2, n_aerosol): + """ + Create rectangular RGB adjacency image: rows=climate targets, cols=all sources. + + Colors: Climate->Climate = Blue, CO2->Climate = Red, Aerosol->Climate = Orange. + """ + n_total = n_climate + n_co2 + n_aerosol + n_rows = n_climate + n_cols = min(n_total, adj_matrix.shape[1]) + rect = adj_matrix[:n_rows, :n_cols] + + rgb = np.ones((n_rows, n_cols, 3)) + adj_norm = np.clip(np.abs(rect), 0, 1) + + for i in range(n_rows): + for j in range(n_cols): + val = adj_norm[i, j] + if val < 0.01: + continue + if j < n_climate: + rgb[i, j] = [1 - val, 1 - val, 1.0] # Blue + elif j < n_climate + n_co2: + rgb[i, j] = [1.0, 1 - val, 1 - val] # Red + else: + rgb[i, j] = [1.0, 0.65 + 0.35 * (1 - val), 1 - val] # Orange + + return rgb + + def _plot_adjacency_through_time(self, fig, mat1, mat2_aligned, effective_no_gt, tau, forcing_indices=None): + """Plot rectangular adjacency matrices: rows=climate targets, cols=all sources.""" + n_climate, n_co2, n_aerosol, row_labels, col_labels = self._get_forcing_layout(mat1.shape[1], forcing_indices) + has_forcing = n_co2 + n_aerosol > 0 subfig_names = [ - f"Learned, latent dimensions = {mat1.shape[1], mat1.shape[2]}", + f"Learned ({n_climate} climate" + + (f" + {n_co2} CO2 + {n_aerosol} aerosol" if n_co2 + n_aerosol > 0 else "") + + ")", "Ground Truth", - "Difference: Learned - GT", + "Difference", ] - fig = plt.figure(constrained_layout=True) - fig.suptitle("Adjacency matrices: learned vs ground-truth") + nrows = 1 if effective_no_gt else 3 - if no_gt: - nrows = 1 - else: - nrows = 3 + subfigs = fig.subfigures(nrows=nrows, ncols=1) + for row in range(nrows): + subfig = subfigs if nrows == 1 else subfigs[row] + subfig.suptitle(subfig_names[row]) + + axes = subfig.subplots(nrows=1, ncols=tau) + for i in range(tau): + ax = axes[i] + ax.set_title(f"t-{i+1}") + mat_idx = tau - i - 1 - if tau == 1: - axes = fig.subplots(nrows=nrows, ncols=1) - for row in range(nrows): - if no_gt: - ax = axes - else: - ax = axes[row] - # axes.set_title(f"t - {i+1}") if row == 0: - sns.heatmap( - mat1[0], ax=ax, cbar=False, vmin=-1, vmax=1, cmap="Blues", xticklabels=False, yticklabels=False - ) - elif row == 1: - sns.heatmap( - mat2[0][::-1], - ax=ax, - cbar=False, - vmin=-1, - vmax=1, - cmap="Blues", - xticklabels=False, - yticklabels=False, - ) - elif row == 2: - sns.heatmap( - mat1[0] - mat2[0][::-1], - ax=ax, - cbar=False, - vmin=-1, - vmax=1, - cmap="Blues", - xticklabels=False, - yticklabels=False, + rgb = self._create_colored_rect_adjacency(mat1[mat_idx], n_climate, n_co2, n_aerosol) + ax.imshow(rgb, aspect="auto", interpolation="nearest") + elif row == 1 and mat2_aligned is not None: + rgb = self._create_colored_rect_adjacency(mat2_aligned[mat_idx], n_climate, n_co2, n_aerosol) + ax.imshow(rgb, aspect="auto", interpolation="nearest") + elif row == 2 and mat2_aligned is not None: + diff = ( + mat1[mat_idx][:n_climate, : len(col_labels)] + - mat2_aligned[mat_idx][:n_climate, : len(col_labels)] ) + ax.imshow(diff, aspect="auto", interpolation="nearest", vmin=-1, vmax=1, cmap="RdBu_r") - else: - subfigs = fig.subfigures(nrows=nrows, ncols=1) - for row in range(nrows): - if nrows == 1: - subfig = subfigs + # Labels and separator lines + if i == 0: + ax.set_yticks(range(n_climate)) + ax.set_yticklabels(row_labels, fontsize=7) else: - subfig = subfigs[row] - subfig.suptitle(f"{subfig_names[row]}") + ax.set_yticks([]) + ax.set_xticks(range(len(col_labels))) + ax.set_xticklabels(col_labels, fontsize=6, rotation=45) + # Separator between climate and forcing columns + if has_forcing: + ax.axvline(x=n_climate - 0.5, color="black", linewidth=1.5) + if n_co2 > 0 and n_aerosol > 0: + ax.axvline(x=n_climate + n_co2 - 0.5, color="black", linewidth=0.8, linestyle=":") + + legend_elements = [ + Patch(facecolor="blue", edgecolor="black", label="Climate \u2192 Climate"), + Patch(facecolor="red", edgecolor="black", label="CO2 \u2192 Climate"), + Patch(facecolor="orange", edgecolor="black", label="Aerosol \u2192 Climate"), + ] + fig.legend(handles=legend_elements, loc="upper right", fontsize=8, framealpha=0.9) - axes = subfig.subplots(nrows=1, ncols=tau) - for i in range(tau): - axes[i].set_title(f"t - {i+1}") - if row == 0: - sns.heatmap( - mat1[tau - i - 1], - ax=axes[i], - cbar=False, - vmin=-1, - vmax=1, - cmap="Blues", - xticklabels=False, - yticklabels=False, - ) - # add a horizontal line every 50 columns - for j in range(0, mat1.shape[1], 50): - axes[i].axhline(y=j, color="black", linewidth=0.4) - # add a vertical line every 50 columns - for j in range(0, mat1.shape[1], 50): - axes[i].axvline(x=j, color="black", linewidth=0.4) + def plot_adjacency_matrix( + self, + learner=None, + mat1: np.ndarray = None, + mat2: np.ndarray = None, + modes_gt=None, + modes_inferred=None, + path=None, + name_suffix: str = "", + savar: bool = False, + no_gt: bool = False, + iteration: int = 0, + plot_through_time: bool = True, + ): + """ + Plot the adjacency matrices learned and compare it to the ground truth. - elif row == 1: - sns.heatmap( - mat2[i], - ax=axes[i], - cbar=False, - vmin=-1, - vmax=1, - cmap="Blues", - xticklabels=False, - yticklabels=False, - ) - elif row == 2: - sns.heatmap( - mat1[tau - i - 1] - mat2[i], - ax=axes[i], - cbar=False, - vmin=-1, - vmax=1, - cmap="Blues", - xticklabels=False, - yticklabels=False, - ) + The first dimension of the matrix should be the time (tau). + + Args: + mat1: learned adjacency matrices + mat2: ground-truth adjacency matrices + path: path where to save the plot + name_suffix: suffix for the name of the plot + no_gt: if True, does not use the ground-truth graph + """ + mat1 = np.array(mat1, copy=True) + tau = mat1.shape[0] + effective_no_gt = no_gt or mat2 is None + forcing_indices = None + if learner is not None and hasattr(learner, "datamodule"): + forcing_indices = getattr(learner.datamodule, "forcing_indices", None) + + mat2_aligned = None + if mat2 is not None: + mat2 = np.array(mat2, copy=True) + if mat2.shape[0] >= tau: + mat2_aligned = mat2[-tau:] + else: + mat2_aligned = mat2 + + if savar and modes_gt is not None and modes_inferred is not None: + lat = None + lon = None + if learner is not None and hasattr(learner, "exp_params"): + lat = getattr(learner.exp_params, "lat", None) + lon = getattr(learner.exp_params, "lon", None) + if lat is not None and lon is not None: + mat1 = permute_matrices( + lat, + lon, + modes_inferred, + modes_gt, + mat1, + tau, + ) + else: + logger.warning("Skipping SAVAR permutation in adjacency plot: lat/lon not available.") + + # Create figure + fig = plt.figure(constrained_layout=True) + fig.suptitle("Adjacency matrices: learned vs ground-truth") + + # Plot based on number of timesteps + if tau == 1: + self._plot_adjacency_single_time(fig, mat1, mat2_aligned, effective_no_gt, forcing_indices) + else: + self._plot_adjacency_through_time(fig, mat1, mat2_aligned, effective_no_gt, tau, forcing_indices) # new if plot_through_time: @@ -1403,6 +1405,11 @@ def plot_adjacency_through_time_w(self, w_adj: np.ndarray, gt_dag: np.ndarray, t fig.clf() def save_mcc_and_assignement(self, exp_path): + """ + Save MCC score history and latent assignment history to disk, and plot MCC over time. + + Saves mcc.npy, assignments.npy, and mcc.png. + """ np.save(exp_path / "mcc", np.array(self.mcc)) np.save(exp_path / "assignments", np.array(self.assignments)) if len(self.mcc) > 1: @@ -1411,180 +1418,3 @@ def save_mcc_and_assignement(self, exp_path): plt.title("MCC score through time") fig.savefig(exp_path / "mcc.png") fig.clf() - - def plot_original_savar(self, data, lat, lon, path): - """Plotting the original savar data.""" - print(f"data shape {data.shape}") - # Get the dimensions - time_steps = data.shape[1] - data_reshaped = data.T.reshape((time_steps, lat, lon)) - - # Calculate the average over the time axis - avg_data = np.mean(data_reshaped, axis=0) - - # Determine the global min and max from the averaged data for consistent color scaling - vmin = np.min(avg_data) - vmax = np.max(avg_data) - - fig, ax = plt.subplots(figsize=(lon / 10, lat / 10)) - cax = ax.imshow(data_reshaped[0], aspect="auto", cmap="viridis", vmin=vmin, vmax=vmax) - # cbar = fig.colorbar(cax, ax=ax) - - def animate(i): - cax.set_data(data_reshaped[i]) - ax.set_title(f"Time step: {i+1}") - return (cax,) - - # Create an animation - ani = animation.FuncAnimation(fig, animate, frames=100, blit=True) - - # Save the animation as a video file - ani.save(path, writer="pillow", fps=10) - - plt.close() - - # # Below are functions used for plotting savar results / metrics. Not used yet but could be useful / integrated into the savar pipeline - - # def plot_original_savar(self, data, lon, lat, path): - # """Plotting the original savar data.""" - # print(f"data shape {data.shape}") - # # Get the dimensions - # time_steps = data.shape[0] - # data_reshaped = data.T.reshape((time_steps, lat, lon)) - - # # Calculate the average over the time axis - # avg_data = np.mean(data_reshaped, axis=0) - - # # Determine the global min and max from the averaged data for consistent color scaling - # vmin = np.min(avg_data) - # vmax = np.max(avg_data) - - # fig, ax = plt.subplots(figsize=(lon / 10, lat / 10)) - # cax = ax.imshow(data_reshaped[0], aspect="auto", cmap="viridis", vmin=vmin, vmax=vmax) - # cbar = fig.colorbar(cax, ax=ax) - - # def animate(i): - # cax.set_data(data_reshaped[i]) - # ax.set_title(f"Time step: {i+1}") - # return (cax,) - - # # Create an animation - # ani = animation.FuncAnimation(fig, animate, frames=100, blit=True) - - # fname = "original_savar_data.gif" - # # Save the animation as a video file - # ani.save(os.path.join(path, fname), writer="pillow", fps=10) - - # plt.close() - - # def compute_time_averaged_pixel_error(self, learner, cdsd_data, savar_data, iteration, path): - # """ - # Computes the pixel error between time-averaged SAVAR ground truth and reconstructed CDSD latent variables. - - # Args: - # cdsd_data (numpy.ndarray): CDSD latent variables of shape (1, lon*lat, d_z). - # savar_data (numpy.ndarray): SAVAR ground truth data of shape (time_steps, lat, lon). - - # Returns: - # float: The mean squared error between time-averaged SAVAR and reconstructed CDSD. - # """ - # # Step 1: Time-average the SAVAR data over time_steps - # savar_avg = np.mean(savar_data, axis=0) # Shape becomes (lat, lon) - - # # Step 2: Reshape cdsd_data to (lat, lon, d_z) based on savar spatial dimensions - # lat, lon = savar_avg.shape - # d_z = cdsd_data.shape[2] - - # # Assuming lon*lat matches the savar grid - # cdsd_reshaped = cdsd_data.reshape(lat, lon, d_z) # Shape becomes (lat, lon, d_z) - - # # Step 3: Reconstruct CDSD by summing over the latent dimension (d_z) - # cdsd_reconstructed = np.sum(cdsd_reshaped, axis=2) # Shape becomes (lat, lon) - - # # Step 4: Compute pixel-wise error (Mean Squared Error) - # pixel_error = np.mean((cdsd_reconstructed - savar_avg) ** 2) - - # print(f"Pixel error: {pixel_error}") - - # combined_min = min(np.min(cdsd_reconstructed), np.min(savar_avg)) - # combined_max = max(np.max(cdsd_reconstructed), np.max(savar_avg)) - - # # Step 5: Plot both the reconstructed CDSD data and the time-averaged SAVAR data - # fig, axes = plt.subplots(1, 2, figsize=(learner.hp.compute_pixel_figsize_x, learner.hp.compute_pixel_figsize_y)) - - # # Plot the reconstructed CDSD data - # im1 = axes[0].imshow(cdsd_reconstructed, cmap="viridis", aspect="auto", vmin=combined_min, vmax=combined_max) - # axes[0].set_title(f"Reconstructed CDSD Data (Pixel Error: {pixel_error:.4f})") - # axes[0].set_xlabel("Longitude") - # axes[0].set_ylabel("Latitude") - # plt.colorbar(im1, ax=axes[0], label="Reconstructed Value") - - # # Plot the time-averaged SAVAR data - # im2 = axes[1].imshow(savar_avg, cmap="viridis", aspect="auto", vmin=combined_min, vmax=combined_max) - # axes[1].set_title("Time-Averaged SAVAR Data") - # axes[1].set_xlabel("Longitude") - # axes[1].set_ylabel("Latitude") - # plt.colorbar(im2, ax=axes[1], label="SAVAR Value") - # plt.tight_layout() - - # fname = f"cdsd_reconstructed_{iteration}.png" - - # plt.savefig(os.path.join(path, fname)) - # plt.close() - - # return pixel_error - - # def calculate_mcc_with_savar(self, cdsd_data, savar_data): - # """ - # Calculates the Mean Correlation Coefficient (MCC) between discovered latents (CDSD) and ground truth SAVAR data, - # where SAVAR data is reshaped and projected into the same number of latents as the CDSD discovered data. - - # Args: - # cdsd_latents (numpy array): Discovered latent variables from CDSD with shape (n_samples, n_latents). - # savar_data (numpy array): Ground-truth SAVAR data with shape (time_steps, longitude, latitude). - # num_latents (int): The number of latent variables (e.g., 3 in your case). - - # Returns: - # float: The Mean Correlation Coefficient (MCC) between the CDSD latents and projected SAVAR latents. - # """ - # num_latents = cdsd_data.shape[2] - - # # Reshape SAVAR data from (time_steps, longitude, latitude) to (time_steps, longitude * latitude) - # time_steps, lat, lon = savar_data.shape - # savar_data_reshaped = savar_data.reshape(time_steps, lon * lat) - - # # Apply ICA - # ica = FastICA(n_components=num_latents) - # savar_latents = ica.fit_transform(savar_data_reshaped.T).T - # print(savar_latents.shape) - - # # Now, reshape the latents back into (time_steps, lat, lon, num_latents) - # savar_latents_reshaped = savar_latents.reshape(num_latents, lat, lon) - - # for i in range(num_latents): - # plt.figure(figsize=(6, 6)) # Create a new figure for each latent - # latent_component = savar_latents_reshaped[i] # Shape: (lat, lon) - # plt.imshow(latent_component, cmap="viridis", aspect="auto") - # plt.title(f"Latent {i + 1} after PCA") - # plt.colorbar() - # plt.show() # Show each plot separately - - # # Ensure CDSD latents and SAVAR latents have the same shape - # assert cdsd_data.shape == savar_latents.shape, "CDSD and SAVAR latent representations must have the same shape" - - # # Number of latent variables - # n_latents = cdsd_data.shape[1] - - # # Compute the correlation matrix between each latent variable of CDSD and SAVAR - # correlation_matrix = np.corrcoef(cdsd_data, savar_latents, rowvar=False)[:n_latents, n_latents:] - - # # Use the Hungarian algorithm to find the best matching between CDSD and SAVAR latents - # row_ind, col_ind = linear_sum_assignment(-np.abs(correlation_matrix)) - - # # Extract the corresponding correlations - # matched_correlations = correlation_matrix[row_ind, col_ind] - - # # Calculate the Mean Correlation Coefficient (MCC) - # mcc = np.mean(np.abs(matched_correlations)) - - # return mcc diff --git a/climatem/plotting/plot_savar_output.py b/climatem/plotting/plot_savar_output.py new file mode 100644 index 0000000..1317ada --- /dev/null +++ b/climatem/plotting/plot_savar_output.py @@ -0,0 +1,2989 @@ +"""Visualization suite for SAVAR synthetic data: plots modes, causal graphs, time series, forcing trajectories, and evaluation metrics.""" + +from pathlib import Path +from typing import Dict, List, Optional + +import matplotlib.animation as animation +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.patches import Patch +from mpl_toolkits.axes_grid1 import make_axes_locatable +from scipy import signal + +from climatem.plotting.plot_model_output import Plotter +from climatem.synthetic_data.utils import permute_matrices +from climatem.utils import get_logger + +# Optional tigramite import for transfer entropy +try: + from tigramite import data_processing as pp + from tigramite import plotting as tp + from tigramite.independence_tests.parcorr import ParCorr + from tigramite.pcmci import PCMCI + + TIGRAMITE_AVAILABLE = True +except ImportError: + TIGRAMITE_AVAILABLE = False + +logger = get_logger(__name__) + + +class SavarPlotter(Plotter): + """ + Specialized plotter for SAVAR synthetic data experiments. + + Inherits from the base Plotter class and adds SAVAR-specific visualization methods including feature map plotting, + adjacency matrix alignment, and forcing diagnostics. + """ + + def __init__(self): + """ + Initialize SavarPlotter, inheriting MCC/assignment tracking from Plotter. + + No additional state is stored; SAVAR-specific context is loaded lazily via prepare_savar_context. + """ + super().__init__() + logger.info("Initialized SavarPlotter for synthetic data visualization") + + def prepare_savar_context(self, learner): + """Load the SAVAR ground-truth artifacts needed for plotting.""" + if not learner.plot_params.savar: + return None + + savar_folder = learner.data_params.data_dir + savar_params = learner.savar_params + savar_fname = ( + f"m_{savar_params.n_per_col**2}_tl_{savar_params.time_len}_ifd_{savar_params.is_forced}_dif_{savar_params.difficulty}_ns_" + f"{savar_params.noise_val}_ses_{savar_params.seasonality}_ol_{savar_params.overlap}_f1_{savar_params.f_1}_f2_{savar_params.f_2}" + f"_ft1_{savar_params.f_time_1}_ft2_{savar_params.f_time_2}_rmp_{savar_params.ramp_type}_lin_{savar_params.linearity}" + f"_pds_{savar_params.poly_degrees}_asp_{savar_params.aerosol_scale}_asc_{savar_params.aerosol_spatial_contrast}" + f"_art_{savar_params.aerosol_ramp_up_time}_apt_{savar_params.aerosol_peak_time}_adt_{savar_params.aerosol_decline_time}" + ) + + savar_dataset_dir = Path(savar_folder) / savar_fname + + # --- Load core SAVAR data --- + modes_gt = np.load(savar_dataset_dir / "modes.npy") + savar_data = np.load(savar_dataset_dir / "savar.npy") # (spatial, time) + + # --- Initialize GT holders --- + learner.co2_gt_spatial = None + learner.aerosol_gt_spatial = None + learner.aerosol_gt_templates = None # List of separate spatial templates (one per aerosol latent) + learner.forcing_indices = None + + co2_forcing = None + aerosol_forcing = None + + # --- Load forcing ground truth --- + if savar_params.is_forced: + co2_path = savar_dataset_dir / "co2_forcing.npy" + aerosol_path = savar_dataset_dir / "aerosol_forcing.npy" + aerosol_templates_path = savar_dataset_dir / "aerosol_spatial_templates.npy" + + if co2_path.exists(): + co2_forcing = np.load(co2_path) + learner.co2_gt_spatial = co2_forcing.mean(axis=1).reshape(learner.lat, learner.lon) + + if aerosol_path.exists(): + aerosol_forcing = np.load(aerosol_path) + learner.aerosol_gt_spatial = aerosol_forcing.mean(axis=1).reshape(learner.lat, learner.lon) + + # Load separate aerosol spatial templates (one per aerosol latent) + if aerosol_templates_path.exists(): + templates = np.load(aerosol_templates_path) # Shape: (n_aerosol, spatial_resolution) + n_aerosol = savar_params.n_aerosol_latents + # Only load as many templates as there are aerosol latents + n_to_load = min(n_aerosol, templates.shape[0]) + learner.aerosol_gt_templates = [ + templates[i].reshape(learner.lat, learner.lon) for i in range(n_to_load) + ] + logger.info( + f"Loaded {len(learner.aerosol_gt_templates)} aerosol spatial templates " + f"(n_aerosol_latents={n_aerosol}, file has {templates.shape[0]} templates)" + ) + + if hasattr(learner, "datamodule"): + if hasattr(learner.datamodule, "forcing_indices") and learner.datamodule.forcing_indices is not None: + learner.forcing_indices = learner.datamodule.forcing_indices + elif hasattr(learner.datamodule, "savar") and hasattr(learner.datamodule.savar, "forcing_indices"): + learner.forcing_indices = learner.datamodule.savar.forcing_indices + + # --- Diagnostic plots (run once, first time context is prepared) --- + if ( + not getattr(learner, "_savar_gt_plots_done", False) + and hasattr(learner, "datamodule") + and hasattr(learner.datamodule, "savar") + and learner.datamodule.savar is not None + ): + savar = learner.datamodule.savar + + # Signal-noise range plots + # Get deterministic data from SAVAR object (generated without noise) + deterministic_component = getattr(savar, "deterministic_data_field", None) + logger.info(f"SAVAR data shape: {savar_data.shape}") + logger.info(f"Deterministic component: {deterministic_component is not None}") + + if deterministic_component is not None: + logger.info(f"Deterministic component shape: {deterministic_component.shape}") + logger.info( + f"Deterministic min/max: {deterministic_component.min():.4f} / {deterministic_component.max():.4f}" + ) + + # Compute noise as (full_data - deterministic_data) + if deterministic_component.shape == savar_data.shape: + noise_component = savar_data - deterministic_component + logger.info( + f"✓ Computed noise component from (data - deterministic), shape: {noise_component.shape}" + ) + logger.info(f"Noise min/max: {noise_component.min():.4f} / {noise_component.max():.4f}") + else: + logger.warning( + f"Shape mismatch: savar_data {savar_data.shape} vs deterministic {deterministic_component.shape}. " + "Cannot compute noise component." + ) + noise_component = None + else: + logger.warning( + "deterministic_data_field not found in SAVAR object - cannot compute noise decomposition" + ) + noise_component = None + + # Get linearity information + linearity = getattr(savar, "linearity", "linear") + poly_degrees = getattr(savar, "poly_degrees", None) + + self.plot_savar_signal_noise_decomposition( + savar_data=savar_data, + deterministic_component=deterministic_component, + noise_component=noise_component, + path=savar_dataset_dir, + linearity=linearity, + poly_degrees=poly_degrees, + ) + + # Forcing diagnostics + if savar_params.is_forced: + # Comprehensive forcing diagnostics + self.plot_forcing_diagnostics( + savar_data=savar_data, + co2_forcing=co2_forcing, + aerosol_forcing=aerosol_forcing, + gt_co2_latent=getattr(savar, "co2_latent_trajectory", None), + gt_aerosol_latent=getattr(savar, "aerosol_latent_trajectory", None), + lat=learner.lat, + lon=learner.lon, + path=savar_dataset_dir, + ) + + # Simple forcing diagnostic plots + if co2_forcing is not None: + self.plot_forcing_diagnostic(co2_forcing, "CO2", savar_dataset_dir, learner.lat, learner.lon) + if aerosol_forcing is not None: + self.plot_forcing_diagnostic( + aerosol_forcing, "Aerosol", savar_dataset_dir, learner.lat, learner.lon + ) + + if hasattr(savar, "aerosol_latent_trajectory") and savar.aerosol_latent_trajectory is not None: + self.plot_aerosol_latent_trajectories( + savar.aerosol_latent_trajectory, + savar_dataset_dir, + ) + + learner._savar_gt_plots_done = True + + return modes_gt + + def plot_forcing_diagnostic(self, forcing_data, forcing_name, path, lat, lon): + """ + Create comprehensive forcing diagnostic plots. + + Args: + forcing_data: Forcing field of shape (spatial_resolution, time) + forcing_name: Name of forcing (e.g., "CO2", "Aerosol") + path: Save path (SAVAR data directory) + lat: Latitude dimension + lon: Longitude dimension + """ + if forcing_data is None: + return + + forcing_np = forcing_data # Shape: (spatial_resolution, time) + spatial_len = forcing_np.shape[0] + time_len = forcing_np.shape[1] + + # Use consistent color for each forcing type + color = "tab:red" if forcing_name == "CO2" else "tab:blue" + + logger.info(f"Creating comprehensive {forcing_name} forcing diagnostics") + + # 1. Spatial pattern plot (average over time) + spatial_pattern_avg = forcing_np.mean(axis=1) # Average over time + + fig, ax = plt.subplots(figsize=(10, 6)) + ax.plot(np.arange(spatial_len), spatial_pattern_avg, color=color, linewidth=2) + ax.set_title(f"{forcing_name} Forcing: Spatial Pattern (Time-Averaged)") + ax.set_xlabel("Grid point") + ax.set_ylabel("Relative intensity") + ax.grid(alpha=0.3) + fig.tight_layout() + filename = f"{forcing_name.lower()}_spatial_pattern.png" + fig.savefig(path / filename, dpi=150) + plt.close(fig) + logger.info(f"Saved {filename}") + + # 2. Heatmap over space and time + fig, ax = plt.subplots(figsize=(12, 8)) + im = ax.imshow(forcing_np, aspect="auto", interpolation="nearest", cmap="RdBu_r") + ax.set_title(f"{forcing_name} Forcing Over Space and Time") + ax.set_xlabel("Time step") + ax.set_ylabel("Grid point") + fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="Forcing magnitude") + fig.tight_layout() + filename = f"{forcing_name.lower()}_heatmap.png" + fig.savefig(path / filename, dpi=150) + plt.close(fig) + logger.info(f"Saved {filename}") + + # 3. Timeline with immediate and cumulative forcing + time_axis = np.arange(time_len) + mean_intensity = forcing_np.mean(axis=0) # Spatial average at each timestep + cumulative_intensity = np.cumsum(mean_intensity) + + fig, ax1 = plt.subplots(figsize=(12, 6)) + (line_immediate,) = ax1.plot( + time_axis, mean_intensity, color=color, linewidth=2, label=f"Immediate {forcing_name}" + ) + ax1.set_xlabel("Time step") + ax1.set_ylabel("Average intensity", color=color) + ax1.tick_params(axis="y", labelcolor=color) + ax1.grid(alpha=0.3) + + ax2 = ax1.twinx() + (line_cumulative,) = ax2.plot( + time_axis, cumulative_intensity, color="tab:orange", linewidth=2, label=f"Cumulative {forcing_name}" + ) + ax2.set_ylabel("Cumulative intensity", color="tab:orange") + ax2.tick_params(axis="y", labelcolor="tab:orange") + + ax1.set_title(f"{forcing_name} Forcing Timeline") + ax1.legend(handles=[line_immediate, line_cumulative], loc="upper left") + fig.tight_layout() + filename = f"{forcing_name.lower()}_timeline.png" + fig.savefig(path / filename, dpi=150) + plt.close(fig) + logger.info(f"Saved {filename}") + + # 4. Spatial pattern at peak forcing (2D grid view) + forcing_reshaped = forcing_np.T.reshape((time_len, lat, lon)) # (time, lat, lon) + + # Find peak time + peak_idx = np.argmax(np.abs(mean_intensity)) + + fig, ax = plt.subplots(figsize=(10, 8)) + im = ax.imshow(forcing_reshaped[peak_idx], cmap="RdBu_r", aspect="auto", origin="upper") + ax.set_title(f"{forcing_name} Forcing: Spatial Pattern at Peak (t={peak_idx})") + ax.set_xlabel("Longitude index") + ax.set_ylabel("Latitude index") + plt.colorbar(im, ax=ax, label=f"{forcing_name} magnitude") + fig.tight_layout() + filename = f"{forcing_name.lower()}_peak_spatial.png" + fig.savefig(path / filename, dpi=150) + plt.close(fig) + logger.info(f"Saved {filename}") + + logger.info(f"Completed comprehensive {forcing_name} forcing diagnostics") + + def plot_aerosol_latent_trajectories(self, aerosol_latent_traj, path, n_aerosol_latents=None): + """ + Plot individual aerosol latent trajectories to verify they have distinct temporal patterns. + + Args: + aerosol_latent_traj: Aerosol latent trajectories of shape (n_latents, time) + path: Save path + n_aerosol_latents: Number of aerosol latents (if None, inferred from data shape) + """ + if aerosol_latent_traj is None: + return + + n_latents = aerosol_latent_traj.shape[0] + if n_aerosol_latents is None: + n_aerosol_latents = n_latents + time_len = aerosol_latent_traj.shape[1] + time_axis = np.arange(time_len) + + # Plot 1: All latent trajectories on same axes + fig, ax = plt.subplots(figsize=(14, 6)) + colors = plt.cm.viridis(np.linspace(0, 1, n_latents)) + for i in range(n_latents): + ax.plot( + time_axis, + aerosol_latent_traj[i], + color=colors[i], + linewidth=1.5, + label=f"Aerosol Latent {i}", + alpha=0.8, + ) + ax.set_title("Aerosol Latent Trajectories (Should Show Distinct Temporal Patterns)") + ax.set_xlabel("Time step") + ax.set_ylabel("Latent value") + ax.legend(loc="upper right") + ax.grid(alpha=0.3) + fig.tight_layout() + fig.savefig(path / "aerosol_latent_trajectories.png", dpi=150) + plt.close(fig) + logger.info("Saved aerosol_latent_trajectories.png") + + # Plot 2: Separate subplots for each latent + fig, axes = plt.subplots(n_latents, 1, figsize=(14, 3 * n_latents), sharex=True) + if n_latents == 1: + axes = [axes] + for i, ax in enumerate(axes): + ax.plot(time_axis, aerosol_latent_traj[i], color=colors[i], linewidth=1.5) + ax.set_ylabel(f"Latent {i}") + ax.grid(alpha=0.3) + # Add peak marker + peak_idx = np.argmin(aerosol_latent_traj[i]) # Aerosol is negative + ax.axvline(x=peak_idx, color="red", linestyle="--", alpha=0.5, label=f"Peak at t={peak_idx}") + ax.legend(loc="upper right") + axes[-1].set_xlabel("Time step") + fig.suptitle("Individual Aerosol Latent Trajectories with Peak Markers", fontsize=12) + fig.tight_layout() + fig.savefig(path / "aerosol_latent_trajectories_separate.png", dpi=150) + plt.close(fig) + logger.info("Saved aerosol_latent_trajectories_separate.png") + + # Plot 3: Correlation matrix + corr_matrix = np.corrcoef(aerosol_latent_traj) + fig, ax = plt.subplots(figsize=(8, 6)) + im = ax.imshow(corr_matrix, cmap="RdBu_r", vmin=-1, vmax=1) + ax.set_title("Aerosol Latent Correlation Matrix\n(Target: Off-diagonal < 0.5)") + ax.set_xticks(np.arange(n_latents)) + ax.set_yticks(np.arange(n_latents)) + ax.set_xticklabels([f"L{i}" for i in range(n_latents)]) + ax.set_yticklabels([f"L{i}" for i in range(n_latents)]) + # Add correlation values as text + for i in range(n_latents): + for j in range(n_latents): + text_color = "white" if abs(corr_matrix[i, j]) > 0.5 else "black" + ax.text(j, i, f"{corr_matrix[i, j]:.2f}", ha="center", va="center", color=text_color, fontsize=10) + fig.colorbar(im, ax=ax, label="Correlation") + fig.tight_layout() + fig.savefig(path / "aerosol_latent_correlation_matrix.png", dpi=150) + plt.close(fig) + logger.info("Saved aerosol_latent_correlation_matrix.png") + + # Print correlation summary + off_diag_corrs = [] + for i in range(n_latents): + for j in range(i + 1, n_latents): + off_diag_corrs.append(abs(corr_matrix[i, j])) + avg_corr = np.mean(off_diag_corrs) if off_diag_corrs else 0 + max_corr = np.max(off_diag_corrs) if off_diag_corrs else 0 + logger.info(f"Aerosol latent correlations: avg={avg_corr:.4f}, max={max_corr:.4f}") + logger.info(f"Aerosol latent correlations: avg={avg_corr:.4f}, max={max_corr:.4f} (target: < 0.5)") + + def plot_compare_predictions_savar( + self, + x_past: np.ndarray, + y_true: np.ndarray, + y_recons: np.ndarray, + y_hat: np.ndarray, + sample: int, + lat: int, + lon: int, + path, + filename_prefix=None, + iteration: int = 0, + valid: str = False, + plot_through_time: bool = True, + ): + """Plot SAVAR predictions alongside reconstruction and ground truth on a latitude/longitude grid.""" + + def _reshape(arr: np.ndarray, var_idx: int) -> np.ndarray: + arr = np.asarray(arr) + if arr.ndim == 2: + if sample >= arr.shape[0]: + raise ValueError("Sample index out of bounds for provided array.") + if var_idx != 0: + raise ValueError("Variable index must be 0 for 2D inputs shaped (n_samples, lat*lon).") + flat = arr[sample] + elif arr.ndim >= 3: + if sample >= arr.shape[0] or var_idx >= arr.shape[1]: + raise ValueError("Sample or variable index out of bounds for provided array.") + flat = arr[sample, var_idx] + else: + raise ValueError("Expected arrays with at least 2 dimensions.") + flat = np.nan_to_num(flat) + return flat.reshape(lat, lon) + + y_true = np.asarray(y_true) + y_recons = np.asarray(y_recons) + y_hat = np.asarray(y_hat) + + if y_true.ndim >= 4: + y_true_current = y_true[:, 0] + y_true_next = y_true[:, 1] if y_true.shape[1] > 1 else None + else: + y_true_current = y_true + y_true_next = None + + n_vars = y_true_current.shape[1] if y_true_current.ndim > 2 else 1 + num_cols = 5 if y_true_next is not None else 4 + fig_width = 8 * num_cols + fig_height = 16 if n_vars > 1 else 8 + + fig, axs = plt.subplots(n_vars, num_cols, figsize=(fig_width, fig_height), layout="constrained") + ax_rows = axs if n_vars > 1 else [axs] + + # Panel descriptions: + # x_past: Last timestep from history (t-1) + # y_true_current: Ground truth target (t) + # y_recons: Reconstruction of y through encoder-decoder (t) + # y_hat: Model prediction of y from history (t) + panels = [ + ("Ground truth (t-1)\n[Last history step]", x_past), + ("Ground truth (t)\n[Target]", y_true_current), + ("Reconstruction (t)\n[Encode-Decode]", y_recons), + ("Prediction (t)\n[Model output]", y_hat), + ] + if y_true_next is not None: + panels.append(("Ground truth (t+1)\n[Future step]", y_true_next)) + + for var_idx, ax_row in enumerate(ax_rows): + grids = [_reshape(arr, var_idx) for _, arr in panels] + vmin = min(grid.min() for grid in grids) + vmax = max(grid.max() for grid in grids) + im = None + for ax, (title, _), grid in zip(ax_row, panels, grids): + im = ax.imshow(grid, cmap="RdBu_r", vmin=vmin, vmax=vmax, aspect="auto") + ax.set_title(title, fontsize=10) + ax.set_xlabel("Longitude index") + ax.set_ylabel("Latitude index") + if im is not None: + fig.colorbar(im, ax=ax_row[-1], orientation="vertical", shrink=1.0, label=f"Variable {var_idx}") + + if not valid: + if plot_through_time: + fname = f"compare_predictions_savar_{iteration}_sample_{sample}_train.png" + else: + fname = "compare_predictions_savar_train.png" + else: + if plot_through_time: + fname = f"compare_predictions_savar_{iteration}_sample_{sample}_valid.png" + else: + fname = "compare_predictions_savar_valid.png" + + # Create descriptive overall title + if y_true_next is not None: + title = "SAVAR Prediction Comparison: History (t-1) | Target (t) | Reconstruction (t) | Prediction (t) | Future (t+1)" + else: + title = ( + "SAVAR Prediction Comparison: History (t-1) | Ground Truth (t) | Reconstruction (t) | Prediction (t)" + ) + plt.suptitle(title, fontsize=24) + if filename_prefix: + fname = f"{filename_prefix}_{fname}" + plt.savefig(path / fname, format="png") + plt.close() + + def plot_savar_feature_maps( + self, + learner, + w_adj, + coordinates: np.ndarray, + iteration: int, + plot_through_time: bool, + path, + ): + """ + Plot learned latent feature maps for SAVAR data. + + Creates separate visualizations for climate latents and forcing latents. + """ + grid_shape = (learner.lat, learner.lon) + logger.info("Creating SAVAR feature maps visualization") + w_adj = w_adj[0] # Now w_adj_mean should be (lat*lon, num_latents) + d_z = w_adj.shape[1] + logger.info(f"Model dimension: d_z = {d_z}") + + # Get model reference (needed for decoder weight visualization) + model = learner.model.module if hasattr(learner.model, "module") else learner.model + + # Get forcing configuration from SAVAR params if available, otherwise from model + logger.info( + f"Checking for savar_params: hasattr={hasattr(learner, 'savar_params')}, is not None={getattr(learner, 'savar_params', None) is not None}" + ) + if hasattr(learner, "savar_params") and learner.savar_params is not None: + # Use ground truth configuration from SAVAR params + n_co2 = learner.savar_params.n_co2_latents + n_aerosol = learner.savar_params.n_aerosol_latents + # Get ground truth number of climate modes (n_per_col^2) + n_climate = learner.savar_params.n_per_col**2 + logger.info( + f"Using SAVAR params (ground truth): n_co2={n_co2}, n_aerosol={n_aerosol}, n_climate={n_climate}" + ) + + # Sanity check: warn if model dimension doesn't match ground truth + expected_d_z = n_climate + n_co2 + n_aerosol + if d_z != expected_d_z: + logger.warning( + f"Model dimension mismatch! Model has d_z={d_z} latents, " + f"but ground truth has {expected_d_z} ({n_climate} climate + {n_co2} CO2 + {n_aerosol} aerosol). " + f"Will plot only the first {n_climate} as climate latents." + ) + else: + # Fall back to model configuration + use_forced_latents = getattr(model, "use_forced_latents", False) + n_co2 = getattr(model, "n_forced_latents_co2", 0) if use_forced_latents else 0 + n_aerosol = getattr(model, "n_forced_latents_aerosol", 0) if use_forced_latents else 0 + n_climate = d_z - n_co2 - n_aerosol + logger.info(f"Using model config: n_co2={n_co2}, n_aerosol={n_aerosol}, n_climate={n_climate}") + + # Split latent indices + climate_indices = list(range(n_climate)) + co2_indices = list(range(n_climate, n_climate + n_co2)) + aerosol_indices = list(range(n_climate + n_co2, d_z)) + + logger.info( + f"Climate latents: {climate_indices}, CO2 latents: {co2_indices}, Aerosol latents: {aerosol_indices}" + ) + + # ==== Figure 1: Climate Latents vs Ground Truth ==== + if len(climate_indices) > 0: + self._plot_climate_feature_maps( + learner, w_adj, grid_shape, climate_indices, iteration, plot_through_time, path + ) + + # ==== Figure 2: CO2 Forcing - Ground Truth vs Forcing Decoder Reconstruction ==== + if len(co2_indices) > 0 and learner.co2_gt_spatial is not None: + self._plot_co2_feature_maps(learner, model, grid_shape, iteration, plot_through_time, path) + + # ==== Figure 3: Aerosol Forcing - Ground Truth vs Forcing Decoder Reconstruction ==== + has_templates = learner.aerosol_gt_templates is not None and len(learner.aerosol_gt_templates) > 0 + if len(aerosol_indices) > 0 and not (has_templates or learner.aerosol_gt_spatial is not None): + logger.warning( + f"Aerosol latent indices {aerosol_indices} found but no GT spatial data " + "(aerosol_forcing.npy / aerosol_spatial_templates.npy missing). " + "Regenerate SAVAR data to include aerosol GT files." + ) + if len(aerosol_indices) > 0 and (has_templates or learner.aerosol_gt_spatial is not None): + self._plot_aerosol_feature_maps( + learner, model, grid_shape, aerosol_indices, has_templates, iteration, plot_through_time, path + ) + + def _plot_climate_feature_maps( + self, learner, w_adj, grid_shape, climate_indices, iteration, plot_through_time, path + ): + """Plot climate latent feature maps vs ground truth.""" + n_climate_plots = len(climate_indices) + 1 # +1 for ground truth + combined_map_n_rows = int(np.sqrt(n_climate_plots)) + 1 + combined_map_n_columns = int(np.ceil(n_climate_plots / combined_map_n_rows)) + + fig, axs = plt.subplots( + nrows=combined_map_n_rows, + ncols=combined_map_n_columns, + figsize=(combined_map_n_columns * 3, combined_map_n_rows * 3), + ) + if combined_map_n_rows == 1: + axs = axs.reshape(1, -1) + + # Plot ground truth climate modes + ax = axs.flat[0] + gt_modes_sum = ( + learner.datamodule.savar_gt_modes.sum(axis=0) + if learner.datamodule.savar_gt_modes.ndim == 3 + else learner.datamodule.savar_gt_modes + ) + im = ax.imshow(learner.datamodule.savar_gt_noise + gt_modes_sum, cmap="viridis") + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.05) + plt.colorbar(im, cax=cax) + ax.set_title("Ground-Truth\nClimate Modes", fontsize="large") + ax.tick_params(axis="both", labelsize="large") + + # Plot climate latent features + for plot_idx, latent_idx in enumerate(climate_indices): + ax = axs.flat[plot_idx + 1] + data = w_adj[:, latent_idx].reshape(grid_shape) + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.05) + im = ax.imshow(data, cmap="viridis") + plt.colorbar(im, cax=cax) + ax.set_title(f"Climate Latent {latent_idx}", fontsize="large") + ax.tick_params(axis="both", labelsize="large") + + for ax in axs.flat[n_climate_plots:]: + fig.delaxes(ax) + + fig.tight_layout() + fname = ( + f"spatial_aggregation_climate_{iteration}.png" if plot_through_time else "spatial_aggregation_climate.png" + ) + plt.savefig(path / fname) + plt.close() + logger.info(f"Saved climate latent feature maps to {fname}") + + def _plot_co2_feature_maps(self, learner, model, grid_shape, iteration, plot_through_time, path): + """Plot CO2 forcing ground truth vs learned decoder reconstruction.""" + fig, axs = plt.subplots(1, 2, figsize=(10, 4)) + + # Left: Ground truth + im0 = axs[0].imshow(learner.co2_gt_spatial, cmap="RdBu_r") + divider0 = make_axes_locatable(axs[0]) + cax0 = divider0.append_axes("right", size="5%", pad=0.05) + plt.colorbar(im0, cax=cax0) + axs[0].set_title("Ground-Truth CO2 Forcing", fontsize="large") + + # Right: Learned decoder weights (w_co2) + try: + autoencoder = model.autoencoder if hasattr(model, "autoencoder") else None + if autoencoder is not None and hasattr(autoencoder, "get_w_co2"): + w_co2 = autoencoder.get_w_co2() + if w_co2 is not None: + w_co2_np = w_co2.cpu().numpy() + spatial_pattern = w_co2_np[:, 0].reshape(grid_shape) + im1 = axs[1].imshow(spatial_pattern, cmap="RdBu_r") + divider1 = make_axes_locatable(axs[1]) + cax1 = divider1.append_axes("right", size="5%", pad=0.05) + plt.colorbar(im1, cax=cax1) + axs[1].set_title("Learned CO2 Decoder Weights", fontsize="large") + else: + axs[1].text(0.5, 0.5, "w_co2 not available", ha="center", va="center", transform=axs[1].transAxes) + axs[1].set_title("CO2 Decoder Weights", fontsize="large") + else: + axs[1].text(0.5, 0.5, "Old model\n(no w_co2)", ha="center", va="center", transform=axs[1].transAxes) + axs[1].set_title("CO2 Decoder Weights", fontsize="large") + except Exception as e: + axs[1].text(0.5, 0.5, "Error", ha="center", va="center", transform=axs[1].transAxes) + logger.warning(f"Could not visualize CO2 decoder weights: {e}") + + for ax in axs: + ax.tick_params(axis="both", labelsize="large") + fig.tight_layout() + fname = f"spatial_aggregation_co2_{iteration}.png" if plot_through_time else "spatial_aggregation_co2.png" + plt.savefig(path / fname) + plt.close() + logger.info(f"Saved CO2 forcing comparison to {fname}") + + def _plot_aerosol_feature_maps( + self, learner, model, grid_shape, aerosol_indices, has_templates, iteration, plot_through_time, path + ): + """Plot aerosol forcing ground truth vs learned decoder reconstruction.""" + n_aerosol = len(aerosol_indices) + fig, axs = plt.subplots(2, n_aerosol, figsize=(4 * n_aerosol, 8)) + if n_aerosol == 1: + axs = axs.reshape(2, 1) + + # Top row: Ground truth aerosol patterns + for i in range(n_aerosol): + if has_templates and i < len(learner.aerosol_gt_templates): + im = axs[0, i].imshow(learner.aerosol_gt_templates[i], cmap="RdBu_r") + divider = make_axes_locatable(axs[0, i]) + cax = divider.append_axes("right", size="5%", pad=0.05) + plt.colorbar(im, cax=cax) + axs[0, i].set_title(f"GT Aerosol Template {i}", fontsize="large") + elif learner.aerosol_gt_spatial is not None and i == 0: + im = axs[0, i].imshow(learner.aerosol_gt_spatial, cmap="RdBu_r") + divider = make_axes_locatable(axs[0, i]) + cax = divider.append_axes("right", size="5%", pad=0.05) + plt.colorbar(im, cax=cax) + axs[0, i].set_title("GT Aerosol (combined)", fontsize="large") + else: + axs[0, i].text( + 0.5, 0.5, "No template\n(regenerate data)", ha="center", va="center", transform=axs[0, i].transAxes + ) + axs[0, i].set_title(f"GT Aerosol {i}", fontsize="large") + + # Bottom row: Forcing decoder weights + try: + autoencoder = model.autoencoder if hasattr(model, "autoencoder") else None + if autoencoder is not None and hasattr(autoencoder, "get_w_aerosol"): + w_aerosol = autoencoder.get_w_aerosol() + if w_aerosol is not None: + w_aerosol_np = w_aerosol.cpu().numpy() + for i in range(n_aerosol): + spatial_pattern = w_aerosol_np[:, i].reshape(grid_shape) + im = axs[1, i].imshow(spatial_pattern, cmap="RdBu_r") + divider = make_axes_locatable(axs[1, i]) + cax = divider.append_axes("right", size="5%", pad=0.05) + plt.colorbar(im, cax=cax) + axs[1, i].set_title(f"Learned Decoder Weights {i}", fontsize="large") + else: + for i in range(n_aerosol): + axs[1, i].text( + 0.5, 0.5, "w_aerosol not available", ha="center", va="center", transform=axs[1, i].transAxes + ) + else: + for i in range(n_aerosol): + axs[1, i].text( + 0.5, 0.5, "Old model\n(no w_aerosol)", ha="center", va="center", transform=axs[1, i].transAxes + ) + except Exception as e: + logger.warning(f"Could not visualize aerosol decoder weights: {e}") + for i in range(n_aerosol): + axs[1, i].text(0.5, 0.5, "Error", ha="center", va="center", transform=axs[1, i].transAxes) + + for ax in axs.flat: + ax.tick_params(axis="both", labelsize="large") + fig.tight_layout() + fname = ( + f"spatial_aggregation_aerosol_{iteration}.png" if plot_through_time else "spatial_aggregation_aerosol.png" + ) + plt.savefig(path / fname) + plt.close() + logger.info(f"Saved aerosol forcing comparison to {fname}") + + def plot_decoder_connectivity_heatmap( + self, + learner, + w_adj, + iteration: int, + plot_through_time: bool, + path, + ): + """ + Plot decoder connectivity heatmap showing spatial × latents connections. + + NOTE: With the architectural fix, only CLIMATE latents are used in observation decoding. + Forcing latents (CO2, aerosol) are excluded from the observation decoder and only + contribute through forcing decoders and the causal transition model. When forced + latents are present, they are shown here for reference. + """ + logger.info("Creating decoder connectivity heatmap") + w_adj = w_adj[0] # Shape: (spatial_resolution, num_latents) + d_z = w_adj.shape[1] + + # Detect forcing configuration from model + model = learner.model.module if hasattr(learner.model, "module") else learner.model + use_forced_latents = getattr(model, "use_forced_latents", False) + n_co2 = getattr(model, "n_forced_latents_co2", 0) if use_forced_latents else 0 + n_aerosol = getattr(model, "n_forced_latents_aerosol", 0) if use_forced_latents else 0 + n_climate = d_z - n_co2 - n_aerosol + + include_forcings = use_forced_latents and (n_co2 + n_aerosol) > 0 + + if include_forcings: + w_adj_plot = w_adj + latent_labels = ( + [f"Climate {i}" for i in range(n_climate)] + + [f"CO2 {i}" for i in range(n_co2)] + + [f"Aerosol {i}" for i in range(n_aerosol)] + ) + latent_types = ["Climate"] * n_climate + ["CO2"] * n_co2 + ["Aerosol"] * n_aerosol + else: + # Only show climate latent columns (actually used in observation decoding) + w_adj_plot = w_adj[:, :n_climate] + latent_labels = [f"Climate {i}" for i in range(n_climate)] + latent_types = ["Climate"] * n_climate + + # Create heatmap + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) + + # Left plot: Decoder connectivity heatmap (spatial × latents) + im1 = ax1.imshow(np.abs(w_adj_plot), aspect="auto", cmap="viridis", interpolation="nearest") + ax1.set_xlabel("Latent Index", fontsize=12) + ax1.set_ylabel("Spatial Location", fontsize=12) + if include_forcings: + ax1.set_title("Observation Decoder Connectivity\n(Forcing latents shown for reference)", fontsize=14) + else: + ax1.set_title("Observation Decoder Connectivity\n(Climate Latents Only)", fontsize=14) + ax1.set_xticks(range(len(latent_labels))) + ax1.set_xticklabels(latent_labels, rotation=45, ha="right", fontsize=10) + plt.colorbar(im1, ax=ax1, label="Weight Magnitude") + + # Right plot: Latent-wise L2 norms (bar chart) + latent_norms = np.linalg.norm(w_adj_plot, axis=0) # L2 norm for each latent + if include_forcings: + type_colors = {"Climate": "tab:green", "CO2": "tab:red", "Aerosol": "tab:blue"} + else: + type_colors = {"Climate": "tab:blue"} + bar_colors = [type_colors[label] for label in latent_types] + + ax2.bar(range(len(latent_norms)), latent_norms, color=bar_colors, alpha=0.7) + ax2.set_xlabel("Latent Index", fontsize=12) + ax2.set_ylabel("Decoder Weight L2 Norm", fontsize=12) + if include_forcings: + ax2.set_title("Latent Usage\n(Forcing latents shown for reference)", fontsize=14) + else: + ax2.set_title("Climate Latent Usage\n(Forcing latents excluded from obs decoder)", fontsize=14) + ax2.set_xticks(range(len(latent_labels))) + ax2.set_xticklabels(latent_labels, rotation=45, ha="right", fontsize=10) + ax2.grid(axis="y", alpha=0.3) + if include_forcings: + legend_patches = [ + Patch(facecolor=type_colors[name], label=name) + for name in ["Climate", "CO2", "Aerosol"] + if name in latent_types + ] + ax2.legend(handles=legend_patches, loc="upper right", fontsize=9) + + # Log latent norms + logger.info(f"Latent norms: {latent_norms}") + if use_forced_latents and include_forcings: + logger.info(f"Included forcing latents in plot for reference ({n_co2} CO2, {n_aerosol} aerosol)") + elif use_forced_latents: + logger.info(f"Note: Forcing latent columns ({n_co2} CO2, {n_aerosol} aerosol) excluded from obs decoder") + + fig.tight_layout() + + if plot_through_time: + fname = f"decoder_connectivity_{iteration}.png" + else: + fname = "decoder_connectivity.png" + + plt.savefig(path / fname, dpi=150) + plt.close() + logger.info(f"Saved decoder connectivity heatmap to {fname}") + + def plot_adjacency_matrix_savar( + self, + learner, + mat1: np.ndarray, + mat2: np.ndarray, + modes_gt, + modes_inferred, + path, + name_suffix: str, + no_gt: bool = False, + iteration: int = 0, + plot_through_time: bool = True, + ): + """ + Plot adjacency matrices for SAVAR runs after aligning inferred modes with the ground truth. + + Uses spatial proximity of mode centroids to find the optimal permutation before plotting. Fully self-contained — + does not delegate to Plotter.plot_adjacency_matrix. + """ + effective_no_gt = no_gt or mat2 is None + lat = getattr(learner, "lat", None) + lon = getattr(learner, "lon", None) + tau = mat1.shape[0] + + mat1_to_plot = np.array(mat1, copy=True) + + # Permute learned adjacency to align with ground-truth modes. + # Only permute climate-mode submatrix; forcing latent rows/columns stay in place. + if ( + not effective_no_gt + and lat is not None + and lon is not None + and modes_gt is not None + and modes_inferred is not None + ): + # modes_inferred may be (d, d_x, d_z) from get_w_decoder(); squeeze to (d_x, d_z) + mi = np.squeeze(modes_inferred) + + n_climate = modes_gt.shape[0] + d_total = mat1_to_plot.shape[1] + if n_climate < d_total: + # Slice latent axis (last) to climate-only columns, then permute climate submatrix + climate_sub = mat1_to_plot[:, :n_climate, :n_climate].copy() + mi_climate = mi[..., :n_climate] + climate_sub = permute_matrices(lat, lon, mi_climate, modes_gt, climate_sub, tau) + mat1_to_plot[:, :n_climate, :n_climate] = climate_sub + else: + mat1_to_plot = permute_matrices(lat, lon, mi, modes_gt, mat1_to_plot, tau) + + # Prepare ground truth adjacency matrix + forcing_indices = getattr(learner, "forcing_indices", None) + mat2_aligned = None + if not effective_no_gt: + mat2_aligned = np.array(mat2, copy=True) + if mat2_aligned.ndim == 2: + mat2_aligned = mat2_aligned[None, ...] + if mat2_aligned.shape[0] != tau: + if mat2_aligned.shape[0] == 1: + mat2_aligned = np.repeat(mat2_aligned, tau, axis=0) + else: + mat2_aligned = mat2_aligned[:tau] + target_d = mat1_to_plot.shape[1] + if mat2_aligned.shape[1] != target_d or mat2_aligned.shape[2] != target_d: + resized = np.zeros((tau, target_d, target_d), dtype=mat2_aligned.dtype) + min_d = min(target_d, mat2_aligned.shape[1], mat2_aligned.shape[2]) + resized[:, :min_d, :min_d] = mat2_aligned[:, :min_d, :min_d] + mat2_aligned = resized + + # Create figure and plot + fig = plt.figure(constrained_layout=True) + fig.suptitle("Adjacency matrices: learned vs ground-truth") + + if tau == 1: + self._plot_adjacency_single_time(fig, mat1_to_plot, mat2_aligned, effective_no_gt, forcing_indices) + else: + self._plot_adjacency_through_time(fig, mat1_to_plot, mat2_aligned, effective_no_gt, tau, forcing_indices) + + if plot_through_time: + fname = f"adjacency_{name_suffix}_{iteration}.png" + else: + fname = f"adjacency_{name_suffix}.png" + + plt.savefig(path / fname, format="png") + plt.close() + + def plot_original_savar(self, data, lat, lon, path): + """ + Create an animated GIF of the original SAVAR data over time. + + Args: + data: SAVAR data of shape (n_modes, spatial_resolution, time) or similar + lat: Latitude dimension + lon: Longitude dimension + path: Save path for the GIF + """ + logger.info(f"Creating SAVAR original data animation - data shape: {data.shape}") + + # Get the dimensions + time_steps = data.shape[1] + data_reshaped = data.T.reshape((time_steps, lat, lon)) + + # Calculate the average over the time axis + avg_data = np.mean(data_reshaped, axis=0) + + # Determine the global min and max from the averaged data for consistent color scaling + vmin = np.min(avg_data) + vmax = np.max(avg_data) + + fig, ax = plt.subplots(figsize=(lon / 10, lat / 10)) + cax = ax.imshow(data_reshaped[0], aspect="auto", cmap="viridis", vmin=vmin, vmax=vmax) + + def animate(i): + cax.set_data(data_reshaped[i]) + ax.set_title(f"SAVAR Original Data - Time step: {i+1}/{time_steps}") + return (cax,) + + # Create an animation (first 100 timesteps to keep file size reasonable) + n_frames = min(100, time_steps) + ani = animation.FuncAnimation(fig, animate, frames=n_frames, blit=False) + + # Save the animation as a GIF + ani.save(path, writer="pillow", fps=10) + plt.close() + + logger.info(f"Saved SAVAR original data animation to {path}") + + def compute_snr_metrics(self, signal_data: np.ndarray, noise_data: np.ndarray) -> dict: + """ + Compute signal-to-noise ratio metrics. + + Args: + signal_data: Signal component (spatial, time) + noise_data: Noise component (spatial, time) + + Returns: + Dictionary with SNR metrics + """ + # Compute standard deviations + signal_std = np.std(signal_data) + noise_std = np.std(noise_data) + + # SNR by standard deviation + snr_std = signal_std / noise_std if noise_std > 0 else np.inf + + # SNR in dB + snr_db = 10 * np.log10(snr_std) if snr_std > 0 else -np.inf + + # Amplitude-based SNR (max range) + signal_range = signal_data.max() - signal_data.min() + noise_range = noise_data.max() - noise_data.min() + snr_amplitude = signal_range / noise_range if noise_range > 0 else np.inf + + # Root mean square + signal_rms = np.sqrt(np.mean(signal_data**2)) + noise_rms = np.sqrt(np.mean(noise_data**2)) + snr_rms = signal_rms / noise_rms if noise_rms > 0 else np.inf + + return { + "signal_std": signal_std, + "noise_std": noise_std, + "snr_std": snr_std, + "snr_db": snr_db, + "signal_range": signal_range, + "noise_range": noise_range, + "snr_amplitude": snr_amplitude, + "signal_rms": signal_rms, + "noise_rms": noise_rms, + "snr_rms": snr_rms, + } + + def plot_savar_data_noise_ranges( + self, + savar_data: np.ndarray, + noise_data: Optional[np.ndarray], + path: Path, + title: str = "SAVAR data", + ) -> None: + """ + Plot data field range and noise field range over time, similar to signal-to-noise diagnostics. + + Creates a filled area plot showing the min-max range of data values and noise values + across spatial dimensions at each time step. + + Args: + savar_data: Climate data of shape (spatial, time) + noise_data: Noise data of shape (spatial, time) or None + path: Save directory path + title: Title prefix for the plot + """ + logger.info(f"Plotting data/noise field ranges for {title}") + logger.info(f" savar_data shape: {savar_data.shape}") + logger.info(f" noise_data: {noise_data is not None}") + if noise_data is not None: + logger.info(f" noise_data shape: {noise_data.shape}") + + spatial_dim, time_len = savar_data.shape + + # Compute min/max range across spatial dimension at each time step + data_min = savar_data.min(axis=0) + data_max = savar_data.max(axis=0) + time_axis = np.arange(time_len) + + fig, ax = plt.subplots(figsize=(14, 5)) + + # Plot data field range + ax.fill_between(time_axis, data_min, data_max, color="steelblue", alpha=0.7, label="data field range") + + # Plot noise field range if available and compute SNR + snr_text = "" + if noise_data is not None: + logger.info(" → Plotting noise field range") + noise_min = noise_data.min(axis=0) + noise_max = noise_data.max(axis=0) + ax.fill_between(time_axis, noise_min, noise_max, color="sandybrown", alpha=0.7, label="noise field range") + + # Compute SNR metrics + snr_metrics = self.compute_snr_metrics(savar_data, noise_data) + logger.info(f" → SNR (std): {snr_metrics['snr_std']:.3f} ({snr_metrics['snr_db']:.2f} dB)") + logger.info(f" → Signal std: {snr_metrics['signal_std']:.4f}, Noise std: {snr_metrics['noise_std']:.4f}") + logger.info(f" → SNR (amplitude): {snr_metrics['snr_amplitude']:.3f}") + + # Create SNR annotation text + snr_text = ( + f"SNR: {snr_metrics['snr_std']:.2f} ({snr_metrics['snr_db']:.1f} dB)\n" + f"Signal std: {snr_metrics['signal_std']:.4f} | Noise std: {snr_metrics['noise_std']:.4f}" + ) + else: + logger.warning(" → Noise data is None, skipping noise plot") + + ax.set_xlabel("Time step", fontsize=12) + ax.set_ylabel("Value", fontsize=12) + ax.set_title(title, fontsize=14, fontweight="bold") + ax.legend(loc="upper right", fontsize=11) + ax.grid(alpha=0.3) + + # Add SNR text box if available + if snr_text: + ax.text( + 0.02, + 0.02, + snr_text, + transform=ax.transAxes, + fontsize=10, + verticalalignment="bottom", + bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8), + ) + + fig.tight_layout() + filename = f"{title.lower().replace(' ', '_')}_field_ranges.png" + fig.savefig(path / filename, dpi=150) + plt.close(fig) + logger.info(f"Saved {filename}") + + def plot_savar_signal_noise_decomposition( + self, + savar_data: np.ndarray, + deterministic_component: Optional[np.ndarray], + noise_component: Optional[np.ndarray], + path: Path, + linearity: str = "linear", + poly_degrees: Optional[List[int]] = None, + ) -> None: + """ + Create separate plots for linear/nonlinear/polynomial components showing data vs noise ranges. + + This mimics the style of the reference plots showing signal-to-noise characteristics + for different data types. + + Args: + savar_data: Full SAVAR climate data (spatial, time) + deterministic_component: Deterministic signal component (spatial, time) or None + noise_component: Noise component (spatial, time) or None + path: Save directory path + linearity: Type of data ("linear", "nonlinear", "polynomial") + poly_degrees: List of polynomial degrees if linearity is "polynomial" + """ + logger.info(f"Creating SAVAR signal-noise decomposition plots (linearity={linearity})") + + # Determine title prefix based on linearity type + if linearity == "linear": + data_type = "Linear data" + elif linearity == "nonlinear": + data_type = "Nonlinear data" + elif linearity == "polynomial": + if poly_degrees: + deg_str = ",".join(map(str, poly_degrees)) + data_type = f"Polynomial data (degrees {deg_str})" + else: + data_type = "Polynomial data" + else: + data_type = f"{linearity.capitalize()} data" + + # Plot 1: Full data range + self.plot_savar_data_noise_ranges( + savar_data=savar_data, + noise_data=noise_component, + path=path, + title=data_type, + ) + + # Plot 2: If we have deterministic component, plot it separately + if deterministic_component is not None: + self.plot_savar_data_noise_ranges( + savar_data=deterministic_component, + noise_data=noise_component, + path=path, + title=f"{data_type} - Deterministic Component", + ) + + # Plot 3: Residual (data - deterministic) vs noise + residual = savar_data - deterministic_component + self.plot_savar_data_noise_ranges( + savar_data=residual, + noise_data=noise_component, + path=path, + title=f"{data_type} - Residual", + ) + + logger.info(f"Completed SAVAR signal-noise decomposition plots for {data_type}") + + # ========================================================================= + # FORCING DIAGNOSTIC PLOTS - Global Correlation Analysis + # ========================================================================= + + # needs to be edited, gotta ask julien again how he wants it + def plot_correlation_heatmap_over_time( + self, + forcing: np.ndarray, + climate: np.ndarray, + forcing_name: str, + window_size: int, + path: Path, + ) -> None: + """ + Plot sliding window correlation between forcing and climate over time. + + Args: + forcing: Forcing data of shape (spatial, time) or (time,) + climate: Climate data of shape (spatial, time) + forcing_name: Name of forcing + window_size: Size of sliding window + path: Save path + """ + logger.info(f"Computing sliding window correlation for {forcing_name}") + + # Ensure forcing and climate have matching time dimensions + time_forcing = forcing.shape[-1] + time_climate = climate.shape[-1] + if time_forcing != time_climate: + min_time = min(time_forcing, time_climate) + logger.warning( + f"Time dimension mismatch: forcing has {time_forcing} timesteps, " + f"climate has {time_climate} timesteps. Trimming both to {min_time}." + ) + forcing = forcing[..., :min_time] + climate = climate[..., :min_time] + + # Aggregate spatially + forcing_ts = forcing.mean(axis=0) if forcing.ndim == 2 else forcing + climate_ts = climate.mean(axis=0) if climate.ndim == 2 else climate + + time_len = len(forcing_ts) + n_windows = time_len - window_size + 1 + + if n_windows < 10: + logger.warning(f"Too few windows ({n_windows}), skipping correlation heatmap") + return + + # Compute correlation in sliding windows + correlations = [] + window_centers = [] + + for i in range(n_windows): + window_forcing = forcing_ts[i : i + window_size] + window_climate = climate_ts[i : i + window_size] + corr = np.corrcoef(window_forcing, window_climate)[0, 1] + correlations.append(corr) + window_centers.append(i + window_size // 2) + + correlations = np.array(correlations) + window_centers = np.array(window_centers) + + # Create figure with two subplots + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), sharex=True) + color = "tab:red" if forcing_name == "CO2" else "tab:blue" + + # Top: forcing and climate time series + ax1_twin = ax1.twinx() + ax1.plot(np.arange(time_len), forcing_ts, color=color, linewidth=1, label=f"{forcing_name}", alpha=0.8) + ax1_twin.plot(np.arange(time_len), climate_ts, color="green", linewidth=1, label="Climate", alpha=0.8) + ax1.set_ylabel(f"{forcing_name} (normalized)", color=color) + ax1_twin.set_ylabel("Climate (normalized)", color="green") + ax1.legend(loc="upper left") + ax1_twin.legend(loc="upper right") + ax1.set_title(f"{forcing_name} and Climate Time Series", fontsize=12) + + # Bottom: sliding window correlation + ax2.fill_between(window_centers, 0, correlations, where=(correlations > 0), color="green", alpha=0.5) + ax2.fill_between(window_centers, 0, correlations, where=(correlations < 0), color="red", alpha=0.5) + ax2.plot(window_centers, correlations, color="black", linewidth=1) + ax2.axhline(y=0, color="gray", linestyle="-", linewidth=0.5) + ax2.set_xlabel("Time step", fontsize=12) + ax2.set_ylabel(f"Correlation (window={window_size})", fontsize=12) + ax2.set_title(f"Sliding Window Correlation: {forcing_name} vs Climate", fontsize=12) + ax2.set_ylim(-1, 1) + ax2.grid(alpha=0.3) + + fig.tight_layout() + filename = f"{forcing_name.lower()}_correlation_over_time.png" + fig.savefig(path / filename, dpi=150) + plt.close(fig) + logger.info(f"Saved {filename}") + + # ========================================================================= + # FORCING DIAGNOSTIC PLOTS - Spatial Correlation Analysis + # ========================================================================= + + def plot_pointwise_correlation_map( + self, + forcing: np.ndarray, + climate: np.ndarray, + forcing_name: str, + lat: int, + lon: int, + path: Path, + lag: int = 0, + ) -> None: + """ + Plot correlation between forcing and climate at each grid point. + + Args: + forcing: Forcing data (spatial, time) or (time,) - will broadcast if 1D + climate: Climate data (spatial, time) + forcing_name: Name of forcing + lat, lon: Grid dimensions + path: Save path + lag: Time lag to apply (positive = forcing leads) + """ + logger.info(f"Computing pointwise correlation map for {forcing_name} (lag={lag})") + + # Ensure forcing and climate have matching time dimensions + time_forcing = forcing.shape[-1] + time_climate = climate.shape[-1] + if time_forcing != time_climate: + min_time = min(time_forcing, time_climate) + logger.warning( + f"Time dimension mismatch: forcing has {time_forcing} timesteps, " + f"climate has {time_climate} timesteps. Trimming both to {min_time}." + ) + forcing = forcing[..., :min_time] + climate = climate[..., :min_time] + + spatial_size = lat * lon + + # Handle forcing shape + if forcing.ndim == 1: + # Broadcast 1D forcing to all spatial points + forcing_broadcast = np.tile(forcing, (spatial_size, 1)) + else: + forcing_broadcast = forcing + + # Apply lag + if lag > 0: + forcing_aligned = forcing_broadcast[:, lag:] + climate_aligned = climate[:, :-lag] + elif lag < 0: + forcing_aligned = forcing_broadcast[:, :lag] + climate_aligned = climate[:, -lag:] + else: + forcing_aligned = forcing_broadcast + climate_aligned = climate + + # Compute correlation at each spatial point + correlations = np.zeros(spatial_size) + for i in range(spatial_size): + f = forcing_aligned[i] if forcing_aligned.ndim == 2 else forcing_aligned + c = climate_aligned[i] + if np.std(f) > 1e-10 and np.std(c) > 1e-10: + correlations[i] = np.corrcoef(f, c)[0, 1] + else: + correlations[i] = 0 + + # Reshape to grid + corr_map = correlations.reshape(lat, lon) + + # Plot + fig, ax = plt.subplots(figsize=(10, 8)) + im = ax.imshow(corr_map, cmap="RdBu_r", vmin=-1, vmax=1, aspect="auto") + ax.set_xlabel("Longitude index", fontsize=12) + ax.set_ylabel("Latitude index", fontsize=12) + lag_str = f" (lag={lag})" if lag != 0 else "" + ax.set_title(f"Pointwise Correlation: {forcing_name} vs Climate{lag_str}", fontsize=14) + fig.colorbar(im, ax=ax, label="Correlation") + + # Add statistics + mean_corr = np.nanmean(corr_map) + max_corr = np.nanmax(np.abs(corr_map)) + ax.text( + 0.02, + 0.98, + f"Mean: {mean_corr:.3f}\nMax |r|: {max_corr:.3f}", + transform=ax.transAxes, + verticalalignment="top", + fontsize=10, + bbox=dict(boxstyle="round", facecolor="white", alpha=0.8), + ) + + fig.tight_layout() + filename = f"{forcing_name.lower()}_pointwise_correlation_lag{lag}.png" + fig.savefig(path / filename, dpi=150) + plt.close(fig) + logger.info(f"Saved {filename} (mean r={mean_corr:.3f})") + + # ========================================================================= + # FORCING DIAGNOSTIC PLOTS - Joint Animations + # ========================================================================= + + def plot_joint_forcing_climate_animation( + self, + co2_forcing: Optional[np.ndarray], + aerosol_forcing: Optional[np.ndarray], + climate: np.ndarray, + lat: int, + lon: int, + path: Path, + max_frames: int = 200, + ) -> None: + """ + Create side-by-side animation of forcing fields and climate field. + + Args: + co2_forcing: CO2 forcing (spatial, time) or None + aerosol_forcing: Aerosol forcing (spatial, time) or None + climate: Climate data (spatial, time) + lat, lon: Grid dimensions + path: Save path + max_frames: Maximum number of frames + """ + logger.info("Creating joint forcing-climate animation") + + # Ensure forcing and climate have matching time dimensions + time_climate = climate.shape[-1] + if co2_forcing is not None: + time_co2 = co2_forcing.shape[-1] + if time_co2 != time_climate: + min_time = min(time_co2, time_climate) + logger.warning( + f"Time dimension mismatch: CO2 forcing has {time_co2} timesteps, " + f"climate has {time_climate} timesteps. Trimming both to {min_time}." + ) + co2_forcing = co2_forcing[..., :min_time] + climate = climate[..., :min_time] + time_climate = min_time + + if aerosol_forcing is not None: + time_aerosol = aerosol_forcing.shape[-1] + if time_aerosol != time_climate: + min_time = min(time_aerosol, time_climate) + logger.warning( + f"Time dimension mismatch: aerosol forcing has {time_aerosol} timesteps, " + f"climate has {time_climate} timesteps. Trimming both to {min_time}." + ) + aerosol_forcing = aerosol_forcing[..., :min_time] + climate = climate[..., :min_time] + + time_len = climate.shape[1] + frame_stride = max(1, time_len // max_frames) + frame_indices = np.arange(0, time_len, frame_stride, dtype=int) + + # Reshape data + climate_reshaped = climate.T.reshape(time_len, lat, lon) + + # Determine number of columns + n_cols = 1 # Climate is always shown + forcings_to_plot = [] + if co2_forcing is not None: + co2_reshaped = co2_forcing.T.reshape(time_len, lat, lon) + forcings_to_plot.append(("CO2", co2_reshaped, "Reds")) + n_cols += 1 + if aerosol_forcing is not None: + aerosol_reshaped = aerosol_forcing.T.reshape(time_len, lat, lon) + forcings_to_plot.append(("Aerosol", aerosol_reshaped, "Blues")) + n_cols += 1 + + # Create figure + fig, axes = plt.subplots(1, n_cols, figsize=(5 * n_cols, 4)) + if n_cols == 1: + axes = [axes] + + # Initialize images + images = [] + for idx, (name, data, cmap) in enumerate(forcings_to_plot): + vmin, vmax = data.min(), data.max() + if vmin == vmax: + vmax = vmin + 1e-6 + im = axes[idx].imshow(data[0], cmap=cmap, vmin=vmin, vmax=vmax, aspect="auto", animated=True) + axes[idx].set_title(f"{name} (t=0)") + axes[idx].set_xlabel("Lon") + axes[idx].set_ylabel("Lat") + fig.colorbar(im, ax=axes[idx], shrink=0.8) + images.append((im, name, data)) + + # Climate + climate_vmin, climate_vmax = climate_reshaped.min(), climate_reshaped.max() + climate_im = axes[-1].imshow( + climate_reshaped[0], cmap="RdBu_r", vmin=climate_vmin, vmax=climate_vmax, aspect="auto", animated=True + ) + axes[-1].set_title("Climate (t=0)") + axes[-1].set_xlabel("Lon") + fig.colorbar(climate_im, ax=axes[-1], shrink=0.8) + + def update(frame_idx): + t = frame_indices[frame_idx] + for im, name, data in images: + im.set_array(data[t]) + climate_im.set_array(climate_reshaped[t]) + for idx, (name, _, _) in enumerate(forcings_to_plot): + axes[idx].set_title(f"{name} (t={t})") + axes[-1].set_title(f"Climate (t={t})") + return [im for im, _, _ in images] + [climate_im] + + anim = animation.FuncAnimation(fig, update, frames=len(frame_indices), interval=100, blit=True) + writer = animation.PillowWriter(fps=10) + filename = "joint_forcing_climate.gif" + anim.save(path / filename, writer=writer) + plt.close(fig) + logger.info(f"Saved {filename}") + + # ========================================================================= + # FORCING DIAGNOSTIC PLOTS - Variance Attribution + # ========================================================================= + + # rescale this vertically - variance difference too big? + def plot_variance_explained_by_forcing( + self, + co2_forcing: Optional[np.ndarray], + aerosol_forcing: Optional[np.ndarray], + climate: np.ndarray, + lat: int, + lon: int, + path: Path, + window_size: int = 100, + ) -> None: + """ + Plot R² of forcing → climate regression over time (sliding window). + + Args: + co2_forcing: CO2 forcing (spatial, time) or None + aerosol_forcing: Aerosol forcing (spatial, time) or None + climate: Climate data (spatial, time) + lat, lon: Grid dimensions + path: Save path + window_size: Sliding window size + """ + logger.info("Computing variance explained by forcings over time") + + # Ensure forcing and climate have matching time dimensions + time_climate = climate.shape[-1] + if co2_forcing is not None: + time_co2 = co2_forcing.shape[-1] + if time_co2 != time_climate: + min_time = min(time_co2, time_climate) + logger.warning( + f"Time dimension mismatch: CO2 forcing has {time_co2} timesteps, " + f"climate has {time_climate} timesteps. Trimming both to {min_time}." + ) + co2_forcing = co2_forcing[..., :min_time] + climate = climate[..., :min_time] + time_climate = min_time + + if aerosol_forcing is not None: + time_aerosol = aerosol_forcing.shape[-1] + if time_aerosol != time_climate: + min_time = min(time_aerosol, time_climate) + logger.warning( + f"Time dimension mismatch: aerosol forcing has {time_aerosol} timesteps, " + f"climate has {time_climate} timesteps. Trimming both to {min_time}." + ) + aerosol_forcing = aerosol_forcing[..., :min_time] + climate = climate[..., :min_time] + + # Spatially average + climate_ts = climate.mean(axis=0) + time_len = len(climate_ts) + n_windows = time_len - window_size + 1 + + if n_windows < 10: + logger.warning("Too few windows for variance explained plot") + return + + window_centers = np.arange(window_size // 2, time_len - window_size // 2 + 1) + + results = {} + + # Process each forcing + forcings = [] + if co2_forcing is not None: + forcings.append(("CO2", co2_forcing.mean(axis=0) if co2_forcing.ndim == 2 else co2_forcing, "tab:red")) + if aerosol_forcing is not None: + forcings.append( + ("Aerosol", aerosol_forcing.mean(axis=0) if aerosol_forcing.ndim == 2 else aerosol_forcing, "tab:blue") + ) + + for name, forcing_ts, color in forcings: + r2_values = [] + for i in range(n_windows): + window_forcing = forcing_ts[i : i + window_size] + window_climate = climate_ts[i : i + window_size] + + # Simple linear regression R² + corr = np.corrcoef(window_forcing, window_climate)[0, 1] + r2 = corr**2 + r2_values.append(r2) + + results[name] = (np.array(r2_values), color) + + # Combined model (if both forcings available) + if len(forcings) == 2: + r2_combined = [] + f1_ts = forcings[0][1] + f2_ts = forcings[1][1] + + for i in range(n_windows): + w_f1 = f1_ts[i : i + window_size] + w_f2 = f2_ts[i : i + window_size] + w_climate = climate_ts[i : i + window_size] + + # Multiple regression: climate = a*f1 + b*f2 + c + X = np.column_stack([w_f1, w_f2, np.ones(window_size)]) + try: + coeffs, residuals, rank, s = np.linalg.lstsq(X, w_climate, rcond=None) + y_pred = X @ coeffs + ss_res = np.sum((w_climate - y_pred) ** 2) + ss_tot = np.sum((w_climate - w_climate.mean()) ** 2) + r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0 + except Exception: + r2 = 0 + r2_combined.append(r2) + + results["Combined"] = (np.array(r2_combined), "green") + + # Plot + fig, ax = plt.subplots(figsize=(14, 6)) + + for name, (r2_vals, color) in results.items(): + linestyle = "-" if name != "Combined" else "--" + ax.plot(window_centers, r2_vals, color=color, linewidth=2, label=f"{name} R²", linestyle=linestyle) + + ax.set_xlabel("Time step (window center)", fontsize=12) + ax.set_ylabel("Variance Explained (R²)", fontsize=12) + ax.set_title(f"Variance Explained by Forcings Over Time (window={window_size})", fontsize=14) + ax.set_ylim(0, 1) + ax.legend(loc="best") + ax.grid(alpha=0.3) + + fig.tight_layout() + filename = "variance_explained_over_time.png" + fig.savefig(path / filename, dpi=150) + plt.close(fig) + logger.info(f"Saved {filename}") + + def plot_forcing_attribution_summary( + self, + co2_forcing: Optional[np.ndarray], + aerosol_forcing: Optional[np.ndarray], + climate: np.ndarray, + path: Path, + ) -> None: + """ + Create summary bar chart of variance explained by CO2, aerosol, and internal variability. + + Args: + co2_forcing: CO2 forcing (spatial, time) or None + aerosol_forcing: Aerosol forcing (spatial, time) or None + climate: Climate data (spatial, time) + path: Save path + """ + logger.info("Creating forcing attribution summary") + + # Ensure forcing and climate have matching time dimensions + time_climate = climate.shape[-1] + if co2_forcing is not None: + time_co2 = co2_forcing.shape[-1] + if time_co2 != time_climate: + min_time = min(time_co2, time_climate) + logger.warning( + f"Time dimension mismatch: CO2 forcing has {time_co2} timesteps, " + f"climate has {time_climate} timesteps. Trimming both to {min_time}." + ) + co2_forcing = co2_forcing[..., :min_time] + climate = climate[..., :min_time] + time_climate = min_time + + if aerosol_forcing is not None: + time_aerosol = aerosol_forcing.shape[-1] + if time_aerosol != time_climate: + min_time = min(time_aerosol, time_climate) + logger.warning( + f"Time dimension mismatch: aerosol forcing has {time_aerosol} timesteps, " + f"climate has {time_climate} timesteps. Trimming both to {min_time}." + ) + aerosol_forcing = aerosol_forcing[..., :min_time] + climate = climate[..., :min_time] + + # Spatially average + climate_ts = climate.mean(axis=0) + _ = np.var(climate_ts) + + # Build feature matrix + features = [] + feature_names = [] + feature_colors = [] + + if co2_forcing is not None: + co2_ts = co2_forcing.mean(axis=0) if co2_forcing.ndim == 2 else co2_forcing + features.append(co2_ts) + feature_names.append("CO2") + feature_colors.append("tab:red") + + if aerosol_forcing is not None: + aerosol_ts = aerosol_forcing.mean(axis=0) if aerosol_forcing.ndim == 2 else aerosol_forcing + features.append(aerosol_ts) + feature_names.append("Aerosol") + feature_colors.append("tab:blue") + + if len(features) == 0: + logger.warning("No forcing data for attribution summary") + return + + # Full model R² + X = np.column_stack(features + [np.ones(len(climate_ts))]) + try: + coeffs, _, _, _ = np.linalg.lstsq(X, climate_ts, rcond=None) + y_pred = X @ coeffs + ss_res = np.sum((climate_ts - y_pred) ** 2) + ss_tot = np.sum((climate_ts - climate_ts.mean()) ** 2) + r2_full = 1 - ss_res / ss_tot if ss_tot > 0 else 0 + except Exception: + r2_full = 0 + + # Individual R² (marginal contribution) + r2_individual = [] + for i, f in enumerate(features): + corr = np.corrcoef(f, climate_ts)[0, 1] + r2_individual.append(corr**2) + + # Internal variability = 1 - R²_full + internal_var = 1 - r2_full + + # Create pie chart and bar chart + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) + + # Pie chart + pie_values = r2_individual + [internal_var] + pie_labels = feature_names + ["Internal"] + pie_colors = feature_colors + ["gray"] + + ax1.pie( + pie_values, + labels=pie_labels, + colors=pie_colors, + autopct="%1.1f%%", + startangle=90, + explode=[0.05] * len(pie_values), + ) + ax1.set_title("Variance Attribution\n(Marginal R²)", fontsize=12) + + # Bar chart + bar_positions = np.arange(len(feature_names) + 2) + bar_heights = r2_individual + [r2_full, internal_var] + bar_labels = feature_names + ["Combined", "Internal"] + bar_colors = feature_colors + ["green", "gray"] + + ax2.bar(bar_positions, bar_heights, color=bar_colors, alpha=0.7, edgecolor="black") + ax2.set_xticks(bar_positions) + ax2.set_xticklabels(bar_labels, rotation=15) + ax2.set_ylabel("Fraction of Variance", fontsize=11) + ax2.set_title("Variance Explained by Each Source", fontsize=12) + ax2.set_ylim(0, 1) + ax2.grid(axis="y", alpha=0.3) + + # Add values on bars + for i, h in enumerate(bar_heights): + ax2.text(i, h + 0.02, f"{h:.2f}", ha="center", fontsize=10) + + fig.suptitle("Forcing Attribution Summary", fontsize=14) + fig.tight_layout() + filename = "forcing_attribution_summary.png" + fig.savefig(path / filename, dpi=150) + plt.close(fig) + logger.info(f"Saved {filename} (R²_full={r2_full:.3f})") + + # ========================================================================= + # FORCING DIAGNOSTIC PLOTS - Spectral Analysis + # ========================================================================= + + def plot_phase_relationship( + self, + forcing: np.ndarray, + climate: np.ndarray, + forcing_name: str, + path: Path, + fs: float = 1.0, + n_dominant_freqs: int = 5, + ) -> None: + """ + Plot detailed phase lag analysis between forcing and climate oscillations. + + Extracts dominant frequencies and shows phase lags with interpretation (leading/lagging). + + Args: + forcing: Forcing data (spatial, time) or (time,) + climate: Climate data (spatial, time) + forcing_name: Name of forcing + path: Save path + fs: Sampling frequency (default 1.0 = 1 sample per timestep) + n_dominant_freqs: Number of dominant frequencies to highlight + """ + logger.info(f"Analyzing phase relationship for {forcing_name}") + + # Ensure forcing and climate have matching time dimensions + time_forcing = forcing.shape[-1] + time_climate = climate.shape[-1] + if time_forcing != time_climate: + min_time = min(time_forcing, time_climate) + logger.warning( + f"Time dimension mismatch: forcing has {time_forcing} timesteps, " + f"climate has {time_climate} timesteps. Trimming both to {min_time}." + ) + forcing = forcing[..., :min_time] + climate = climate[..., :min_time] + + # Get time series + forcing_ts = forcing.mean(axis=0) if forcing.ndim == 2 else forcing + climate_ts = climate.mean(axis=0) + + # Compute cross-spectrum + nperseg = min(len(forcing_ts) // 4, 256) + f, Pxy = signal.csd(forcing_ts, climate_ts, fs=fs, nperseg=nperseg) + phase = np.angle(Pxy, deg=True) + + # Also compute coherence to identify significant frequencies + f_coh, Cxy = signal.coherence(forcing_ts, climate_ts, fs=fs, nperseg=nperseg) + + # Find dominant frequencies (high coherence) + coherence_threshold = 0.3 + significant_mask = Cxy > coherence_threshold + if not significant_mask.any(): + logger.warning(f"No significant coherence found for {forcing_name}, using top {n_dominant_freqs} peaks") + dominant_indices = np.argsort(Cxy)[-n_dominant_freqs:] + else: + coherent_indices = np.where(significant_mask)[0] + # Among coherent frequencies, pick top ones by coherence + coherent_indices_sorted = coherent_indices[np.argsort(Cxy[coherent_indices])[-n_dominant_freqs:]] + dominant_indices = coherent_indices_sorted + + # Create comprehensive figure + fig = plt.figure(figsize=(14, 10)) + gs = fig.add_gridspec(3, 2, hspace=0.35, wspace=0.3) + + color = "tab:red" if forcing_name == "CO2" else "tab:blue" + + # Panel 1: Full phase spectrum + ax1 = fig.add_subplot(gs[0, :]) + ax1.plot(f, phase, color=color, linewidth=1.5, alpha=0.7) + ax1.scatter( + f[dominant_indices], + phase[dominant_indices], + s=100, + c="orange", + edgecolors="black", + zorder=10, + label="Dominant freqs", + ) + ax1.axhline(y=0, color="gray", linestyle="--", linewidth=1) + ax1.axhline(y=90, color="green", linestyle=":", linewidth=1, alpha=0.5, label="90° (quadrature)") + ax1.axhline(y=-90, color="green", linestyle=":", linewidth=1, alpha=0.5) + ax1.set_xlabel("Frequency", fontsize=11) + ax1.set_ylabel("Phase (degrees)", fontsize=11) + ax1.set_title(f"Phase Spectrum: {forcing_name} → Climate", fontsize=13, fontweight="bold") + ax1.set_ylim(-180, 180) + ax1.legend(loc="upper right") + ax1.grid(alpha=0.3) + + # Panel 2: Coherence (for reference) + ax2 = fig.add_subplot(gs[1, 0]) + ax2.plot(f_coh, Cxy, color=color, linewidth=1.5) + ax2.scatter(f[dominant_indices], Cxy[dominant_indices], s=100, c="orange", edgecolors="black", zorder=10) + ax2.axhline( + y=coherence_threshold, color="gray", linestyle="--", alpha=0.7, label=f"Threshold ({coherence_threshold})" + ) + ax2.set_xlabel("Frequency", fontsize=11) + ax2.set_ylabel("Coherence", fontsize=11) + ax2.set_title("Spectral Coherence", fontsize=12) + ax2.set_ylim(0, 1) + ax2.legend(loc="upper right") + ax2.grid(alpha=0.3) + + # Panel 3: Phase at dominant frequencies (bar chart) + ax3 = fig.add_subplot(gs[1, 1]) + dominant_freqs = f[dominant_indices] + dominant_phases = phase[dominant_indices] + _ = Cxy[dominant_indices] + + colors_bars = [ + "green" if -45 < p < 45 else "orange" if abs(abs(p) - 90) < 45 else "red" for p in dominant_phases + ] + ax3.bar(range(len(dominant_indices)), dominant_phases, color=colors_bars, alpha=0.7, edgecolor="black") + ax3.axhline(y=0, color="black", linestyle="-", linewidth=1) + ax3.set_xticks(range(len(dominant_indices))) + ax3.set_xticklabels([f"{freq:.3f}" for freq in dominant_freqs], rotation=45, ha="right") + ax3.set_xlabel("Frequency", fontsize=11) + ax3.set_ylabel("Phase (degrees)", fontsize=11) + ax3.set_title("Phase at Dominant Frequencies", fontsize=12) + ax3.set_ylim(-180, 180) + ax3.grid(axis="y", alpha=0.3) + + # Panel 4: Interpretation table + ax4 = fig.add_subplot(gs[2, :]) + ax4.axis("off") + + # Create interpretation text + table_data = [] + table_data.append(["Freq", "Period", "Phase", "Coherence", "Interpretation"]) + table_data.append(["-" * 8, "-" * 8, "-" * 8, "-" * 10, "-" * 40]) + + for i, idx in enumerate(dominant_indices): + freq = f[idx] + period = 1 / freq if freq > 0 else np.inf + ph = phase[idx] + coh = Cxy[idx] + + # Interpretation + if -45 < ph < 45: + interp = f"In-phase: {forcing_name} and climate move together" + elif 45 <= ph < 135: + interp = f"{forcing_name} leads climate by ~1/4 period" + elif ph >= 135 or ph <= -135: + interp = f"Anti-phase: {forcing_name} and climate opposite" + else: # -135 < ph <= -45 + interp = f"{forcing_name} lags climate by ~1/4 period" + + table_data.append( + [ + f"{freq:.4f}", + f"{period:.1f}" if period != np.inf else "∞", + f"{ph:.1f}°", + f"{coh:.3f}", + interp, + ] + ) + + # Render table + table = ax4.table( + cellText=table_data, + cellLoc="left", + loc="center", + colWidths=[0.12, 0.12, 0.12, 0.14, 0.5], + ) + table.auto_set_font_size(False) + table.set_fontsize(9) + table.scale(1, 2) + + # Style header row + for i in range(5): + cell = table[(0, i)] + cell.set_facecolor("#4CAF50") + cell.set_text_props(weight="bold", color="white") + + fig.suptitle(f"Phase Relationship Analysis: {forcing_name} ↔ Climate", fontsize=15, fontweight="bold") + filename = f"{forcing_name.lower()}_phase_relationship.png" + fig.savefig(path / filename, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {filename}") + + # ========================================================================= + # FORCING DIAGNOSTIC PLOTS - Tigramite Transfer Entropy + # ========================================================================= + + def plot_transfer_entropy_matrix( + self, + co2_latent: Optional[np.ndarray], + aerosol_latents: Optional[np.ndarray], + climate_modes: np.ndarray, + path: Path, + tau_max: int = 5, + significance_level: float = 0.05, + ) -> None: + """ + Compute and visualize transfer entropy from forcings to climate modes using Tigramite. + + Args: + co2_latent: CO2 latent trajectory (time,) or None + aerosol_latents: Aerosol latent trajectories (n_aerosol, time) or None + climate_modes: Climate mode trajectories (n_modes, time) + path: Save path + tau_max: Maximum time lag for PCMCI + significance_level: Significance level for link detection + """ + if not TIGRAMITE_AVAILABLE: + logger.warning("Tigramite not available, skipping transfer entropy analysis") + return + + logger.info("Computing transfer entropy using Tigramite PCMCI") + + # Build data array (n_vars, time) + var_list = [] + var_names = [] + + # Add forcing latents + if co2_latent is not None: + if co2_latent.ndim == 1: + var_list.append(co2_latent) + var_names.append("CO2") + else: + for i in range(co2_latent.shape[0]): + var_list.append(co2_latent[i]) + var_names.append(f"CO2_{i}") + + if aerosol_latents is not None: + if aerosol_latents.ndim == 1: + var_list.append(aerosol_latents) + var_names.append("Aero") + else: + for i in range(aerosol_latents.shape[0]): + var_list.append(aerosol_latents[i]) + var_names.append(f"A{i}") + + # Add climate modes + if climate_modes.ndim == 1: + var_list.append(climate_modes) + var_names.append("M0") + else: + for i in range(climate_modes.shape[0]): + var_list.append(climate_modes[i]) + var_names.append(f"M{i}") + + if len(var_list) < 2: + logger.warning("Need at least 2 variables for transfer entropy") + return + + # Stack into (time, n_vars) array for tigramite + data_array = np.column_stack(var_list) + + # Create tigramite dataframe + dataframe = pp.DataFrame(data_array, var_names=var_names) + + # Run PCMCI with partial correlation test + parcorr = ParCorr(significance="analytic") + pcmci = PCMCI(dataframe=dataframe, cond_ind_test=parcorr, verbosity=0) + + try: + results = pcmci.run_pcmci(tau_max=tau_max, pc_alpha=significance_level) + except Exception as e: + logger.error(f"PCMCI failed: {e}") + return + + # Extract link matrix + q_matrix = pcmci.get_corrected_pvalues(p_matrix=results["p_matrix"], fdr_method="fdr_bh") + link_matrix = np.where(q_matrix < significance_level, results["val_matrix"], 0) + + # Create custom visualization + n_vars = len(var_names) + n_forcing = (1 if co2_latent is not None else 0) + ( + aerosol_latents.shape[0] + if aerosol_latents is not None and aerosol_latents.ndim > 1 + else (1 if aerosol_latents is not None else 0) + ) + + # Summary: forcing → climate links (sum over lags) + forcing_to_climate = np.zeros((n_forcing, n_vars - n_forcing)) + for i in range(n_forcing): + for j in range(n_forcing, n_vars): + # Sum absolute link strengths over all lags + forcing_to_climate[i, j - n_forcing] = np.sum(np.abs(link_matrix[i, j, :])) + + # Plot heatmap + fig, ax = plt.subplots(figsize=(10, 6)) + + forcing_names = var_names[:n_forcing] + climate_names = var_names[n_forcing:] + + im = ax.imshow(forcing_to_climate, cmap="YlOrRd", aspect="auto") + ax.set_xticks(range(len(climate_names))) + ax.set_yticks(range(len(forcing_names))) + ax.set_xticklabels(climate_names, rotation=45, ha="right") + ax.set_yticklabels(forcing_names) + ax.set_xlabel("Climate Modes", fontsize=12) + ax.set_ylabel("Forcing Latents", fontsize=12) + ax.set_title(f"Causal Influence: Forcings → Climate (PCMCI, τ_max={tau_max})", fontsize=14) + + # Add values + for i in range(len(forcing_names)): + for j in range(len(climate_names)): + val = forcing_to_climate[i, j] + color = "white" if val > forcing_to_climate.max() / 2 else "black" + ax.text(j, i, f"{val:.2f}", ha="center", va="center", color=color, fontsize=9) + + fig.colorbar(im, ax=ax, label="Summed |link strength|") + fig.tight_layout() + filename = "transfer_entropy_matrix.png" + fig.savefig(path / filename, dpi=150) + plt.close(fig) + logger.info(f"Saved {filename}") + + # Also save Tigramite's built-in graph plot + try: + fig, ax = plt.subplots(figsize=(12, 8)) + tp.plot_graph( + val_matrix=results["val_matrix"], + graph=results["graph"], + var_names=var_names, + link_colorbar_label="Cross-MCI", + node_colorbar_label="Auto-MCI", + fig_ax=(fig, ax), + ) + filename_graph = "tigramite_causal_graph.png" + fig.savefig(path / filename_graph, dpi=150) + plt.close(fig) + logger.info(f"Saved {filename_graph}") + except Exception as e: + logger.warning(f"Could not save Tigramite graph plot: {e}") + + def plot_mutual_information_matrix( + self, + co2_latent: Optional[np.ndarray], + aerosol_latents: Optional[np.ndarray], + climate_modes: np.ndarray, + path: Path, + n_bins: int = 20, + ) -> None: + """ + Compute and visualize mutual information between forcing latents and climate modes. + + Args: + co2_latent: CO2 latent trajectory (time,) or None + aerosol_latents: Aerosol latent trajectories (n_aerosol, time) or None + climate_modes: Climate mode trajectories (n_modes, time) + path: Save path + n_bins: Number of bins for histogram-based MI estimation + """ + logger.info("Computing mutual information matrix") + + # Build variable lists + forcing_list = [] + forcing_names = [] + + if co2_latent is not None: + if co2_latent.ndim == 1: + forcing_list.append(co2_latent) + forcing_names.append("CO2") + else: + for i in range(co2_latent.shape[0]): + forcing_list.append(co2_latent[i]) + forcing_names.append(f"CO2_{i}") + + if aerosol_latents is not None: + if aerosol_latents.ndim == 1: + forcing_list.append(aerosol_latents) + forcing_names.append("Aero") + else: + for i in range(aerosol_latents.shape[0]): + forcing_list.append(aerosol_latents[i]) + forcing_names.append(f"A{i}") + + climate_list = [] + climate_names = [] + if climate_modes.ndim == 1: + climate_list.append(climate_modes) + climate_names.append("M0") + else: + for i in range(climate_modes.shape[0]): + climate_list.append(climate_modes[i]) + climate_names.append(f"M{i}") + + if len(forcing_list) == 0 or len(climate_list) == 0: + logger.warning("Need forcing and climate variables for MI computation") + return + + # Compute MI matrix (forcing × climate) + mi_matrix = np.zeros((len(forcing_list), len(climate_list))) + + for i, f_var in enumerate(forcing_list): + for j, c_var in enumerate(climate_list): + # Mutual information using histogram method + # MI(X,Y) = H(X) + H(Y) - H(X,Y) + # Where H is entropy + c_xy, _, _ = np.histogram2d(f_var, c_var, bins=n_bins) + c_x = np.histogram(f_var, bins=n_bins)[0] + c_y = np.histogram(c_var, bins=n_bins)[0] + + # Normalize to probabilities + p_xy = c_xy / np.sum(c_xy) + p_x = c_x / np.sum(c_x) + p_y = c_y / np.sum(c_y) + + # Compute entropies (ignore zero probabilities) + h_x = -np.sum(p_x[p_x > 0] * np.log2(p_x[p_x > 0])) + h_y = -np.sum(p_y[p_y > 0] * np.log2(p_y[p_y > 0])) + h_xy = -np.sum(p_xy[p_xy > 0] * np.log2(p_xy[p_xy > 0])) + + mi_matrix[i, j] = h_x + h_y - h_xy + + # Plot heatmap + fig, ax = plt.subplots(figsize=(10, 6)) + im = ax.imshow(mi_matrix, cmap="YlOrRd", aspect="auto") + ax.set_xticks(range(len(climate_names))) + ax.set_yticks(range(len(forcing_names))) + ax.set_xticklabels(climate_names, rotation=45, ha="right") + ax.set_yticklabels(forcing_names) + ax.set_xlabel("Climate Modes", fontsize=12) + ax.set_ylabel("Forcing Latents", fontsize=12) + ax.set_title("Mutual Information: Forcings ↔ Climate", fontsize=14) + + # Add values + for i in range(len(forcing_names)): + for j in range(len(climate_names)): + val = mi_matrix[i, j] + color = "white" if val > mi_matrix.max() / 2 else "black" + ax.text(j, i, f"{val:.2f}", ha="center", va="center", color=color, fontsize=9) + + fig.colorbar(im, ax=ax, label="MI (bits)") + fig.tight_layout() + filename = "mutual_information_matrix.png" + fig.savefig(path / filename, dpi=150) + plt.close(fig) + logger.info(f"Saved {filename} (mean MI={np.mean(mi_matrix):.3f} bits)") + + def plot_conditional_correlation_network( + self, + co2_latent: Optional[np.ndarray], + aerosol_latents: Optional[np.ndarray], + climate_modes: np.ndarray, + path: Path, + threshold: float = 0.2, + ) -> None: + """ + Plot network graph showing partial correlations to distinguish direct vs indirect effects. + + Args: + co2_latent: CO2 latent trajectory (time,) or None + aerosol_latents: Aerosol latent trajectories (n_aerosol, time) or None + climate_modes: Climate mode trajectories (n_modes, time) + path: Save path + threshold: Minimum absolute partial correlation to display + """ + logger.info("Computing conditional correlation network") + + # Build data array + var_list = [] + var_names = [] + var_types = [] # 'forcing' or 'climate' + + if co2_latent is not None: + if co2_latent.ndim == 1: + var_list.append(co2_latent) + var_names.append("CO2") + var_types.append("forcing") + else: + for i in range(co2_latent.shape[0]): + var_list.append(co2_latent[i]) + var_names.append(f"CO2_{i}") + var_types.append("forcing") + + if aerosol_latents is not None: + if aerosol_latents.ndim == 1: + var_list.append(aerosol_latents) + var_names.append("Aero") + var_types.append("forcing") + else: + for i in range(aerosol_latents.shape[0]): + var_list.append(aerosol_latents[i]) + var_names.append(f"A{i}") + var_types.append("forcing") + + if climate_modes.ndim == 1: + var_list.append(climate_modes) + var_names.append("M0") + var_types.append("climate") + else: + for i in range(climate_modes.shape[0]): + var_list.append(climate_modes[i]) + var_names.append(f"M{i}") + var_types.append("climate") + + if len(var_list) < 2: + logger.warning("Need at least 2 variables for correlation network") + return + + # Stack into array (time, n_vars) + data_array = np.column_stack(var_list) + n_vars = len(var_list) + + # Compute correlation matrix + corr_matrix = np.corrcoef(data_array.T) + + # Compute partial correlations (simple inverse covariance method) + # Partial corr: -cov_inv[i,j] / sqrt(cov_inv[i,i] * cov_inv[j,j]) + try: + precision = np.linalg.inv(corr_matrix + 1e-6 * np.eye(n_vars)) + partial_corr = np.zeros((n_vars, n_vars)) + for i in range(n_vars): + for j in range(n_vars): + if i != j: + partial_corr[i, j] = -precision[i, j] / np.sqrt(precision[i, i] * precision[j, j]) + except np.linalg.LinAlgError: + logger.warning("Could not compute partial correlations, using regular correlations") + partial_corr = corr_matrix + + # Create network visualization + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7)) + + # Plot 1: Heatmap of partial correlations + im1 = ax1.imshow(partial_corr, cmap="RdBu_r", vmin=-1, vmax=1, aspect="auto") + ax1.set_xticks(range(n_vars)) + ax1.set_yticks(range(n_vars)) + ax1.set_xticklabels(var_names, rotation=45, ha="right") + ax1.set_yticklabels(var_names) + ax1.set_title("Partial Correlation Matrix", fontsize=14) + fig.colorbar(im1, ax=ax1, label="Partial Correlation") + + # Plot 2: Network graph (simple circular layout) + ax2.set_xlim(-1.5, 1.5) + ax2.set_ylim(-1.5, 1.5) + ax2.set_aspect("equal") + ax2.axis("off") + ax2.set_title(f"Correlation Network (|r| > {threshold})", fontsize=14) + + # Node positions (circular layout) + angles = np.linspace(0, 2 * np.pi, n_vars, endpoint=False) + positions = np.column_stack([np.cos(angles), np.sin(angles)]) + + # Draw edges (links above threshold) + for i in range(n_vars): + for j in range(i + 1, n_vars): + pc = partial_corr[i, j] + if abs(pc) > threshold: + x_vals = [positions[i, 0], positions[j, 0]] + y_vals = [positions[i, 1], positions[j, 1]] + color = "red" if pc > 0 else "blue" + width = abs(pc) * 3 + ax2.plot(x_vals, y_vals, color=color, linewidth=width, alpha=0.6) + + # Draw nodes + for i, (name, typ) in enumerate(zip(var_names, var_types)): + color = "orange" if typ == "forcing" else "lightblue" + ax2.scatter(positions[i, 0], positions[i, 1], s=500, c=color, edgecolors="black", linewidths=2, zorder=10) + ax2.text( + positions[i, 0] * 1.2, + positions[i, 1] * 1.2, + name, + fontsize=11, + ha="center", + va="center", + fontweight="bold", + ) + + # Legend + from matplotlib.lines import Line2D + + legend_elements = [ + Line2D([0], [0], marker="o", color="w", markerfacecolor="orange", markersize=10, label="Forcing"), + Line2D([0], [0], marker="o", color="w", markerfacecolor="lightblue", markersize=10, label="Climate"), + Line2D([0], [0], color="red", linewidth=2, label="Positive corr"), + Line2D([0], [0], color="blue", linewidth=2, label="Negative corr"), + ] + ax2.legend(handles=legend_elements, loc="upper right", fontsize=10) + + fig.tight_layout() + filename = "conditional_correlation_network.png" + fig.savefig(path / filename, dpi=150) + plt.close(fig) + logger.info(f"Saved {filename}") + + def plot_causal_graph_tigramite( + self, + data_array: np.ndarray, + var_names: List[str], + path: Path, + tau_max: int = 5, + pc_alpha: float = 0.05, + ) -> None: + """ + Generate Tigramite-style causal graph visualization with lag information. + + Args: + data_array: Data (time, n_vars) or (n_vars, time) - will be transposed if needed + var_names: Variable names ['CO2', 'A0', 'A1', ..., 'M0', 'M1', ...] + path: Save path + tau_max: Maximum time lag for PCMCI + pc_alpha: Significance level for link detection + """ + if not TIGRAMITE_AVAILABLE: + logger.warning("Tigramite not available, skipping causal graph") + return + + logger.info("Generating Tigramite causal graph") + + # Ensure data is (time, n_vars) + if data_array.shape[0] < data_array.shape[1]: + data_array = data_array.T + + # Create tigramite dataframe + dataframe = pp.DataFrame(data_array, var_names=var_names) + + # Run PCMCI + parcorr = ParCorr(significance="analytic") + pcmci = PCMCI(dataframe=dataframe, cond_ind_test=parcorr, verbosity=0) + + try: + results = pcmci.run_pcmci(tau_max=tau_max, pc_alpha=pc_alpha) + except Exception as e: + logger.error(f"PCMCI failed: {e}") + return + + # Create Tigramite's graph visualization + fig, ax = plt.subplots(figsize=(14, 10)) + + try: + tp.plot_graph( + val_matrix=results["val_matrix"], + graph=results["graph"], + var_names=var_names, + link_colorbar_label="MCI value", + node_colorbar_label="Auto-MCI", + fig_ax=(fig, ax), + ) + + ax.set_title( + f"Tigramite Causal Graph (τ_max={tau_max}, α={pc_alpha})", fontsize=14, fontweight="bold", pad=20 + ) + + filename = "tigramite_causal_graph_full.png" + fig.savefig(path / filename, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {filename}") + + except Exception as e: + logger.error(f"Could not create Tigramite graph: {e}") + plt.close(fig) + + def plot_information_flow_arrows( + self, + transfer_entropy_matrix: np.ndarray, + source_names: List[str], + target_names: List[str], + lat: int, + lon: int, + path: Path, + source_positions: Optional[np.ndarray] = None, + target_positions: Optional[np.ndarray] = None, + ) -> None: + """ + Overlay information flow arrows on spatial grid. + + Arrow thickness represents transfer entropy strength from sources to targets. + + Args: + transfer_entropy_matrix: Transfer entropy (n_sources, n_targets) + source_names: Source variable names (e.g., ['CO2', 'A0', ...]) + target_names: Target variable names (e.g., ['M0', 'M1', ...]) + lat, lon: Grid dimensions + path: Save path + source_positions: Optional source positions on grid (n_sources, 2) as [lat_idx, lon_idx] + target_positions: Optional target positions on grid (n_targets, 2) as [lat_idx, lon_idx] + """ + logger.info("Creating information flow arrow visualization") + + n_sources = len(source_names) + n_targets = len(target_names) + + # If positions not provided, create default positions + if source_positions is None: + # Place sources on left side + source_positions = np.zeros((n_sources, 2)) + for i in range(n_sources): + source_positions[i] = [lat // 2, lon // 4] # Same position for all (will offset in plot) + + if target_positions is None: + # Place targets on right side, distributed vertically + target_positions = np.zeros((n_targets, 2)) + for i in range(n_targets): + target_positions[i] = [int(lat * (i + 1) / (n_targets + 1)), 3 * lon // 4] + + # Create grid background + fig, ax = plt.subplots(figsize=(12, 8)) + + # Draw a simple grid + ax.set_xlim(0, lon) + ax.set_ylim(0, lat) + ax.set_aspect("equal") + ax.grid(alpha=0.2) + ax.set_xlabel("Longitude index", fontsize=12) + ax.set_ylabel("Latitude index", fontsize=12) + ax.set_title("Information Flow: Forcings → Climate Modes", fontsize=14, fontweight="bold") + + # Normalize transfer entropy for arrow widths + te_max = np.max(transfer_entropy_matrix) + te_min = np.min(transfer_entropy_matrix) + + if te_max > te_min: + te_normalized = (transfer_entropy_matrix - te_min) / (te_max - te_min) + else: + te_normalized = np.zeros_like(transfer_entropy_matrix) + + # Draw arrows + for i in range(n_sources): + for j in range(n_targets): + te_val = transfer_entropy_matrix[i, j] + if te_val > 0.01: # Only draw significant flows + # Source and target positions + src_pos = source_positions[i] + tgt_pos = target_positions[j] + + # Arrow properties + arrow_width = te_normalized[i, j] * 5 # Scale width + color_intensity = te_normalized[i, j] + arrow_color = plt.cm.YlOrRd(color_intensity) + + # Draw arrow + ax.annotate( + "", + xy=(tgt_pos[1], tgt_pos[0]), + xytext=(src_pos[1], src_pos[0]), + arrowprops=dict( + arrowstyle="->", lw=arrow_width, color=arrow_color, alpha=0.7, shrinkA=10, shrinkB=10 + ), + ) + + # Draw source nodes + for i, name in enumerate(source_names): + pos = source_positions[i] + ax.scatter(pos[1], pos[0], s=300, c="orange", edgecolors="black", linewidths=2, zorder=10) + ax.text(pos[1], pos[0] + lat * 0.05, name, fontsize=11, ha="center", fontweight="bold") + + # Draw target nodes + for j, name in enumerate(target_names): + pos = target_positions[j] + ax.scatter(pos[1], pos[0], s=300, c="lightblue", edgecolors="black", linewidths=2, zorder=10) + ax.text(pos[1], pos[0] + lat * 0.05, name, fontsize=11, ha="center", fontweight="bold") + + # Add colorbar for transfer entropy + sm = plt.cm.ScalarMappable(cmap=plt.cm.YlOrRd, norm=plt.Normalize(vmin=te_min, vmax=te_max)) + sm.set_array([]) + fig.colorbar(sm, ax=ax, label="Transfer Entropy", shrink=0.7) + + # Legend + from matplotlib.lines import Line2D + + legend_elements = [ + Line2D([0], [0], marker="o", color="w", markerfacecolor="orange", markersize=10, label="Forcing"), + Line2D([0], [0], marker="o", color="w", markerfacecolor="lightblue", markersize=10, label="Climate Mode"), + ] + ax.legend(handles=legend_elements, loc="upper left", fontsize=10) + + fig.tight_layout() + filename = "information_flow_arrows.png" + fig.savefig(path / filename, dpi=150) + plt.close(fig) + logger.info(f"Saved {filename}") + + # ========================================================================= + # FORCING DIAGNOSTIC PLOTS - Ground Truth Comparison + # ========================================================================= + + def plot_gt_vs_learned_forcing_effect( + self, + gt_adj: np.ndarray, + learned_adj: np.ndarray, + forcing_indices: Dict[str, List[int]], + path: Path, + iteration: int, + ) -> None: + """ + Compare known ground truth causal coefficients with learned adjacency. + + Args: + gt_adj: Ground truth adjacency matrix (tau, n_latents, n_latents) + learned_adj: Learned adjacency matrix (tau, n_latents, n_latents) + forcing_indices: Dict with 'co2' and 'aerosol' latent indices + path: Save path + iteration: Current iteration + """ + logger.info("Comparing GT vs learned forcing effects") + + if forcing_indices is None: + logger.warning("No forcing indices available for GT comparison") + return + + # Get forcing and climate indices + co2_idx = forcing_indices.get("co2", []) + aerosol_idx = forcing_indices.get("aerosol", []) + n_total = forcing_indices.get("n_total", learned_adj.shape[1]) + n_climate = n_total - len(co2_idx) - len(aerosol_idx) + climate_idx = list(range(n_climate)) + + all_forcing_idx = co2_idx + aerosol_idx + + # Extract forcing → climate submatrices (sum over tau) + gt_forcing_to_climate = np.zeros((len(all_forcing_idx), n_climate)) + learned_forcing_to_climate = np.zeros((len(all_forcing_idx), n_climate)) + + for fi, forcing_i in enumerate(all_forcing_idx): + for ci, climate_i in enumerate(climate_idx): + # Sum over all lags + gt_forcing_to_climate[fi, ci] = np.sum(np.abs(gt_adj[:, forcing_i, climate_i])) + learned_forcing_to_climate[fi, ci] = np.sum(np.abs(learned_adj[:, forcing_i, climate_i])) + + # Create comparison figure + fig, axes = plt.subplots(1, 3, figsize=(15, 5)) + + # GT + im0 = axes[0].imshow(gt_forcing_to_climate, cmap="YlOrRd", aspect="auto") + axes[0].set_title("Ground Truth\nForcing → Climate", fontsize=12) + axes[0].set_xlabel("Climate Mode") + axes[0].set_ylabel("Forcing Latent") + fig.colorbar(im0, ax=axes[0], shrink=0.8) + + # Learned + im1 = axes[1].imshow(learned_forcing_to_climate, cmap="YlOrRd", aspect="auto") + axes[1].set_title("Learned\nForcing → Climate", fontsize=12) + axes[1].set_xlabel("Climate Mode") + fig.colorbar(im1, ax=axes[1], shrink=0.8) + + # Difference + diff = learned_forcing_to_climate - gt_forcing_to_climate + vabs = max(np.abs(diff).max(), 0.01) + im2 = axes[2].imshow(diff, cmap="RdBu_r", vmin=-vabs, vmax=vabs, aspect="auto") + axes[2].set_title("Difference\n(Learned - GT)", fontsize=12) + axes[2].set_xlabel("Climate Mode") + fig.colorbar(im2, ax=axes[2], shrink=0.8) + + # Add y-axis labels + forcing_labels = [f"CO2_{i}" for i in range(len(co2_idx))] + [f"A{i}" for i in range(len(aerosol_idx))] + for ax in axes: + ax.set_yticks(range(len(forcing_labels))) + ax.set_yticklabels(forcing_labels) + ax.set_xticks(range(n_climate)) + ax.set_xticklabels([f"M{i}" for i in range(n_climate)]) + + fig.suptitle(f"GT vs Learned Forcing Effects (iter={iteration})", fontsize=14) + fig.tight_layout() + filename = f"gt_vs_learned_forcing_{iteration}.png" + fig.savefig(path / filename, dpi=150) + plt.close(fig) + logger.info(f"Saved {filename}") + + def plot_forcing_latent_reconstruction_error( + self, + gt_co2_latent: Optional[np.ndarray], + gt_aerosol_latent: Optional[np.ndarray], + learned_co2_latent: Optional[np.ndarray], + learned_aerosol_latent: Optional[np.ndarray], + path: Path, + iteration: int, + ) -> None: + """ + Track how well learned forcing latents match ground truth trajectories. + + Args: + gt_co2_latent: Ground truth CO2 latent (time,) or (n, time) + gt_aerosol_latent: Ground truth aerosol latents (n_aerosol, time) + learned_co2_latent: Learned CO2 latent (time,) or (n, time) + learned_aerosol_latent: Learned aerosol latents (n_aerosol, time) + path: Save path + iteration: Current iteration + """ + logger.info("Computing forcing latent reconstruction error") + + n_plots = sum([gt_co2_latent is not None, gt_aerosol_latent is not None]) + if n_plots == 0: + logger.warning("No ground truth latents available for comparison") + return + + fig, axes = plt.subplots(1, n_plots, figsize=(7 * n_plots, 5)) + if n_plots == 1: + axes = [axes] + + plot_idx = 0 + + # CO2 comparison + if gt_co2_latent is not None and learned_co2_latent is not None: + ax = axes[plot_idx] + gt = gt_co2_latent.flatten() if gt_co2_latent.ndim > 1 else gt_co2_latent + learned = learned_co2_latent.flatten() if learned_co2_latent.ndim > 1 else learned_co2_latent + + # Align lengths + min_len = min(len(gt), len(learned)) + gt = gt[:min_len] + learned = learned[:min_len] + + # Normalize for comparison (correlation doesn't care about scale) + gt_norm = (gt - gt.mean()) / (gt.std() + 1e-8) + learned_norm = (learned - learned.mean()) / (learned.std() + 1e-8) + + time_axis = np.arange(min_len) + ax.plot(time_axis, gt_norm, label="GT", color="black", linewidth=1.5) + ax.plot(time_axis, learned_norm, label="Learned", color="tab:red", linewidth=1.5, alpha=0.8) + + corr = np.corrcoef(gt_norm, learned_norm)[0, 1] + mse = np.mean((gt_norm - learned_norm) ** 2) + + ax.set_title(f"CO2 Latent (r={corr:.3f}, MSE={mse:.3f})", fontsize=12) + ax.set_xlabel("Time") + ax.set_ylabel("Normalized value") + ax.legend() + ax.grid(alpha=0.3) + plot_idx += 1 + + # Aerosol comparison + if gt_aerosol_latent is not None and learned_aerosol_latent is not None: + ax = axes[plot_idx] + + # Handle different shapes + gt_aero = gt_aerosol_latent + learned_aero = learned_aerosol_latent + + if gt_aero.ndim == 1: + gt_aero = gt_aero.reshape(1, -1) + if learned_aero.ndim == 1: + learned_aero = learned_aero.reshape(1, -1) + + n_gt = gt_aero.shape[0] + n_learned = learned_aero.shape[0] + n_plot = min(n_gt, n_learned) + min_len = min(gt_aero.shape[1], learned_aero.shape[1]) + + colors_gt = plt.cm.Blues(np.linspace(0.4, 0.9, n_plot)) + colors_learned = plt.cm.Reds(np.linspace(0.4, 0.9, n_plot)) + + time_axis = np.arange(min_len) + corrs = [] + + for i in range(n_plot): + gt_i = gt_aero[i, :min_len] + learned_i = learned_aero[i, :min_len] + + gt_norm = (gt_i - gt_i.mean()) / (gt_i.std() + 1e-8) + learned_norm = (learned_i - learned_i.mean()) / (learned_i.std() + 1e-8) + + ax.plot(time_axis, gt_norm, color=colors_gt[i], linewidth=1, label=f"GT A{i}" if i == 0 else None) + ax.plot( + time_axis, + learned_norm, + color=colors_learned[i], + linewidth=1, + linestyle="--", + label=f"Learned A{i}" if i == 0 else None, + ) + + corrs.append(np.corrcoef(gt_norm, learned_norm)[0, 1]) + + mean_corr = np.mean(corrs) + ax.set_title(f"Aerosol Latents (mean r={mean_corr:.3f})", fontsize=12) + ax.set_xlabel("Time") + ax.set_ylabel("Normalized value") + ax.legend(loc="upper right") + ax.grid(alpha=0.3) + + fig.suptitle(f"Forcing Latent Reconstruction (iter={iteration})", fontsize=14) + fig.tight_layout() + filename = f"forcing_latent_reconstruction_{iteration}.png" + fig.savefig(path / filename, dpi=150) + plt.close(fig) + logger.info(f"Saved {filename}") + + # ========================================================================= + # INTEGRATION METHOD - Call all forcing diagnostics + # ========================================================================= + + def plot_forcing_diagnostics( + self, + savar_data: np.ndarray, + co2_forcing: Optional[np.ndarray], + aerosol_forcing: Optional[np.ndarray], + gt_co2_latent: Optional[np.ndarray], + gt_aerosol_latent: Optional[np.ndarray], + lat: int, + lon: int, + path: Path, + max_lag: int = 20, + window_size: int = 200, + ) -> None: + """ + Generate all forcing diagnostic plots. + + This is the main entry point for comprehensive forcing analysis. + Called during data generation and optionally during training. + + Args: + savar_data: Climate data (spatial, time) + co2_forcing: CO2 forcing (spatial, time) or None + aerosol_forcing: Aerosol forcing (spatial, time) or None + gt_co2_latent: Ground truth CO2 latent trajectory + gt_aerosol_latent: Ground truth aerosol latent trajectories + lat, lon: Grid dimensions + path: Save path + max_lag: Maximum lag for cross-correlation + window_size: Window size for sliding correlations + """ + logger.info("=" * 60) + logger.info("GENERATING COMPREHENSIVE FORCING DIAGNOSTICS") + logger.info("=" * 60) + + # 1. Global Correlation Analysis + logger.info("--- Global Correlation Analysis ---") + if co2_forcing is not None: + self.plot_correlation_heatmap_over_time(co2_forcing, savar_data, "CO2", window_size, path) + if aerosol_forcing is not None: + self.plot_correlation_heatmap_over_time(aerosol_forcing, savar_data, "Aerosol", window_size, path) + + # 2. Spatial Correlation Analysis + logger.info("--- Spatial Correlation Analysis ---") + if co2_forcing is not None: + self.plot_pointwise_correlation_map(co2_forcing, savar_data, "CO2", lat, lon, path, lag=0) + self.plot_pointwise_correlation_map(co2_forcing, savar_data, "CO2", lat, lon, path, lag=1) + if aerosol_forcing is not None: + self.plot_pointwise_correlation_map(aerosol_forcing, savar_data, "Aerosol", lat, lon, path, lag=0) + + # 3. Joint Animations + logger.info("--- Joint Animations ---") + self.plot_joint_forcing_climate_animation(co2_forcing, aerosol_forcing, savar_data, lat, lon, path) + + # 4. Variance Attribution + logger.info("--- Variance Attribution ---") + self.plot_variance_explained_by_forcing(co2_forcing, aerosol_forcing, savar_data, lat, lon, path, window_size) + self.plot_forcing_attribution_summary(co2_forcing, aerosol_forcing, savar_data, path) + + # 5. Causal/Information-Theoretic Analysis + if gt_co2_latent is not None or gt_aerosol_latent is not None: + logger.info("--- Causal & Information-Theoretic Analysis ---") + # For SAVAR, we use the latent trajectories, not the full spatial data + # We need to extract climate mode trajectories - for now use spatial mean as proxy + climate_proxy = savar_data.mean(axis=0).reshape(1, -1) # (1, time) + + # Mutual information + self.plot_mutual_information_matrix(gt_co2_latent, gt_aerosol_latent, climate_proxy, path) + + # Conditional correlation network + self.plot_conditional_correlation_network(gt_co2_latent, gt_aerosol_latent, climate_proxy, path) + + # Transfer entropy (if Tigramite available) + if TIGRAMITE_AVAILABLE: + self.plot_transfer_entropy_matrix(gt_co2_latent, gt_aerosol_latent, climate_proxy, path) + + # Tigramite causal graph + var_list = [] + var_names = [] + if gt_co2_latent is not None: + var_list.append(gt_co2_latent if gt_co2_latent.ndim == 1 else gt_co2_latent[0]) + var_names.append("CO2") + if gt_aerosol_latent is not None: + if gt_aerosol_latent.ndim == 1: + var_list.append(gt_aerosol_latent) + var_names.append("Aero") + else: + for i in range(gt_aerosol_latent.shape[0]): + var_list.append(gt_aerosol_latent[i]) + var_names.append(f"A{i}") + var_list.append(climate_proxy[0]) + var_names.append("M0") + + if len(var_list) >= 2: + data_array = np.column_stack(var_list) + self.plot_causal_graph_tigramite(data_array, var_names, path, tau_max=5) + + logger.info("=" * 60) + logger.info("FORCING DIAGNOSTICS COMPLETE") + logger.info("=" * 60) + + def plot_training_forcing_diagnostics( + self, + learner, + iteration: int, + path: Path, + ) -> None: + """ + Subset of forcing diagnostics suitable for training checkpoints. + + Avoids expensive computations like full transfer entropy. + Called at plot_freq intervals during training. + + Args: + learner: TrainingLatent instance + iteration: Current iteration + path: Save path + """ + logger.info(f"Generating training forcing diagnostics (iter={iteration})") + + # Get forcing data from datamodule + datamodule = getattr(learner, "datamodule", None) + if datamodule is None: + return + + co2_forcing = getattr(datamodule, "co2_forcing", None) + aerosol_forcing = getattr(datamodule, "aerosol_forcing", None) + + if co2_forcing is None and aerosol_forcing is None: + return + + # Get climate data + savar_data = getattr(datamodule, "savar_data", None) + if savar_data is None: + return + + # GT vs learned comparison (if available) + gt_adj = getattr(datamodule, "savar_gt_adj", None) + forcing_indices = getattr(datamodule, "forcing_indices", None) + + if gt_adj is not None and forcing_indices is not None: + learned_adj = learner.model.module.get_adj().cpu().detach().numpy() + self.plot_gt_vs_learned_forcing_effect(gt_adj, learned_adj, forcing_indices, path, iteration) + + # Forcing latent reconstruction (if model has forcing encoders) + # This would require extracting latent trajectories from the model + # For now, skip - can be added when model provides this interface + + # Less frequent: variance explained (every 5 * plot_freq) + plot_freq = learner.plot_params.plot_freq + if iteration % (5 * plot_freq) == 0: + lat = learner.lat + lon = learner.lon + self.plot_variance_explained_by_forcing( + co2_forcing, aerosol_forcing, savar_data, lat, lon, path, window_size=200 + ) + + def plot_sparsity(self, learner, save=False): + """ + Override parent method to handle SAVAR-specific plotting completely. + + This avoids conflicts with parent's SAVAR handling by implementing the full plotting logic for SAVAR + experiments. + """ + # Save coordinates + np.save(learner.plots_path / "coordinates.npy", learner.coordinates) + + if save: + self.save(learner) + + # 1. Plot learning curves (same for all experiments) + if learner.latent: + self.plot_learning_curves( + train_loss=learner.train_loss_list, + train_recons=learner.train_recons_list, + train_kl=learner.train_kl_list, + valid_loss=learner.valid_loss_list, + valid_recons=learner.valid_recons_list, + valid_kl=learner.valid_kl_list, + best_metrics=learner.best_metrics, + iteration=learner.iteration, + plot_through_time=learner.plot_params.plot_through_time, + path=learner.plots_path, + ) + + # Plot penalties and losses + losses = [ + {"name": "tr ortho", "data": learner.train_ortho_cons_list, "s": ":"}, + {"name": "mu ortho", "data": learner.mu_ortho_list, "s": ":"}, + {"name": "tr sparsity", "data": learner.train_sparsity_cons_list, "s": ":"}, + {"name": "tr var adj", "data": learner.train_transition_var_list, "s": ":"}, + {"name": "mu sparsity", "data": learner.mu_sparsity_list, "s": ":"}, + ] + self.plot_learning_curves2( + losses=losses, + iteration=learner.iteration, + plot_through_time=learner.plot_params.plot_through_time, + path=learner.plots_path, + fname="penalties", + yaxis_log=True, + ) + + losses = [ + {"name": "tr loss", "data": learner.train_loss_list, "s": "-."}, + {"name": "tr recons", "data": learner.train_recons_list, "s": "-"}, + {"name": "val recons", "data": learner.valid_recons_list, "s": "-"}, + {"name": "KL", "data": learner.train_kl_list, "s": "-"}, + {"name": "val loss", "data": learner.valid_loss_list, "s": "-."}, + {"name": "tr ELBO", "data": learner.train_elbo_list, "s": "-."}, + {"name": "val ELBO", "data": learner.valid_elbo_list, "s": "-."}, + ] + self.plot_learning_curves2( + losses=losses, + iteration=learner.iteration, + plot_through_time=learner.plot_params.plot_through_time, + path=learner.plots_path, + fname="losses", + ) + + logvar = [ + {"name": "logvar encoder", "data": learner.logvar_encoder_tt, "s": "-"}, + {"name": "logvar decoder", "data": learner.logvar_decoder_tt, "s": "-"}, + {"name": "logvar transition", "data": learner.logvar_transition_tt, "s": "-"}, + ] + self.plot_learning_curves2( + losses=logvar, + iteration=learner.iteration, + plot_through_time=learner.plot_params.plot_through_time, + path=learner.plots_path, + fname="logvar", + ) + + # 2. SAVAR-specific: prepare context and plot original data + logger.info("Preparing SAVAR-specific visualizations") + modes_gt = self.prepare_savar_context(learner) + + # 3. Get adjacency matrices + adj = learner.model.module.get_adj().cpu().detach().numpy() + adj_w = learner.model.module.autoencoder.get_w_decoder().cpu().detach().numpy() + adj_w2 = learner.model.module.autoencoder.get_w_encoder().cpu().detach().numpy() + + # 4. Plot SAVAR adjacency matrix with spatial alignment + self.plot_adjacency_matrix_savar( + learner=learner, + mat1=adj, + mat2=learner.datamodule.savar_gt_adj, + modes_gt=modes_gt, + modes_inferred=adj_w, + path=learner.plots_path, + name_suffix="transition", + no_gt=False, + iteration=learner.iteration, + plot_through_time=learner.plot_params.plot_through_time, + ) + + # 5. Plot weight matrices + if learner.latent: + # Plot decoder and encoder weight matrices + self.plot_adjacency_matrix_w(adj_w, None, learner.plots_path, "w", no_gt=True) + adj_w2 = np.swapaxes(adj_w2, 1, 2) + self.plot_adjacency_matrix_w(adj_w2, None, learner.plots_path, "encoder_w", no_gt=True) + + # Plot SAVAR feature maps (spatial patterns of learned latents) + self.plot_savar_feature_maps( + learner, + adj_w, + coordinates=learner.coordinates, + iteration=learner.iteration, + plot_through_time=learner.plot_params.plot_through_time, + path=learner.plots_path, + ) + + # Plot decoder connectivity heatmap (NEW) + self.plot_decoder_connectivity_heatmap( + learner, + adj_w, + iteration=learner.iteration, + plot_through_time=learner.plot_params.plot_through_time, + path=learner.plots_path, + ) + + logger.info("Completed SAVAR-specific plotting") diff --git a/climatem/synthetic_data/generate_savar_datasets.py b/climatem/synthetic_data/generate_savar_datasets.py index 223b375..79393b7 100644 --- a/climatem/synthetic_data/generate_savar_datasets.py +++ b/climatem/synthetic_data/generate_savar_datasets.py @@ -1,3 +1,20 @@ +""" +Orchestration layer for generating synthetic SAVAR datasets. + +This module creates SAVAR datasets with randomized +causal structure, saves them to disk (as .npy, .csv, and .json), and produces +diagnostic plots of the spatial mode and noise weight patterns. + +The main entry point is :func:`generate_save_savar_data`, which: +1. Constructs randomized spatial modes and noise weights on a 2-D grid. +2. Builds a causal link structure (optionally including CO2/aerosol forcing). +3. Instantiates a :class:`~climatem.synthetic_data.savar.SAVAR` model and + generates both deterministic and noisy data fields. +4. Persists every artefact (parameters, weights, forcing fields, latent + trajectories, background state) into a per-dataset directory. +""" + +import copy import csv import json @@ -8,10 +25,22 @@ from climatem.synthetic_data.savar import SAVAR from climatem.synthetic_data.utils import check_stability, create_random_mode +from climatem.utils import get_logger + +logger = get_logger(__name__) -# Before saving the parameters to JSON, convert ndarray to list def convert_ndarray_to_list(d): + """ + Recursively convert all :class:`numpy.ndarray` values in *d* to Python lists. + + This is used to make a parameter dictionary JSON-serialisable before saving. + + Parameters + ---------- + d : dict + Dictionary whose values may contain ndarrays or nested dicts. + """ for key, value in d.items(): if isinstance(value, np.ndarray): d[key] = value.tolist() @@ -20,13 +49,42 @@ def convert_ndarray_to_list(d): def np_encoder(object): + """ + JSON encoder fallback for numpy scalar types. + + Passed as the ``default`` argument to :func:`json.dump` so that + ``np.float64``, ``np.int64``, etc. are converted to native Python types. + + Parameters + ---------- + object : Any + The object that the default JSON encoder could not serialise. + + Returns + ------- + int or float or None + The Python-native equivalent, or ``None`` if *object* is not a numpy + generic type (which will cause :func:`json.dump` to raise). + """ if isinstance(object, np.generic): return object.item() def save_parameters_to_csv(filename, parameters): - # Exclude array data - # excluded_keys = ['modes_weights', 'noise_weights'] + """ + Write experiment parameters to a human-readable CSV file. + + Large array-valued parameters listed in *excluded_keys* are omitted to + keep the CSV concise. Dictionary values (e.g. ``links_coeffs``) are + serialised as JSON strings. + + Parameters + ---------- + filename : str or pathlib.Path + Destination CSV path. + parameters : dict + Flat or shallowly nested parameter dictionary. + """ excluded_keys = ["noise_weights"] # keep noise weights to get permutations filtered_params = {key: value for key, value in parameters.items() if key not in excluded_keys} @@ -44,6 +102,20 @@ def save_parameters_to_csv(filename, parameters): def save_links_coeffs_to_csv(filename, links_coeffs): + """ + Write the causal link structure to a CSV with one row per edge. + + Each row contains the target component index, the source component index, + the time lag, and the edge coefficient. + + Parameters + ---------- + filename : str or pathlib.Path + Destination CSV path. + links_coeffs : dict[int, list[tuple]] + Mapping from target component index to a list of + ``((source_index, lag), coefficient)`` tuples. + """ with open(filename, "w", newline="") as file: writer = csv.writer(file) writer.writerow(["Component", "Link", "Lag", "Coefficient"]) @@ -52,8 +124,23 @@ def save_links_coeffs_to_csv(filename, links_coeffs): writer.writerow([key, value[0][0], value[0][1], value[1]]) -# Function to create a circular mode def create_circular_mode(shape, radius=10): + """ + Create a spatial mode with non-zero random values inside a circular mask. + + Parameters + ---------- + shape : tuple[int, int] + ``(height, width)`` of the output array. + radius : int, optional + Radius (in pixels) of the circular region centred in the grid. + + Returns + ------- + numpy.ndarray + 2-D array of the given *shape* with standard-normal values inside the + circle and zeros outside. + """ mode = np.zeros(shape) center = (shape[0] // 2, shape[1] // 2) Y, X = np.ogrid[: shape[0], : shape[1]] @@ -64,11 +151,44 @@ def create_circular_mode(shape, radius=10): def create_links_coeffs(n_modes, prob_edge=0.2, tau=5, a=4, b=8, difficulty="easy"): + """ + Build a randomised causal link structure for *n_modes* latent components. + + Every component receives exactly one autoregressive (self-)link at a random + lag in ``[1, tau]``. Cross-component links are added with probability + *prob_edge*, subject to a stability guard that keeps the total absolute + coefficient for each target below 1. + + Parameters + ---------- + n_modes : int + Number of latent components. + prob_edge : float + Probability that a directed edge j -> k exists (for j != k). + tau : int + Maximum time lag for any causal link. + a, b : float + Shape parameters of the Beta(a, b) distribution used to draw link + coefficient magnitudes. The default ``a=4, b=8`` yields a right-skewed + distribution concentrated around 0.3, producing moderate-strength links + that are unlikely to be very large (keeps the VAR process stable). + difficulty : str + One of ``"easy"``, ``"med_easy"``, ``"med_hard"``, ``"hard"``. + Higher difficulty halves or quarters coefficient magnitudes, making + causal discovery harder. + + Returns + ------- + dict[int, list[tuple]] + Mapping from target component index to a list of + ``((source_index, -lag), coefficient)`` tuples. + """ links_coeffs = {} for k in range(n_modes): val = 0 links_coeffs[k] = [] auto_reg_tau = np.random.choice(np.arange(1, tau + 1)) + # Draw coefficient magnitude from Beta(a=4, b=8); mean ~0.33, right-skewed r = beta.rvs(a, b) if difficulty == "med_hard": r /= 2 @@ -82,17 +202,224 @@ def create_links_coeffs(n_modes, prob_edge=0.2, tau=5, a=4, b=8, difficulty="eas if j != k: auto_reg_tau = np.random.choice(np.arange(1, tau + 1)) if np.random.choice([0, 1], p=[1 - prob_edge, prob_edge]): + # Draw cross-link coefficient from the same Beta(a=4, b=8) r = beta.rvs(a, b) if difficulty == "med_hard": r /= 2 if difficulty == "hard": r /= 4 val += int(r * 100) / 100 + # Guard: total absolute coefficient for target k must stay < 1 + # to help ensure VAR stability if val < 1: links_coeffs[k].append(((j, -auto_reg_tau), int(r * 100) / 100)) return links_coeffs +def create_forcing_links_coeffs( + n_climate_modes, + n_co2_latents=1, + n_aerosol_latents=4, + tau=5, + co2_effect_strength=0.15, + aerosol_effect_strength=0.10, + co2_affected_modes=None, + aerosol_affected_modes=None, + forcing_autoreg_strength=0.8, +): + """ + Create causal links from forcing latents to climate modes. + + This extends the links_coeffs structure to include: + - CO2 forcing latent(s) → climate modes + - Aerosol forcing latent(s) → climate modes + - Autoregressive terms for forcing latents + + Forcing indices are assigned as: + - Climate modes: 0 to n_climate_modes-1 + - CO2 latents: n_climate_modes to n_climate_modes + n_co2_latents - 1 + - Aerosol latents: n_climate_modes + n_co2_latents to end + + Args: + n_climate_modes: Number of climate modes (e.g., 4) + n_co2_latents: Number of latents for CO2 (default 1) + n_aerosol_latents: Number of latents for aerosols (default 4) + tau: Maximum time lag + co2_effect_strength: Coefficient strength for CO2 → mode connections + aerosol_effect_strength: Coefficient strength for aerosol → mode connections + co2_affected_modes: List of climate mode indices affected by CO2 (default: all) + aerosol_affected_modes: List of climate mode indices affected by aerosols (default: all) + forcing_autoreg_strength: Autoregressive coefficient for forcing latents + + Returns: + forcing_links: Dictionary with forcing → mode causal links + forcing_indices: Dict with 'co2' and 'aerosol' index ranges + """ + # Default: all climate modes are affected by forcings + if co2_affected_modes is None: + co2_affected_modes = list(range(n_climate_modes)) + if aerosol_affected_modes is None: + aerosol_affected_modes = list(range(n_climate_modes)) + + # Compute forcing indices + co2_start_idx = n_climate_modes + co2_end_idx = co2_start_idx + n_co2_latents + aerosol_start_idx = co2_end_idx + aerosol_end_idx = aerosol_start_idx + n_aerosol_latents + + forcing_indices = { + "co2": list(range(co2_start_idx, co2_end_idx)), + "aerosol": list(range(aerosol_start_idx, aerosol_end_idx)), + "n_total": aerosol_end_idx, + } + + forcing_links = {} + + # Initialize forcing latent entries (no self-loops: forcings are exogenous, + # their trajectories are prescribed by _generate_forcing_trajectories()) + for co2_idx in forcing_indices["co2"]: + forcing_links[co2_idx] = [] + for aerosol_idx in forcing_indices["aerosol"]: + forcing_links[aerosol_idx] = [] + + # Add CO2 → climate mode connections + # CO2 affects all modes with lag 1 (immediate effect on next timestep) + for mode_idx in co2_affected_modes: + for co2_idx in forcing_indices["co2"]: + # Add with some random variation in strength + strength = co2_effect_strength * (0.8 + 0.4 * np.random.rand()) + lag = -np.random.choice([1, 2]) # Lag 1 or 2 + # This link means: climate_mode[mode_idx] is caused by co2[co2_idx] at lag + # We store this in the CLIMATE MODE's entry (target's perspective) + if mode_idx not in forcing_links: + forcing_links[mode_idx] = [] + forcing_links[mode_idx].append(((co2_idx, lag), round(strength, 3))) + + # Add Aerosol → climate mode connections + # Each aerosol latent affects a subset of modes (more localized effect) + for i, aerosol_idx in enumerate(forcing_indices["aerosol"]): + # Each aerosol latent primarily affects ~1-2 climate modes + # Distribute aerosol effects across modes + primary_mode = i % n_climate_modes + affected = [primary_mode] + # 50% chance of also affecting the neighbouring mode + if np.random.rand() > 0.5 and n_climate_modes > 1: + neighbor = (primary_mode + 1) % n_climate_modes + affected.append(neighbor) + + for mode_idx in affected: + strength = aerosol_effect_strength * (0.7 + 0.6 * np.random.rand()) + # 70% chance the aerosol effect is negative (cooling), reflecting + # the dominant real-world radiative effect of sulphate aerosols + if np.random.rand() > 0.3: + strength = -strength + lag = -np.random.choice([1, 2, 3]) + if mode_idx not in forcing_links: + forcing_links[mode_idx] = [] + forcing_links[mode_idx].append(((aerosol_idx, lag), round(strength, 3))) + + return forcing_links, forcing_indices + + +def merge_links_coeffs(climate_links, forcing_links): + """ + Merge climate mode links with forcing links into a single links_coeffs dict. + + Args: + climate_links: links_coeffs for climate modes (indices 0 to N-1) + forcing_links: links from forcing latents (includes forcing → mode links) + + Returns: + merged: Combined links_coeffs dictionary + """ + merged = {} + + # Copy climate links + for k, v in climate_links.items(): + merged[k] = list(v) + + # Add/extend with forcing links + for k, v in forcing_links.items(): + if k in merged: + merged[k].extend(v) + else: + merged[k] = list(v) + + return merged + + +def _plot_mode_weights(modes_weights, noise_weights, savar_dataset_dir): + """Plot and save spatial mode and noise weight visualizations.""" + sum_modes = modes_weights.sum(axis=0) + fig, ax = plt.subplots() + im = ax.imshow(sum_modes) + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.05) + plt.colorbar(im, cax=cax) + ax.set_title("Sum of Circular Modes") + plt.savefig(savar_dataset_dir / "modes.png") + np.save(savar_dataset_dir / "modes.npy", modes_weights) + plt.close() + + sum_noise = noise_weights.sum(axis=0) + fig, ax = plt.subplots() + im = ax.imshow(sum_noise) + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.05) + plt.colorbar(im, cax=cax) + ax.set_title("Sum of Circular Noise") + plt.savefig(savar_dataset_dir / "noise_modes.png") + np.save(savar_dataset_dir / "noise_modes.npy", sum_noise) + plt.close() + + +def _log_snr_metrics(noise_field, deterministic_data, noisy_data, noise_val): + """Compute and log signal-to-noise ratio metrics.""" + if noise_field is None or deterministic_data is None: + return + actual_noise = noisy_data - deterministic_data + signal_std = np.std(deterministic_data) + noise_std = np.std(actual_noise) + snr_std = signal_std / noise_std if noise_std > 0 else np.inf + snr_db = 10 * np.log10(snr_std) if snr_std > 0 else -np.inf + + logger.info("SIGNAL-TO-NOISE RATIO METRICS") + logger.info("Signal std: %.6f", signal_std) + logger.info("Noise std: %.6f", noise_std) + logger.info("SNR (std): %.3f (%.2f dB)", snr_std, snr_db) + logger.info("Noise strength: %.6f", noise_val) + + +def _save_savar_artifacts(savar_model, savar_dataset_dir): + """Save forcing fields, latent trajectories, templates, and background to disk.""" + artifacts = [ + ("forcing_data_field", "forcing_data_field.npy"), + ("co2_forcing_data_field", "co2_forcing.npy"), + ("aerosol_forcing_data_field", "aerosol_forcing.npy"), + ("co2_latent_trajectory", "co2_latent_trajectory.npy"), + ("aerosol_latent_trajectory", "aerosol_latent_trajectory.npy"), + ("aerosol_spatial_templates", "aerosol_spatial_templates.npy"), + ] + for attr_name, filename in artifacts: + data = getattr(savar_model, attr_name, None) + if data is not None: + path = savar_dataset_dir / filename + np.save(path, data) + logger.debug("Saved %s to %s (shape: %s)", attr_name, path, data.shape) + + background_field = getattr(savar_model, "background_data_field", None) + if background_field is not None: + background_path = savar_dataset_dir / "background_data_field.npy" + np.save(background_path, background_field) + logger.debug("Saved background data field to %s (shape: %s)", background_path, background_field.shape) + + bg_std = background_field.std() + final_data = savar_model.data_field + data_std = final_data.std() if final_data is not None else 0.0 + ratio = bg_std / data_std if data_std > 0 else 0.0 + logger.info("Background contribution: std=%.6f, ratio to total data=%.3f", bg_std, ratio) + + def generate_save_savar_data( save_dir_path, name, @@ -101,18 +428,163 @@ def generate_save_savar_data( noise_val=0.2, n_per_col=2, # Number of components N = n_per_col**2 difficulty="easy", - seasonality=False, + seasonality=True, + periods=[12, 6, 3], + amplitudes=[0.1, 0.05, 0.02], + phases=[0.0, 0.7853981634, 1.5707963268], # [0, pi/4, pi/2] radians + yearly_jitter_amp: float = 0.05, + yearly_jitter_phase: float = 0.10, overlap=0, is_forced=False, f_1=1, f_2=2, - f_time_1=4000, + f_time_1=2000, f_time_2=8000, ramp_type="linear", linearity="polynomial", poly_degrees=[2, 3], plotting=True, + aerosol_scale=0.02, + aerosol_spatial_contrast=1.05, + aerosol_ramp_up_time=2000, + aerosol_peak_time=5000, + aerosol_decline_time=8000, + # Forcing causal structure parameters + n_co2_latents=1, + n_aerosol_latents=2, # Updated from 4 to 2 (2026-01-22 aerosol refactor) + co2_effect_strength=0.25, # Updated from 0.15 to 0.25 + aerosol_effect_strength=0.20, # Updated from 0.10 to 0.20 + forcing_amplification=1.0, # Updated from 1.5 to 1.0 + noise_ar1=True, # Use AR(1) noise for realistic temporal correlations + noise_ar1_rho=0.95, # AR(1) persistence parameter rho (or "decay" for mode-dependent rho_k) + tau=5, + # Background state parameters + enable_background=False, + background_strength=0.3, + background_strength_mode="relative", + background_smoothness=0.15, + background_timescale_rho=0.995, + background_n_modes=3, ): + """ + Generate a complete SAVAR dataset and persist all artefacts to disk. + + This is the main orchestration function. It builds spatial modes, creates + a randomised causal graph (optionally extended with CO2/aerosol forcing + latents), generates deterministic and noisy time series via + :class:`~climatem.synthetic_data.savar.SAVAR`, and saves everything + (parameters, weights, forcing fields, latent trajectories, diagnostic + plots) into ``save_dir_path / name``. + + Parameters + ---------- + **Grid parameters** + + save_dir_path : pathlib.Path + Parent directory under which a sub-folder *name* will be created. + name : str + Identifier for this dataset; used as the sub-folder name and in logs. + comp_size : int, optional + Side length (in pixels) of each spatial component tile. Default 10. + n_per_col : int, optional + Number of components per row/column; total components N = n_per_col**2. + overlap : float, optional + Spatial overlap between component tiles, in [0, 1]. 0 = no overlap, + 1 = all tiles centred at the grid midpoint. + + **Temporal parameters** + + time_len : int, optional + Number of time steps to generate. Default 10 000, chosen to provide + enough samples for stable VAR estimation and spectral analysis while + keeping generation time manageable. + tau : int, optional + Maximum causal time lag (in time steps). Default 5. + seasonality : bool, optional + Whether to add seasonal (periodic) components. + periods, amplitudes, phases : list, optional + Lists defining the harmonic decomposition of the seasonal cycle. + yearly_jitter_amp, yearly_jitter_phase : float, optional + Year-to-year random perturbation of seasonal amplitude and phase. + + **Causal structure parameters** + + difficulty : str, optional + Controls edge density and coefficient magnitude. One of + ``"easy"`` (no cross-links), ``"med_easy"`` (sparse), + ``"med_hard"`` (moderate, halved coefficients), + ``"hard"`` (dense, quartered coefficients). + + **Forcing parameters** + + is_forced : bool, optional + Whether to include exogenous CO2 and aerosol forcing. + f_1, f_2 : float, optional + Forcing magnitude at the first and second plateau, respectively. + f_time_1 : int, optional + End of the first (low) forcing plateau. Default 2000, representing + the pre-industrial steady-state period (~20 % of the time series). + f_time_2 : int, optional + Start of the second (high) forcing plateau. Default 8000, + representing the point at which forcing stabilises (~80 %). + ramp_type : str, optional + Interpolation between the two plateaus (``"linear"`` or other). + aerosol_scale, aerosol_spatial_contrast : float, optional + Magnitude and spatial heterogeneity of aerosol forcing. + aerosol_ramp_up_time : int, optional + Time step at which aerosol emissions begin increasing. Default 2000, + aligned with f_time_1 to co-locate the start of anthropogenic forcing. + aerosol_peak_time : int, optional + Time step at which aerosol emissions peak. Default 5000, placing the + peak at the midpoint of the ramp-up window. + aerosol_decline_time : int, optional + Time step at which aerosol emissions finish declining. Default 8000, + aligned with f_time_2 to model clean-air regulations. + n_co2_latents, n_aerosol_latents : int, optional + Number of latent variables representing CO2 and aerosol forcing. + co2_effect_strength, aerosol_effect_strength : float, optional + Coefficient magnitude for forcing -> climate mode causal links. + forcing_amplification : float, optional + Global multiplier applied to all forcing signals. + + **Noise parameters** + + noise_val : float, optional + Noise standard deviation (before AR(1) colouring). + noise_ar1 : bool, optional + Whether to use temporally correlated AR(1) noise. + noise_ar1_rho : float, optional + AR(1) persistence parameter. + + **Background state parameters** + + enable_background : bool, optional + Whether to add a slowly varying background state. + background_strength : float, optional + Magnitude of the background state. + background_strength_mode : str, optional + ``"relative"`` scales background relative to signal std. + background_smoothness : float, optional + Controls spatial smoothness of background modes. + background_timescale_rho : float, optional + AR(1) coefficient for the background time series. + background_n_modes : int, optional + Number of spatial modes used to construct the background. + + **Miscellaneous** + + linearity : str, optional + ``"linear"`` or ``"nonlinear"``; selects the SAVAR transition model. + poly_degrees : list[int], optional + Polynomial degrees used when ``linearity="nonlinear"``. + plotting : bool, optional + Whether to save diagnostic spatial-mode plots. + + Returns + ------- + numpy.ndarray + The generated (noisy) data field with shape ``(time_len, nx * ny)``. + """ # Setup spatial weights of underlying processes ny = nx = n_per_col * comp_size @@ -124,9 +596,13 @@ def generate_save_savar_data( noise_weights = np.zeros((N, nx, ny)) modes_weights = np.zeros((N, nx, ny)) + # Create a subfolder for this specific SAVAR dataset + savar_dataset_dir = save_dir_path / name + savar_dataset_dir.mkdir(parents=True, exist_ok=True) + # Specify the path where you want to save the data - npy_name = f"{name}.npy" - save_path = save_dir_path / npy_name + npy_name = "savar.npy" + save_path = savar_dataset_dir / npy_name # Center starting position (for fully overlapping modes) center_x_start = (nx - comp_size) // 2 @@ -153,7 +629,8 @@ def generate_save_savar_data( (comp_size, comp_size), random=True ) - # This is the probabiliity of having a link between latent k and j, with k different from j. latents always have one link with themselves at a previous time. + # Probability of having a link between latent k and j (k != j). + # Latents always have one autoregressive link with themselves at a previous time. if difficulty == "easy": prob = 0 # expected N out of N^2 total links if difficulty == "med_easy": @@ -163,16 +640,45 @@ def generate_save_savar_data( if difficulty == "hard": prob = 1 / 2 # (N + 2*N / 2)/N^2 - links_coeffs = create_links_coeffs(N, prob_edge=prob, difficulty=difficulty) + # Create climate mode links (N x N) + climate_links_coeffs = create_links_coeffs(N, prob_edge=prob, tau=tau, difficulty=difficulty) # One good thing of SAVAR is that if the underlying process is stable and stationary, then SAVAR is also both. # Independently of W. This is, we only need to check for stationarity of \PHI and not of W^+\PHI W - check_stability(links_coeffs) + check_stability(climate_links_coeffs) + + # Initialize forcing_indices (will be populated if is_forced) + forcing_indices = None + + if is_forced: + # Create forcing -> mode causal links + forcing_links, forcing_indices = create_forcing_links_coeffs( + n_climate_modes=N, + n_co2_latents=n_co2_latents, + n_aerosol_latents=n_aerosol_latents, + tau=tau, + co2_effect_strength=co2_effect_strength, + aerosol_effect_strength=aerosol_effect_strength, + ) + + # Merge climate and forcing links into complete links_coeffs + links_coeffs = merge_links_coeffs(climate_links_coeffs, forcing_links) + + logger.info( + "Created extended causal graph with %d total latents: " + "climate modes 0-%d, CO2 latents %s, aerosol latents %s", + forcing_indices["n_total"], + N - 1, + forcing_indices["co2"], + forcing_indices["aerosol"], + ) + else: + links_coeffs = climate_links_coeffs if is_forced: # turn off forcing by setting the time to the last time step w_f = modes_weights - # A very simple method for adding a focring term (bias on the mean of the noise term) + # A very simple method for adding a forcing term (bias on the mean of the noise term) forcing_dict = { "w_f": w_f, # Shape of the mode of the forcing "f_1": f_1, # Value of the forcing at period_1 @@ -181,46 +687,38 @@ def generate_save_savar_data( "f_time_2": f_time_2, # The period two goes from t= f_time_2 to the end. Between the two periods, the forcing is risen linearly "time_len": time_len, "ramp_type": ramp_type, + "aerosol_scale": aerosol_scale, # Scale parameter for aerosol forcing + "aerosol_spatial_contrast": aerosol_spatial_contrast, # Spatial contrast parameter + "aerosol_ramp_up_time": aerosol_ramp_up_time, # When aerosols start increasing + "aerosol_peak_time": aerosol_peak_time, # When aerosols peak + "aerosol_decline_time": aerosol_decline_time, # When aerosols finish declining } + + season_dict = None if seasonality: - raise ValueError("SAVAR data with seasonality not implemented yet") - # We could introduce seasonality if we would wish - # season_dict = {"amplitude": 0.08, "period": 12} + lat = np.linspace(-90, 90, nx) # vary along rows + lat2d = np.repeat(lat[:, None], ny, axis=1) # shape (nx, ny) + season_weight = np.abs(np.sin(2 * np.deg2rad(lat2d))).ravel() + + if phases is None: + phases = [0.0] * len(amplitudes) + + if not (len(amplitudes) == len(periods) == len(phases)): + raise ValueError("season_amplitudes, season_periods, season_phases must have identical lengths.") + + season_dict = { + "amplitudes": amplitudes, # e.g. [0.06, 0.02, 0.01] + "periods": periods, # e.g. [365, 182.5, 60] + "phases": phases, # radian offsets + "season_weight": season_weight, + "yearly_jitter": { + "amplitude": yearly_jitter_amp, # e.g. 0.05 + "phase": yearly_jitter_phase, # e.g. 0.10 + }, + } if plotting: - # Plot the sum of mode weights - sum_modes = modes_weights.sum(axis=0) - fig, ax = plt.subplots() - im = ax.imshow(sum_modes) - divider = make_axes_locatable(ax) - cax = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(im, cax=cax) - ax.set_title("Sum of Circular Modes") - fig_name = f"{name}_modes.png" - modenpy_name = f"{name}_modes.npy" - fig_path = save_dir_path / fig_name - modenpy_path = save_dir_path / modenpy_name - plt.savefig(fig_path) - np.save(modenpy_path, sum_modes) - plt.close() - - # Plot the sum of noise weights - sum_noise = noise_weights.sum(axis=0) - fig, ax = plt.subplots() - im = ax.imshow(sum_noise) - divider = make_axes_locatable(ax) - cax = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(im, cax=cax) - ax.set_title("Sum of Circular Noise") - - fig_name = f"{name}_noise_modes.png" - noisenpy_name = f"{name}_noise_modes.npy" - fig_path = save_dir_path / fig_name - sum_noise_npypath = save_dir_path / noisenpy_name - - plt.savefig(fig_path) - np.save(sum_noise_npypath, sum_noise) - plt.close() + _plot_mode_weights(modes_weights, noise_weights, savar_dataset_dir) # Creating a dictionary of parameters parameters = { @@ -237,33 +735,59 @@ def generate_save_savar_data( "ramp_type": ramp_type, "linearity": linearity, "poly_degrees": poly_degrees, - # "season_dict": season_dict, - # "seasonality" : True, + "season_dict": season_dict, + "seasonality": True, + "noise_val": noise_val, + # Forcing causal structure info + "is_forced": is_forced, + "forcing_indices": forcing_indices, + "n_co2_latents": n_co2_latents if is_forced else 0, + "n_aerosol_latents": n_aerosol_latents if is_forced else 0, + "n_climate_modes": N, + "n_total_latents": forcing_indices["n_total"] if forcing_indices else N, + # Forcing signal strength parameters + "forcing_amplification": forcing_amplification if is_forced else None, + "co2_effect_strength": co2_effect_strength if is_forced else None, + "aerosol_effect_strength": aerosol_effect_strength if is_forced else None, + # Background state parameters + "enable_background": enable_background, + "background_strength": background_strength, + "background_strength_mode": background_strength_mode, + "background_smoothness": background_smoothness, + "background_timescale_rho": background_timescale_rho, + "background_n_modes": background_n_modes, } + parameters_copy = copy.deepcopy(parameters) + convert_ndarray_to_list(parameters_copy) # safe to mutate + # Specify the path to save the parameters - param_names = f"{name}_parameters.npy" - params_path = save_dir_path / param_names + param_names = "parameters.npy" + params_path = savar_dataset_dir / param_names # Save the dictionary of parameters to a .npy file np.save(params_path, parameters) - param_names = f"{name}_parameters.csv" - params_path = save_dir_path / param_names - save_parameters_to_csv(params_path, parameters) - param_names = f"{name}_links_coeffs.csv" - params_path = save_dir_path / param_names + param_names = "parameters.csv" + params_path = savar_dataset_dir / param_names + save_parameters_to_csv(params_path, parameters_copy) + param_names = "links_coeffs.csv" + params_path = savar_dataset_dir / param_names save_links_coeffs_to_csv(params_path, parameters["links_coeffs"]) - param_names = f"{name}_mode_weights.npy" - params_path = save_dir_path / param_names + param_names = "mode_weights.npy" + params_path = savar_dataset_dir / param_names np.save(params_path, modes_weights) + # Save noise_weights for diagnostics + param_names = "noise_weights.npy" + params_path = savar_dataset_dir / param_names + np.save(params_path, noise_weights) + # Create a copy of the parameters to modify - parameters_copy = parameters.copy() convert_ndarray_to_list(parameters_copy) # Specify the path to save the parameters - param_names = f"{name}_parameters.json" - params_path = save_dir_path / param_names + param_names = "parameters.json" + params_path = savar_dataset_dir / param_names # Save the dictionary of parameters to a JSON file with open(params_path, "w") as json_file: @@ -275,25 +799,94 @@ def generate_save_savar_data( links_coeffs=links_coeffs, time_length=time_len, mode_weights=modes_weights, - noise_strength=noise_val, # How to play with this parameter? - # season_dict=season_dict, #turn off by commenting out - # forcing_dict=forcing_dict #turn off by commenting out + # noise_weights defaults to mode_weights inside SAVAR (baseline behavior) + noise_strength=noise_val, + season_dict=season_dict, linearity=linearity, poly_degrees=poly_degrees, + output_save_dir=str(savar_dataset_dir), + # Background state parameters + enable_background=enable_background, + background_strength=background_strength, + background_strength_mode=background_strength_mode, + background_smoothness=background_smoothness, + background_timescale_rho=background_timescale_rho, + background_n_modes=background_n_modes, ) else: savar_model = SAVAR( links_coeffs=links_coeffs, time_length=time_len, mode_weights=modes_weights, + # noise_weights defaults to mode_weights inside SAVAR (baseline behavior) noise_strength=noise_val, - forcing_dict=forcing_dict, # turn off by commenting out + noise_ar1=noise_ar1, + noise_ar1_rho=noise_ar1_rho, + season_dict=season_dict, + forcing_dict=forcing_dict, + forcing_indices=forcing_indices, # Pass forcing indices for causal structure + forcing_amplification=forcing_amplification, linearity=linearity, poly_degrees=poly_degrees, + output_save_dir=str(savar_dataset_dir), + # Background state parameters + enable_background=enable_background, + background_strength=background_strength, + background_strength_mode=background_strength_mode, + background_smoothness=background_smoothness, + background_timescale_rho=background_timescale_rho, + background_n_modes=background_n_modes, ) - savar_model.generate_data() # Remember to generate data, otherwise the data field will be empty + # Generate data with noise (baseline behavior: noise added to data_field before dynamics) + logger.info("Generating data with noise (baseline behavior)...") + savar_model.generate_data(include_noise=True) np.save(save_path, savar_model.data_field) + logger.info("Saved noisy data to %s (shape: %s)", save_path, savar_model.data_field.shape) + + # Save noise data field for diagnostics + noise_field = getattr(savar_model, "noise_data_field", None) + if noise_field is not None: + noise_data_path = savar_dataset_dir / "noise_data_field.npy" + np.save(noise_data_path, noise_field) + logger.debug("Saved noise data field to %s (shape: %s)", noise_data_path, noise_field.shape) + + # Also generate deterministic data (separate pass, for diagnostics/SNR) + # Save all deterministic components so they can be preserved + saved_seasonal = savar_model.seasonal_data_field + saved_forcing = savar_model.forcing_data_field + saved_co2_forcing = savar_model.co2_forcing_data_field + saved_aerosol_forcing = savar_model.aerosol_forcing_data_field + saved_background = savar_model.background_data_field + saved_co2_latent = savar_model.co2_latent_trajectory + saved_aerosol_latent = savar_model.aerosol_latent_trajectory + saved_aerosol_templates = savar_model.aerosol_spatial_templates + + savar_model.data_field = None + savar_model.seasonal_data_field = saved_seasonal + savar_model.forcing_data_field = saved_forcing + savar_model.co2_forcing_data_field = saved_co2_forcing + savar_model.aerosol_forcing_data_field = saved_aerosol_forcing + savar_model.background_data_field = saved_background + savar_model.co2_latent_trajectory = saved_co2_latent + savar_model.aerosol_latent_trajectory = saved_aerosol_latent + savar_model.aerosol_spatial_templates = saved_aerosol_templates + + logger.info("Generating deterministic data (no noise, for diagnostics)...") + savar_model.generate_data(include_noise=False) + deterministic_data = savar_model.data_field.copy() + deterministic_path = savar_dataset_dir / "savar_deterministic.npy" + np.save(deterministic_path, deterministic_data) + logger.info("Saved deterministic data to %s (shape: %s)", deterministic_path, deterministic_data.shape) + + # Restore the noisy data as the primary output + noisy_data = np.load(save_path) + savar_model.data_field = noisy_data + + _log_snr_metrics(noise_field, deterministic_data, noisy_data, noise_val) + + _save_savar_artifacts(savar_model, savar_dataset_dir) + + logger.info("%s DONE!", name) - print(f"{name} DONE!") return savar_model.data_field diff --git a/climatem/synthetic_data/graph_evaluation.py b/climatem/synthetic_data/graph_evaluation.py index 24c6088..44ccb3e 100644 --- a/climatem/synthetic_data/graph_evaluation.py +++ b/climatem/synthetic_data/graph_evaluation.py @@ -1,4 +1,13 @@ +""" +Evaluation utilities for comparing learned causal graphs against ground truth. + +Provides functions to compute standard causal-discovery metrics -- Structural Hamming Distance (SHD), precision, recall, +and F1 -- between inferred and ground-truth adjacency matrices. Also includes helpers for permuting matrices to align +latent orderings, plotting adjacency heatmaps, and extracting human-readable latent equations from adjacency structures. +""" + import json +import logging from pathlib import Path import matplotlib.pyplot as plt @@ -7,9 +16,32 @@ import seaborn as sns from sklearn.metrics import f1_score, precision_score, recall_score +logger = logging.getLogger(__name__) -def get_permutation_list(mat_adj_w, modes_gt, lat, lon): # , remove_n_latents=0 +def get_permutation_list(mat_adj_w, modes_gt, lat, lon): # , remove_n_latents=0 + """ + Find the permutation that best aligns inferred modes to ground-truth modes. + + Alignment is based on the 2-D grid location of the maximum value of each + spatial mode. The permutation minimises the sum of squared Euclidean + distances between the peak locations of ground-truth and inferred modes. + + Parameters + ---------- + mat_adj_w : np.ndarray + Inferred mode weight matrix, shape ``(lat*lon, n_modes)``. + modes_gt : np.ndarray + Ground-truth mode weights, shape ``(n_modes, lat, lon)``. + lat, lon : int + Spatial grid dimensions. + + Returns + ------- + np.ndarray + Integer array of length ``n_modes`` mapping each ground-truth mode + index to the best-matching inferred mode index. + """ mat_adj_w = mat_adj_w.reshape((lat, lon, mat_adj_w.shape[1])).transpose((2, 0, 1)) idx_gt = np.where(modes_gt == modes_gt.max((1, 2))[:, None, None]) @@ -25,7 +57,28 @@ def get_permutation_list(mat_adj_w, modes_gt, lat, lon): # , remove_n_latents=0 def get_permutation_list_hardcoded_100(mat_adj_w, modes_gt, lat, lon): # , remove_n_latents=0 - + """ + Compute mode permutation for exactly 100 modes on a 10x10 spatial grid. + + This is a specialised (and possibly dead-code) variant of + ``get_permutation_list`` that assumes 100 modes laid out on a 10x10 grid + of 10x10 spatial patches. Each mode is matched by finding the maximum + activation within its expected patch. + + Parameters + ---------- + mat_adj_w : np.ndarray + Inferred mode weight matrix, shape ``(lat*lon, 100)``. + modes_gt : np.ndarray + Ground-truth mode weights, shape ``(100, lat, lon)``. + lat, lon : int + Spatial grid dimensions (expected to be 100 each). + + Returns + ------- + list of int + Permutation mapping ground-truth mode indices to inferred mode indices. + """ mat_adj_w = mat_adj_w.reshape((lat, lon, mat_adj_w.shape[1])).transpose((2, 0, 1)) permutation_list = [] @@ -68,7 +121,7 @@ def permute_matrix(matrix, permutation): return permuted_matrix -def load_and_permute_all_matrices(csv_files, permutation, remove_modes=[]): +def load_and_permute_all_matrices(modes_inferred, modes_gt, adj_w, adj_gt, lat, lon, tau): """ Loads and permutes multiple adjacency matrices, one for each time lag. @@ -80,28 +133,46 @@ def load_and_permute_all_matrices(csv_files, permutation, remove_modes=[]): np.ndarray: A 3D NumPy array containing all permuted adjacency matrices where the shape is (number_of_time_lags, n, n). """ - permuted_matrices = [] + # Find the permutation + modes_inferred = modes_inferred.reshape((lat, lon, modes_inferred.shape[-1])).transpose((2, 0, 1)) - for csv_file in csv_files: - # Load the adjacency matrix - adjacency_matrix = load_adjacency_matrix(csv_file) + # Get the flat index of the maximum for each mode + idx_gt_flat = np.argmax(modes_gt.reshape(modes_gt.shape[0], -1), axis=1) # shape: (n_modes,) + idx_inferred_flat = np.argmax(modes_inferred.reshape(modes_inferred.shape[0], -1), axis=1) # shape: (n_modes,) - if len(remove_modes): - adjacency_matrix = np.delete(adjacency_matrix, remove_modes, 0) - adjacency_matrix = np.delete(adjacency_matrix, remove_modes, 1) + # Convert flat indices to 2D coordinates (row, col) + idx_gt = np.array([np.unravel_index(i, (lat, lon)) for i in idx_gt_flat]) # shape: (n_modes, 2) + idx_inferred = np.array([np.unravel_index(i, (lat, lon)) for i in idx_inferred_flat]) # shape: (n_modes, 2) - # Permute the adjacency matrix - permuted_matrix = permute_matrix(adjacency_matrix, permutation) + # Compute error matrix using squared Euclidean distance between indices which yields an (n_modes x n_modes) matrix + permutation_list = ((idx_gt[:, None, :] - idx_inferred[None, :, :]) ** 2).sum(axis=2).argmin(axis=1) + logger.info("permutation_list: %s", permutation_list) - # Append the permuted matrix to the list - permuted_matrices.append(permuted_matrix) + # Permute + for k in range(tau): + adj_w[k] = adj_w[k][np.ix_(permutation_list, permutation_list)] - # Convert the list of permuted matrices to a NumPy array - return np.array(permuted_matrices) + logger.info("PERMUTED THE MATRICES") + + return adj_w def binarize_matrix(A, threshold=0.5): - """Binarizes the adjacency matrix by applying a threshold.""" + """ + Binarise an adjacency matrix by applying a threshold. + + Parameters + ---------- + A : np.ndarray + Real-valued adjacency matrix. + threshold : float + Values strictly above this become 1; all others become 0. + + Returns + ------- + np.ndarray + Integer array with values in {0, 1}. + """ return (A > threshold).astype(int) @@ -193,12 +264,205 @@ def plot_adjacency_matrix( plt.close() +def plot_adjacency_with_forcing_labels( + mat_inferred: np.ndarray, + mat_gt: np.ndarray, + forcing_indices: dict, + path: str, + name: str = "adjacency_with_labels", + threshold: float = 0.5, + tau_idx: int = 0, +): + """ + Plot adjacency matrices with labeled axes showing climate modes, CO2, and aerosol indices. + + Args: + mat_inferred: Inferred adjacency matrices (tau x N x N) + mat_gt: Ground truth adjacency matrices (tau x N x N) + forcing_indices: Dict with 'co2', 'aerosol' index lists and 'n_total' + path: Path where to save the plot + name: Name of the plot file + threshold: Binarization threshold + tau_idx: Which time lag to plot (0-indexed) + """ + co2_indices = forcing_indices.get("co2", []) + aerosol_indices = forcing_indices.get("aerosol", []) + n_total = forcing_indices.get("n_total", mat_inferred.shape[1]) + n_climate = n_total - len(co2_indices) - len(aerosol_indices) + + # Create labels for each index + labels = [] + for i in range(n_total): + if i < n_climate: + labels.append(f"M{i}") + elif i in co2_indices: + labels.append("CO2") + elif i in aerosol_indices: + aero_idx = aerosol_indices.index(i) + labels.append(f"A{aero_idx}") + + # Binarize matrices + mat_inferred_bin = binarize_matrix(mat_inferred[tau_idx], threshold) + mat_gt_bin = binarize_matrix(mat_gt[tau_idx], threshold) + + # Create figure with two subplots + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + + # Plot inferred adjacency + axes[0].imshow(mat_inferred_bin, cmap="Blues", vmin=0, vmax=1, aspect="equal") + axes[0].set_title(f"Inferred Adjacency (t-{tau_idx+1})", fontsize=12) + axes[0].set_xticks(range(n_total)) + axes[0].set_yticks(range(n_total)) + axes[0].set_xticklabels(labels, fontsize=8) + axes[0].set_yticklabels(labels, fontsize=8) + axes[0].set_xlabel("Source (cause)") + axes[0].set_ylabel("Target (effect)") + + # Plot ground truth adjacency + im2 = axes[1].imshow(mat_gt_bin, cmap="Blues", vmin=0, vmax=1, aspect="equal") + axes[1].set_title(f"Ground Truth Adjacency (t-{tau_idx+1})", fontsize=12) + axes[1].set_xticks(range(n_total)) + axes[1].set_yticks(range(n_total)) + axes[1].set_xticklabels(labels, fontsize=8) + axes[1].set_yticklabels(labels, fontsize=8) + axes[1].set_xlabel("Source (cause)") + axes[1].set_ylabel("Target (effect)") + + # Add separating lines between climate modes and forcing latents + for ax in axes: + # Line between climate modes and CO2 + ax.axhline(y=n_climate - 0.5, color="red", linewidth=1.5, linestyle="--") + ax.axvline(x=n_climate - 0.5, color="red", linewidth=1.5, linestyle="--") + # Line between CO2 and aerosols (if both exist) + if co2_indices and aerosol_indices: + co2_end = max(co2_indices) + 0.5 + ax.axhline(y=co2_end, color="orange", linewidth=1, linestyle=":") + ax.axvline(x=co2_end, color="orange", linewidth=1, linestyle=":") + + # Add colorbar + fig.colorbar(im2, ax=axes, shrink=0.6, label="Edge present") + + # Add legend for regions + legend_text = ( + f"M0-M{n_climate-1}: Climate Modes | CO2: CO2 Forcing | A0-A{len(aerosol_indices)-1}: Aerosol Forcings" + ) + fig.text(0.5, 0.02, legend_text, ha="center", fontsize=10, style="italic") + + plt.tight_layout(rect=[0, 0.05, 1, 1]) + plt.savefig(Path(path) / f"{name}.png", dpi=150) + plt.close() + + logger.info(f"Saved labeled adjacency plot to {Path(path) / name}.png") + + +def plot_adjacency_all_lags_with_labels( + mat_inferred: np.ndarray, + mat_gt: np.ndarray, + forcing_indices: dict, + path: str, + name: str = "adjacency_all_lags", + threshold: float = 0.5, +): + """ + Plot adjacency matrices for all time lags with labeled axes. + + Args: + mat_inferred: Inferred adjacency matrices (tau x N x N) + mat_gt: Ground truth adjacency matrices (tau x N x N) + forcing_indices: Dict with 'co2', 'aerosol' index lists and 'n_total' + path: Path where to save the plot + name: Name of the plot file + threshold: Binarization threshold + """ + tau = mat_inferred.shape[0] + co2_indices = forcing_indices.get("co2", []) + aerosol_indices = forcing_indices.get("aerosol", []) + n_total = forcing_indices.get("n_total", mat_inferred.shape[1]) + n_climate = n_total - len(co2_indices) - len(aerosol_indices) + + # Create labels + labels = [] + for i in range(n_total): + if i < n_climate: + labels.append(f"M{i}") + elif i in co2_indices: + labels.append("CO2") + elif i in aerosol_indices: + aero_idx = aerosol_indices.index(i) + labels.append(f"A{aero_idx}") + + # Create figure with 2 rows (inferred, gt) x tau columns + fig, axes = plt.subplots(2, tau, figsize=(4 * tau, 8)) + + if tau == 1: + axes = axes.reshape(2, 1) + + for t in range(tau): + mat_inf_bin = binarize_matrix(mat_inferred[t], threshold) + mat_gt_bin = binarize_matrix(mat_gt[t], threshold) + + # Inferred + axes[0, t].imshow(mat_inf_bin, cmap="Blues", vmin=0, vmax=1, aspect="equal") + axes[0, t].set_title(f"Inferred t-{t+1}", fontsize=10) + if t == 0: + axes[0, t].set_ylabel("Target") + axes[0, t].set_yticks(range(n_total)) + axes[0, t].set_yticklabels(labels, fontsize=7) + else: + axes[0, t].set_yticks([]) + + # Ground truth + axes[1, t].imshow(mat_gt_bin, cmap="Blues", vmin=0, vmax=1, aspect="equal") + axes[1, t].set_title(f"GT t-{t+1}", fontsize=10) + axes[1, t].set_xlabel("Source") + axes[1, t].set_xticks(range(n_total)) + axes[1, t].set_xticklabels(labels, fontsize=7, rotation=45) + if t == 0: + axes[1, t].set_ylabel("Target") + axes[1, t].set_yticks(range(n_total)) + axes[1, t].set_yticklabels(labels, fontsize=7) + else: + axes[1, t].set_yticks([]) + + # Add separator lines + for row in range(2): + axes[row, t].axhline(y=n_climate - 0.5, color="red", linewidth=1, linestyle="--") + axes[row, t].axvline(x=n_climate - 0.5, color="red", linewidth=1, linestyle="--") + + plt.suptitle("Adjacency Matrices by Time Lag (Red line separates climate modes from forcings)", fontsize=12) + plt.tight_layout() + plt.savefig(Path(path) / f"{name}.png", dpi=150) + plt.close() + + logger.info(f"Saved multi-lag adjacency plot to {Path(path) / name}.png") + + def evaluate_adjacency_matrix(A_inferred, A_ground_truth, threshold): - """Evaluates the precision, recall, F1-score, and Structural Hamming Distance (SHD) between the inferred and ground - truth adjacency matrices.""" + """ + Evaluate precision, recall, F1, and SHD between inferred and ground-truth graphs. + + Both matrices are binarised with the given *threshold* before comparison. + + Parameters + ---------- + A_inferred : np.ndarray + Inferred adjacency matrix (possibly real-valued). + A_ground_truth : np.ndarray + Ground-truth adjacency matrix. + threshold : float + Values strictly above this become 1; all others become 0. + + Returns + ------- + precision : float + recall : float + f1 : float + shd : int + Structural Hamming Distance (false positives + false negatives). + """ # Binarize the matrices before comparison A_inferred_bin = binarize_matrix(A_inferred, threshold) - print(f"N inferred links: {A_inferred_bin.sum()}") + logger.info(f"N inferred links: {A_inferred_bin.sum()}") A_ground_truth_bin = binarize_matrix(A_ground_truth, threshold) # Flatten the matrices to make comparison easier @@ -218,6 +482,144 @@ def evaluate_adjacency_matrix(A_inferred, A_ground_truth, threshold): return precision, recall, f1, shd +def evaluate_adjacency_by_link_type(A_inferred, A_ground_truth, threshold, forcing_indices): + """ + Evaluate adjacency matrix metrics separately for different link types. + + Args: + A_inferred: Inferred adjacency matrices (tau x N x N) + A_ground_truth: Ground truth adjacency matrices (tau x N x N) + threshold: Binarization threshold + forcing_indices: Dict with 'co2', 'aerosol' index lists and 'n_total' + + Returns: + Dict with metrics for each link type: + - 'overall': Overall metrics + - 'climate_to_climate': Climate mode ↔ Climate mode + - 'co2_to_climate': CO2 → Climate mode + - 'aerosol_to_climate': Aerosol → Climate mode + - 'forcing_autoreg': Forcing autoregressive (CO2→CO2, aerosol→aerosol) + """ + A_inferred_bin = binarize_matrix(A_inferred, threshold) + A_ground_truth_bin = binarize_matrix(A_ground_truth, threshold) + + co2_indices = set(forcing_indices.get("co2", [])) + aerosol_indices = set(forcing_indices.get("aerosol", [])) + n_total = forcing_indices.get("n_total", A_inferred.shape[1]) + n_climate = n_total - len(co2_indices) - len(aerosol_indices) + + results = {} + + # Overall metrics + results["overall"] = _compute_metrics(A_inferred_bin.flatten(), A_ground_truth_bin.flatten()) + + # Climate ↔ Climate (indices 0 to n_climate-1) + climate_mask = np.zeros_like(A_inferred_bin, dtype=bool) + climate_mask[:, :n_climate, :n_climate] = True + results["climate_to_climate"] = _compute_metrics(A_inferred_bin[climate_mask], A_ground_truth_bin[climate_mask]) + + # CO2 → Climate (column indices in co2_indices, row indices in climate) + co2_to_climate_mask = np.zeros_like(A_inferred_bin, dtype=bool) + for co2_idx in co2_indices: + co2_to_climate_mask[:, :n_climate, co2_idx] = True + if co2_to_climate_mask.any(): + results["co2_to_climate"] = _compute_metrics( + A_inferred_bin[co2_to_climate_mask], A_ground_truth_bin[co2_to_climate_mask] + ) + else: + results["co2_to_climate"] = {"precision": 0.0, "recall": 0.0, "f1": 0.0, "shd": 0, "n_gt_links": 0} + + # Aerosol → Climate (column indices in aerosol_indices, row indices in climate) + aerosol_to_climate_mask = np.zeros_like(A_inferred_bin, dtype=bool) + for aerosol_idx in aerosol_indices: + aerosol_to_climate_mask[:, :n_climate, aerosol_idx] = True + if aerosol_to_climate_mask.any(): + results["aerosol_to_climate"] = _compute_metrics( + A_inferred_bin[aerosol_to_climate_mask], A_ground_truth_bin[aerosol_to_climate_mask] + ) + else: + results["aerosol_to_climate"] = {"precision": 0.0, "recall": 0.0, "f1": 0.0, "shd": 0, "n_gt_links": 0} + + # Forcing autoregressive (CO2→CO2, aerosol→aerosol diagonal) + forcing_autoreg_mask = np.zeros_like(A_inferred_bin, dtype=bool) + for idx in co2_indices | aerosol_indices: + forcing_autoreg_mask[:, idx, idx] = True + if forcing_autoreg_mask.any(): + results["forcing_autoreg"] = _compute_metrics( + A_inferred_bin[forcing_autoreg_mask], A_ground_truth_bin[forcing_autoreg_mask] + ) + else: + results["forcing_autoreg"] = {"precision": 0.0, "recall": 0.0, "f1": 0.0, "shd": 0, "n_gt_links": 0} + + return results + + +def _compute_metrics(inferred_flat, gt_flat): + """ + Compute precision, recall, F1, and SHD from flattened binary arrays. + + Parameters + ---------- + inferred_flat : np.ndarray + Flattened binary inferred adjacency values. + gt_flat : np.ndarray + Flattened binary ground-truth adjacency values. + + Returns + ------- + dict + Dictionary with keys ``'precision'``, ``'recall'``, ``'f1'``, + ``'shd'``, and ``'n_gt_links'``. + """ + if len(inferred_flat) == 0: + return {"precision": 0.0, "recall": 0.0, "f1": 0.0, "shd": 0, "n_gt_links": 0} + + n_gt_links = int(gt_flat.sum()) + n_inferred_links = int(inferred_flat.sum()) + + if n_gt_links == 0 and n_inferred_links == 0: + return {"precision": 1.0, "recall": 1.0, "f1": 1.0, "shd": 0, "n_gt_links": 0} + elif n_gt_links == 0: + return {"precision": 0.0, "recall": 1.0, "f1": 0.0, "shd": n_inferred_links, "n_gt_links": 0} + elif n_inferred_links == 0: + return {"precision": 1.0, "recall": 0.0, "f1": 0.0, "shd": n_gt_links, "n_gt_links": n_gt_links} + + precision = float(precision_score(gt_flat, inferred_flat, zero_division=0)) + recall = float(recall_score(gt_flat, inferred_flat, zero_division=0)) + f1 = float(f1_score(gt_flat, inferred_flat, zero_division=0)) + + false_positives = int(np.sum((inferred_flat == 1) & (gt_flat == 0))) + false_negatives = int(np.sum((inferred_flat == 0) & (gt_flat == 1))) + shd = false_positives + false_negatives + + return {"precision": precision, "recall": recall, "f1": f1, "shd": shd, "n_gt_links": n_gt_links} + + +def print_evaluation_by_link_type(results): + """ + Log evaluation metrics grouped by causal link type. + + Parameters + ---------- + results : dict + Dictionary returned by ``evaluate_adjacency_by_link_type``, mapping + link-type names (e.g. ``'climate_to_climate'``) to metric dicts. + """ + logger.info("\n%s", "=" * 70) + logger.info("EVALUATION BY LINK TYPE") + logger.info("=" * 70) + + for link_type, metrics in results.items(): + logger.info("\n%s:", link_type.upper().replace("_", " ")) + logger.info(" GT links: %s", metrics["n_gt_links"]) + logger.info(" Precision: %.3f", metrics["precision"]) + logger.info(" Recall: %.3f", metrics["recall"]) + logger.info(" F1: %.3f", metrics["f1"]) + logger.info(" SHD: %s", metrics["shd"]) + + logger.info("\n%s", "=" * 70) + + def extract_adjacency_matrix(links_coeffs, N, tau): """ Extract the ground truth adjacency matrices for each time lag from the links_coeffs. @@ -235,23 +637,38 @@ def extract_adjacency_matrix(links_coeffs, N, tau): adj_matrices = np.zeros((tau, N, N)) # Loop through each component and its links - for key, values in links_coeffs.items(): + for target, values in links_coeffs.items(): for link, coeff in values: - target_var, lag = link + source, lag = link time_lag = -lag # Convert the negative lag to a positive index # Only consider lags that are within the specified time window (tau) if time_lag <= tau: if abs(coeff) > 0.01: - adj_matrices[time_lag - 1, key, target_var] = ( + adj_matrices[time_lag - 1, target, source] = ( 1 # Fill the adjacency matrix at the appropriate time lag ) else: - adj_matrices[time_lag - 1, key, target_var] = 0 + adj_matrices[time_lag - 1, target, source] = 0 return adj_matrices def extract_latent_equations(links_coeffs): + """ + Convert a ``links_coeffs`` dictionary into human-readable latent equations. + + Parameters + ---------- + links_coeffs : dict + Mapping from latent variable index to a list of + ``((linked_var, lag), coefficient)`` tuples. + + Returns + ------- + dict + Mapping from latent variable index to a string equation, e.g. + ``"L0(t) = 0.5 * L1(t - 1) + 0.3 * L0(t - 2)"``. + """ equations = {} for latent_var, links in links_coeffs.items(): @@ -267,6 +684,21 @@ def extract_latent_equations(links_coeffs): def extract_equations_from_adjacency(adj_matrices): + """ + Derive human-readable latent equations from a stack of adjacency matrices. + + Parameters + ---------- + adj_matrices : np.ndarray + Adjacency matrices with shape ``(num_lags, num_latents, num_latents)``. + Non-zero entries indicate a causal link. + + Returns + ------- + dict + Mapping from latent variable index to a string equation built from + the non-zero entries of the adjacency matrices across all lags. + """ num_lags, num_latents, _ = adj_matrices.shape # 5 lags, 16 latents equations = {} @@ -311,71 +743,106 @@ def main(csv_file, permutation): def save_equations_to_json(equations, filename): + """ + Serialise a dictionary of latent equations to a JSON file. + + Parameters + ---------- + equations : dict + Mapping from latent variable index to equation string. + filename : str or Path + Destination file path. + """ with open(filename, "w") as json_file: json.dump(equations, json_file, indent=4) - print(f"Equations saved to {filename}") + logger.info(f"Equations saved to {filename}") # Example usage: +# NOTE: The paths below are hardcoded to a specific user/cluster environment. +# Adjust savar_path, results_path, and config_path to match your setup. if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + + threshold = 0.5 + + # load your existing JSON config + config_path = Path("configs/single_param_file_savar.json") + with open(config_path, "r") as f: + cfg = json.load(f) + + exp = cfg["exp_params"] + data = cfg["data_params"] + savar = cfg["savar_params"] + + # pull out exactly the bits you used to hard-code + tau = exp["tau"] + n_modes = exp["d_z"] # latent dim = number of modes + comp_size = savar["comp_size"] + time_len = savar["time_len"] + is_forced = savar["is_forced"] + seasonality = savar["seasonality"] + overlap = savar["overlap"] + difficulty = savar["difficulty"] + lat = lon = int(np.sqrt(n_modes)) * comp_size + noise_val = savar["noise_val"] - # Set parameters here - home_path = Path("$HOME") - tau = 5 - threshold = 0.75 - n_modes_gt = 25 - difficulty = "easy" - iteration = 2999 - comp_size = 25 + home_path = str(Path.home()) + savar_path = "/my_projects/climatem/workspace/pfs7wor9/ka_qa4548-data/SAVAR_DATA_TEST" + results_path = Path( + "my_projects/climatem/workspace/pfs7wor9/ka_qa4548-results/SAVAR_DATA_TEST/var_savar_scenarios_piControl_nonlinear_False_tau_5_z_9_lr_0.001_bs_256_spreg_0_ormuinit_100000.0_spmuinit_0.1_spthres_0.05_fixed_False_num_ensembles_1_instantaneous_False_crpscoef_1_spcoef_0_tempspcoef_0_overlap_0.3_forcing_True" + ) - n_modes = n_modes_gt - lat = lon = int(np.sqrt(n_modes)) * comp_size - folder_results = home_path / f"predictions_{n_modes_gt}_{difficulty}" - savar_folder = home_path / Path("savar_data") - savar_fname = f"m_{n_modes_gt}_{difficulty}_savar_name" - run_name = "model_results_folder" - results_path = folder_results / f"{savar_fname}_{run_name}" - csv_files = [results_path / f"adjacency_transition_time_{i}_iteration_{iteration}.csv" for i in np.arange(5, 0, -1)] - - # Get the permuted adjacency matrices for all time lags - modes_gt = np.load(savar_folder / f"{savar_fname}_mode_weights.npy") - mat_adj_w = np.load(results_path / f"adj_w_iteration_{iteration}.npy")[0] - - if n_modes == 100: - # With lots of modes some modes are equal and the other function breaks. This function works for the specifics params of the 100 modes dataset. - permutation_list = get_permutation_list_hardcoded_100(mat_adj_w, modes_gt, lat, lon) - else: - permutation_list = get_permutation_list(mat_adj_w, modes_gt, lat, lon) - permuted_matrices = np.array(load_and_permute_all_matrices(csv_files, permutation_list)) + # Load ground truthh modes + savar_folder = home_path + savar_path + savar_fname = f"modes_{n_modes}_tl_{time_len}_isforced_{is_forced}_difficulty_{difficulty}_noisestrength_{noise_val}_seasonality_{seasonality}_overlap_{overlap}" + # modes_gt_path = savar_folder / Path(f"/{savar_fname}_mode_weights.npy") + modes_gt = np.load(f"{savar_folder}/{savar_fname}_mode_weights.npy") + + result_folder = home_path / results_path + # load CDSD results + cdsd_adj_inferred_path = result_folder / Path("plots/graphs.npy") + cdsd_modes_inferred_path = result_folder / Path("plots/w_decoder.npy") + modes_inferred = np.load(cdsd_modes_inferred_path) + adjacency_inferred = np.load(cdsd_adj_inferred_path) + + # if n_modes == 100: + # # With lots of modes some modes are equal and the other function breaks. This function works for the specifics params of the 100 modes dataset. + # permutation_list = get_permutation_list(mat_adj_w, modes_gt, lat, lon) + # else: + # permutation_list = get_permutation_list(mat_adj_w, modes_gt, lat, lon) + permuted_matrices = np.array( + load_and_permute_all_matrices(modes_inferred, modes_gt, adjacency_inferred, adjacency_inferred, lat, lon, tau) + ) # Load parameters from npy file - params_file = savar_folder / f"{savar_fname}_parameters.npy" + params_file = f"{savar_folder}/{savar_fname}_parameters.npy" params = np.load(params_file, allow_pickle=True).item() links_coeffs = params["links_coeffs"] - gt_adj_list = extract_adjacency_matrix(links_coeffs, n_modes_gt, tau) + gt_adj_list = extract_adjacency_matrix(links_coeffs, n_modes, tau) plot_adjacency_matrix( mat1=binarize_matrix(permuted_matrices, threshold), mat2=gt_adj_list, mat3=gt_adj_list, - path=results_path, + path=result_folder, name=f"permuted_adjacency_thr_{threshold}", no_gt=False, - iteration=iteration, + iteration=20000, plot_through_time=True, ) - save_equations_to_json(extract_latent_equations(links_coeffs), results_path / "gt_eq") + save_equations_to_json(extract_latent_equations(links_coeffs), result_folder / "gt_eq") save_equations_to_json( extract_equations_from_adjacency(binarize_matrix(permuted_matrices, threshold)), - results_path / f"thr_{threshold}_results_eq", + result_folder / f"thr_{threshold}_results_eq", ) precision, recall, f1, shd = evaluate_adjacency_matrix(permuted_matrices, gt_adj_list, threshold) - print(f"Precision: {precision}, Recall: {recall}, F1 Score: {f1}, SHD: {shd}") + logger.info(f"Precision: {precision}, Recall: {recall}, F1 Score: {f1}, SHD: {shd}") results = {"precision": precision, "recall": recall, "f1_score": f1, "shd": shd} # Save results as a JSON file - json_filename = results_path / f"thr_{threshold}_evaluation_results.json" + json_filename = result_folder / f"thr_{threshold}_evaluation_results.json" with open(json_filename, "w") as json_file: json.dump(results, json_file) diff --git a/climatem/synthetic_data/savar.py b/climatem/synthetic_data/savar.py index 8884f70..0645894 100644 --- a/climatem/synthetic_data/savar.py +++ b/climatem/synthetic_data/savar.py @@ -7,8 +7,9 @@ """ import itertools as it +import math from copy import deepcopy -from math import pi, sin, sqrt +from math import sqrt from typing import List import numpy as np @@ -17,6 +18,10 @@ from torch.distributions.multivariate_normal import MultivariateNormal from tqdm.auto import tqdm +from climatem.utils import get_logger + +logger = get_logger(__name__) + def dict_to_matrix(links_coeffs, default=0): """ @@ -39,12 +44,114 @@ def dict_to_matrix(links_coeffs, default=0): return graph +def normalize_transition_matrix(phi: np.ndarray, eps: float = 1e-8) -> np.ndarray: + """ + Normalize the transition matrix P so that for each target mode, the sum of all incoming link coefficients equals 1. + + This follows the SAVAR SNR formulation constraint: sum of links in P = 1. + + Args: + phi: Transition matrix of shape (n_modes, n_modes, tau_max) + where phi[j, i, tau] is the coefficient from mode i to mode j at lag tau+1 + eps: Small constant to avoid division by zero + + Returns: + Normalized phi where sum over (i, tau) of |phi[j, i, tau]| = 1 for each j + """ + phi_normalized = phi.copy() + n_modes = phi.shape[0] + + for j in range(n_modes): + # Sum of absolute values of all incoming links to mode j + total = np.abs(phi[j, :, :]).sum() + if total > eps: + phi_normalized[j, :, :] = phi[j, :, :] / total + + return phi_normalized + + +def normalize_transition_with_forcing( + phi_climate: np.ndarray, phi_forcing: np.ndarray, n_climate_modes: int, eps: float = 1e-8 +) -> tuple: + """ + Normalize transition matrix including forcing coefficients. + + For each climate mode j, the sum of all incoming coefficients equals 1: + sum(|phi_climate[j,:,:]|) + sum(|phi_forcing[j,:,:]|) = 1 + + This ensures the SAVAR SNR constraint is maintained when forcing + is included in the dynamics. + + Args: + phi_climate: Climate-to-climate coefficients (n_climate, n_climate, tau_max) + phi_forcing: Forcing-to-climate coefficients (n_climate, n_forcing, tau_max) + n_climate_modes: Number of climate modes + eps: Small constant to avoid division by zero + + Returns: + Tuple of (normalized phi_climate, normalized phi_forcing) + """ + phi_climate_norm = phi_climate.copy() + phi_forcing_norm = phi_forcing.copy() + + for j in range(n_climate_modes): + # Total incoming coefficient sum (climate + forcing) + climate_sum = np.abs(phi_climate[j, :, :]).sum() + forcing_sum = np.abs(phi_forcing[j, :, :]).sum() + total = climate_sum + forcing_sum + + if total > eps: + phi_climate_norm[j, :, :] = phi_climate[j, :, :] / total + phi_forcing_norm[j, :, :] = phi_forcing[j, :, :] / total + + logger.debug( + f"Mode {j}: climate links = {climate_sum:.4f}, " f"forcing links = {forcing_sum:.4f}, total = {total:.4f}" + ) + + return phi_climate_norm, phi_forcing_norm + + class SAVAR: - """Main class containing SAVAR model.""" + """ + SAVAR synthetic climate data generator. + + Generates spatiotemporal synthetic climate data with a known causal structure, + used to benchmark causal discovery methods. Based on Tibau et al. (2022), extended + with GPU acceleration, external forcing (CO2 + aerosol), and background state. + + Data generation pipeline: + 1. Seasonality: harmonic components with optional yearly jitter + 2. Forcing trajectories: CO2 (monotonic ramp) and aerosol (regional, staggered) + latent time series are generated as exogenous inputs + 3. Dynamics: autoregressive evolution in latent (mode) space: + z_j(t) = sum_{k,tau} P_climate[j,k,tau] * z_k(t-tau) + + sum_{f,tau} P_forcing[j,f,tau] * forcing_f(t-tau) + N_j(t) + S_t = W^{-1} * z(t) (project back to observation space) + where W = mode_weights (mixing matrix), P_climate/P_forcing = transition + matrices, and N_t is observation-space noise with covariance s * W^T * W. + 4. Background: optional slow, spatially-smooth AR(1) drift added post-dynamics + + Key attributes: + links_coeffs (dict): Causal graph as {target_idx: [((source_idx, lag), coeff), ...]}. + n_climate_modes (int): Number of climate latent variables (from mode_weights). + n_vars (int): Total latent count (climate + forcing modes). + mode_weights (ndarray): Spatial patterns (mixing matrix W), shape (n_modes, D_x, D_y) + where D_x * D_y = spatial_resolution. + data_field (ndarray): Generated spatiotemporal data, shape (spatial_resolution, time_length). + noise_data_field (ndarray): Pre-generated noise, shape (spatial_resolution, time_length + transient). + seasonal_data_field (ndarray): Seasonality component (same shape as data_field before trimming). + co2_latent_trajectory (ndarray): CO2 latent time series, shape (time,). + aerosol_latent_trajectory (ndarray): Aerosol latent time series, shape (n_aerosol, time). + aerosol_spatial_templates (ndarray): Orthogonal spatial templates for aerosol extraction, + shape (n_aerosol, spatial_resolution). + forcing_coeffs (dict): Extracted forcing-to-climate causal coefficients. + forcing_amplification (float): Scaling factor for forcing contributions. + """ __slots__ = [ "links_coeffs", "n_vars", + "n_climate_modes", "time_length", "transient", "spatial_resolution", @@ -54,19 +161,38 @@ class SAVAR: "noise_cov", "noise_strength", "noise_variance", + "noise_ar1", + "noise_ar1_rho", "latent_noise_cov", "fast_noise_cov", "forcing_dict", + "forcing_indices", + "forcing_coeffs", + "forcing_amplification", "season_dict", "data_field", "noise_data_field", "seasonal_data_field", "forcing_data_field", + "co2_forcing_data_field", + "aerosol_forcing_data_field", + "co2_latent_trajectory", + "aerosol_latent_trajectory", + "aerosol_spatial_templates", + "background_data_field", "linearity", "poly_degrees", "verbose", "model_seed", "nnar_model", + "output_save_dir", + # Background state parameters + "enable_background", + "background_strength", + "background_strength_mode", + "background_smoothness", + "background_timescale_rho", + "background_n_modes", ] def __init__( @@ -78,10 +204,14 @@ def __init__( noise_weights: np.ndarray = None, noise_strength: float = 1, noise_variance: float = 1, + noise_ar1: bool = True, + noise_ar1_rho: float = 0.95, noise_cov: np.ndarray = None, latent_noise_cov: np.ndarray = None, fast_cov: np.ndarray = None, forcing_dict: dict = None, + forcing_indices: dict = None, + forcing_amplification: float = 5.0, linearity: str = "linear", poly_degrees: List[int] = [2], season_dict: dict = None, @@ -89,8 +219,18 @@ def __init__( noise_data_field: np.ndarray = None, seasonal_data_field: np.ndarray = None, forcing_data_field: np.ndarray = None, + co2_forcing_data_field: np.ndarray = None, + aerosol_forcing_data_field: np.ndarray = None, verbose: bool = False, model_seed: int = None, + output_save_dir: str = None, + # Background state parameters + enable_background: bool = False, + background_strength: float = 0.3, + background_strength_mode: str = "relative", + background_smoothness: float = 0.15, + background_timescale_rho: float = 0.995, + background_n_modes: int = 3, ): self.links_coeffs = links_coeffs @@ -98,6 +238,8 @@ def __init__( self.transient = transient self.noise_strength = noise_strength self.noise_variance = noise_variance # TODO: NOT USED. + self.noise_ar1 = noise_ar1 + self.noise_ar1_rho = noise_ar1_rho self.noise_cov = noise_cov self.latent_noise_cov = latent_noise_cov # D_x @@ -107,6 +249,8 @@ def __init__( self.noise_weights = noise_weights self.forcing_dict = forcing_dict + self.forcing_indices = forcing_indices + self.forcing_amplification = forcing_amplification self.season_dict = season_dict self.linearity = linearity self.poly_degrees = poly_degrees @@ -115,104 +259,218 @@ def __init__( self.verbose = verbose self.model_seed = model_seed + self.output_save_dir = output_save_dir + + # Background state parameters + self.enable_background = enable_background + self.background_strength = background_strength + self.background_strength_mode = background_strength_mode + self.background_smoothness = background_smoothness + self.background_timescale_rho = background_timescale_rho + self.background_n_modes = background_n_modes # Computed attributes - print("Creating attributes") - self.n_vars = len(links_coeffs) + logger.debug("Creating attributes") + # n_climate_modes is the number of climate variables (from mode_weights) + self.n_climate_modes = mode_weights.shape[0] + # n_vars is total latents (climate + forcing) if forcing_indices provided + if forcing_indices is not None: + self.n_vars = forcing_indices.get("n_total", self.n_climate_modes) + else: + self.n_vars = self.n_climate_modes self.tau_max = max(abs(lag) for (_, lag), _ in it.chain.from_iterable(self.links_coeffs.values())) - self.spatial_resolution = deepcopy(self.mode_weights.reshape(self.n_vars, -1).shape[1]) + self.spatial_resolution = deepcopy(self.mode_weights.reshape(self.n_climate_modes, -1).shape[1]) print("spatial-resolution done") + # Extract forcing → mode coefficients if forcing is used + self.forcing_coeffs = self._extract_forcing_coefficients() if forcing_indices else None + + # Initialize forcing latent trajectories (populated during forcing generation) + self.co2_latent_trajectory = None + self.aerosol_latent_trajectory = None + self.aerosol_spatial_templates = None # Orthogonal spatial templates for aerosol extraction + logger.debug("spatial-resolution done") + if self.noise_weights is None: self.noise_weights = deepcopy(self.mode_weights) if self.latent_noise_cov is None: self.latent_noise_cov = np.eye(self.n_vars) if self.fast_noise_cov is None: self.fast_noise_cov = np.zeros((self.spatial_resolution, self.spatial_resolution)) - print("copies done") + logger.debug("copies done") # Empty attributes self.noise_data_field = noise_data_field self.seasonal_data_field = seasonal_data_field self.forcing_data_field = forcing_data_field + self.co2_forcing_data_field = co2_forcing_data_field + self.aerosol_forcing_data_field = aerosol_forcing_data_field + self.background_data_field = None if np.random is not None: np.random.seed(model_seed) - def generate_data(self, train_nnar=True) -> None: - """Generates the data of savar :return:""" + def _extract_forcing_coefficients(self): + """ + Extract forcing → climate mode causal coefficients from links_coeffs. + + Returns a dictionary with structure: + { + 'co2_to_modes': {mode_idx: [(forcing_idx, lag, coeff), ...]}, + 'aerosol_to_modes': {mode_idx: [(forcing_idx, lag, coeff), ...]}, + } + """ + if self.forcing_indices is None: + return None + + co2_indices = set(self.forcing_indices.get("co2", [])) + aerosol_indices = set(self.forcing_indices.get("aerosol", [])) + + forcing_coeffs = { + "co2_to_modes": {m: [] for m in range(self.n_climate_modes)}, + "aerosol_to_modes": {m: [] for m in range(self.n_climate_modes)}, + } + + # Scan links_coeffs for forcing → mode connections + for target_idx, links in self.links_coeffs.items(): + # Only consider climate modes as targets + if target_idx >= self.n_climate_modes: + continue + + for (source_idx, lag), coeff in links: + if source_idx in co2_indices: + forcing_coeffs["co2_to_modes"][target_idx].append((source_idx, lag, coeff)) + elif source_idx in aerosol_indices: + forcing_coeffs["aerosol_to_modes"][target_idx].append((source_idx, lag, coeff)) + + # Print summary + n_co2_links = sum(len(v) for v in forcing_coeffs["co2_to_modes"].values()) + n_aerosol_links = sum(len(v) for v in forcing_coeffs["aerosol_to_modes"].values()) + logger.info( + f"Extracted forcing coefficients: {n_co2_links} CO2→mode links, {n_aerosol_links} aerosol→mode links" + ) + + return forcing_coeffs + + def generate_data(self, train_nnar=True, include_noise=True) -> None: + """ + Generates the data of savar. + + Args: + train_nnar: Whether to train NNAR model for nonlinear data + include_noise: Whether to include noise in the generated data + """ # Prepare the datafield if self.data_field is None: - if self.verbose: - print("Creating empty data field") - # Compute the field + logger.debug("Creating empty data field") self.data_field = np.zeros((self.spatial_resolution, self.time_length + self.transient)) - # Add noise first - if self.noise_data_field is None: - if self.verbose: - print("Creating noise data field") - self._add_noise_field() + self._apply_noise(include_noise) + self._apply_seasonality() + self._apply_forcing() + self._apply_dynamics(train_nnar) + self._apply_background() + + def _apply_noise(self, include_noise): + """Add noise to data_field BEFORE dynamics (baseline behavior).""" + if include_noise: + if self.noise_data_field is None: + logger.debug("Creating noise data field and adding to data_field") + self._add_noise_field() + else: + logger.debug("Reusing existing noise_data_field, adding to data_field") + self.data_field += self.noise_data_field else: - self.data_field += self.noise_data_field - - # Add seasonality - if self.season_dict is not None: - if self.verbose: - print("Adding seasonality forcing") - self._add_seasonality_forcing() + logger.debug("Skipping noise (generating deterministic data)") + + def _apply_seasonality(self): + """Add seasonality to data_field if configured.""" + if self.season_dict is None: + logger.info("No seasonality") + return + logger.debug("Adding seasonality forcing") + if self.seasonal_data_field is not None: + logger.debug("Reusing existing seasonal_data_field") + self.data_field += self.seasonal_data_field else: - print("No seasonality") + self._add_seasonality_forcing() - # Add external forcing - if self.forcing_dict is not None: - initial_data = self.data_field.copy() - self._add_external_forcing() - diff = self.data_field - initial_data - if self.verbose: - print("Adding external forcing") - print(f"Max change in data field: {diff.max()}") - print(f"Mean change in data field: {diff.mean()}") - print(f"Sample values after forcing applied:\n{diff[:, :5]}") - elif self.verbose: - print("No forcing") + def _apply_forcing(self): + """Generate forcing latent trajectories if configured.""" + if self.forcing_dict is None: + return + logger.debug("Generating forcing latent trajectories (applied during dynamics)") + if self.co2_latent_trajectory is None: + self._generate_forcing_trajectories() + else: + logger.debug("Reusing existing forcing latent trajectories") + def _apply_dynamics(self, train_nnar): + """Compute data using the configured linearity model.""" # Compute the data if self.linearity == "linear": - if self.verbose: - print("Creating linear data") + logger.debug("Creating linear data") self._create_linear() elif self.linearity == "polynomial": - if self.verbose: - print("Creating polynomial data") + logger.debug("Creating polynomial data") self._create_polynomial() else: - if self.verbose: - print("Creating nonlinear data") + logger.debug("Creating nonlinear data") if train_nnar: - if self.verbose: - print("Training NNAR model before data generation...") + logger.info("Training NNAR model before data generation...") self.train_nnar(num_epochs=50, learning_rate=0.001, batch_size=32) self._create_nonlinear() + def _apply_background(self): + """Add background state AFTER mode dynamics to avoid entanglement with causal graph.""" + if not self.enable_background: + return + logger.debug("Adding low-frequency background state") + if self.background_data_field is not None: + logger.debug("Reusing existing background_data_field") + self.data_field += self.background_data_field + else: + self._add_background_field( + background_strength=self.background_strength, + background_strength_mode=self.background_strength_mode, + spatial_smoothness=self.background_smoothness, + time_rho=self.background_timescale_rho, + n_bg_modes=self.background_n_modes, + ) + def generate_cov_noise_matrix(self) -> np.ndarray: """ - W in NxL data_field L times T. + Generate the noise covariance matrix: Cov = s · W^+ · (W^+)^T. - :return: - """ + This gives noise that is anti-correlated with the mode spatial structure, + ensuring the model can distinguish signal from noise. W^+ is the + pseudoinverse of the noise_weights mixing matrix. - W = deepcopy(self.noise_weights).reshape(self.n_vars, -1) - print(f"noise_weights copied, {W.shape}") + Returns: + cov: Covariance matrix of shape (spatial_resolution, spatial_resolution). + """ + W = deepcopy(self.noise_weights).reshape(self.n_climate_modes, -1) + logger.debug(f"noise_weights copied, {W.shape}") W_plus = np.linalg.pinv(W) - print("noise_weights inverted") - # Can we speed this up? since they are all np.eye - cov = self.noise_strength * W_plus @ W_plus.transpose() # + self.fast_noise_cov - print("cov created inverted") + logger.debug("noise_weights inverted") + cov = self.noise_strength * W_plus @ W_plus.transpose() + logger.debug("cov created") return cov def _add_noise_field(self): + """ + Generate noise in observation space with covariance s · W^+ · (W^+)^T. + + Uses the pseudoinverse of the mixing matrix so that the noise covariance + is *anti-correlated* with the mode spatial patterns. This is essential for + the model to distinguish signal (mode-aligned) from noise. + + NOTE: This method stores the noise in self.noise_data_field but does NOT + add it to self.data_field. The noise is applied during dynamics + (_create_linear, etc.) following: S_t = W^{-1}·P·W·S_{t-τ} + N_t + """ + logger.info("Generate noise_data_field with covariance s·W^+·(W^+)^T") if self.noise_cov is None: if self.verbose: @@ -229,77 +487,227 @@ def _add_noise_field(self): noise_data_field = distrib.sample(sample_shape=torch.Size([self.time_length + self.transient])) self.noise_data_field = noise_data_field.detach().cpu().numpy().transpose() - # self.noise_data_field = np.random.multivariate_normal(mean=np.zeros(self.spatial_resolution), cov=self.noise_cov, - # size=self.time_length + self.transient).transpose() - + logger.info( + f"Generated noise with shape {self.noise_data_field.shape}, " + f"std={self.noise_data_field.std():.4f}, " + f"noise_strength (s)={self.noise_strength:.4f}" + ) + # Add noise to data_field BEFORE dynamics (baseline behavior). + # The AR dynamics will then evolve the noise-inclusive field, + # effectively dampening the noise through the transition coefficients. self.data_field += self.noise_data_field def _add_seasonality_forcing(self): + """ + Add deterministic seasonal cycle to the data field. + + Constructs a sum of sinusoidal harmonics (e.g. annual, semi-annual) with + optional year-to-year jitter in amplitude and phase. The result is stored + in ``self.seasonal_data_field`` and added to ``self.data_field``. + + Configuration comes from ``self.season_dict`` with keys: + periods (list[float]): harmonic periods in timesteps, e.g. [365, 182.5, 60]. + amplitudes (list[float]): amplitude for each harmonic. + phases (list[float]): phase offset (radians) for each harmonic. + yearly_jitter (dict|None): {"amplitude": float, "phase": float} controlling + inter-annual variability of each harmonic. + season_weight (ndarray|None): per-gridpoint multiplier (latitude-dependent + seasonality strength), shape (spatial_resolution,). + Adds external forcing to the data field using PyTorch tensors for GPU acceleration. - # A*sin((2pi/lambda)*x) A = amplitude, lambda = period - amplitude = self.season_dict["amplitude"] - period = self.season_dict["period"] - season_weight = self.season_dict.get("season_weight", None) + Allows for both linear and nonlinear ramps. + """ + periods = self.season_dict["periods"] # e.g. [12, 6, 3] for year, half-year, quarter-year + amplitudes = self.season_dict["amplitudes"] # same length as periods + phases = self.season_dict.get("phases", [0.0] * len(periods)) + + # year-to-year amplitude / phase jitter + jitter_cfg = self.season_dict.get("yearly_jitter") # None or dict + base_P = periods[0] # assume first is annual (12 months) + dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + + T = self.time_length + self.transient + ncy = math.ceil(T / base_P) # # of whole cycles + + L = self.data_field.shape[0] + T = self.time_length + self.transient + t = torch.arange(T, device=dev, dtype=dtype) + seasonal = torch.zeros((L, T), device=dev, dtype=dtype) + + σ_A = jitter_cfg["amplitude"] if jitter_cfg else 0.0 + σ_φ = jitter_cfg["phase"] if jitter_cfg else 0.0 + + # allow vector inputs, default to identical values otherwise + σ_Ak = torch.as_tensor(σ_A).expand(len(periods)).to(dtype=dtype, device=dev) + σ_φk = torch.as_tensor(σ_φ).expand(len(periods)).to(dtype=dtype, device=dev) + + for k, (A, P, φ) in enumerate(zip(amplitudes, periods, phases)): + # one jitter draw *per year* for this harmonic + amp_noise_k = 1 + σ_Ak[k] * torch.randn(ncy, device=dev, dtype=dtype) + phase_noise_k = σ_φk[k] * torch.randn(ncy, device=dev, dtype=dtype) + + amp_series_k = amp_noise_k.repeat_interleave(base_P)[:T] # (T,) + phase_series_k = phase_noise_k.repeat_interleave(base_P)[:T] # (T,) + + seasonal += amp_series_k * A * torch.sin(2 * math.pi / P * (t + phase_series_k) + φ) + + w = self.season_dict.get("season_weight") + if w is not None: + if not torch.is_tensor(w): + w = torch.as_tensor(w, dtype=dtype, device=dev) + else: + w = w.to(device=dev, dtype=dtype) + if w.ndim > 1: + w = w.reshape(-1) + if w.numel() != L: + raise ValueError(f"season_weight has length {w.numel()} but grid has {L} points") + seasonal *= w.reshape(L, 1) + + seasonal_np = seasonal.cpu().numpy() + self.seasonal_data_field = seasonal_np + self.data_field += seasonal_np + + def _apply_season_forcing_interaction(self, forcing_field, interaction_cfg=None): + """Modulate a forcing field by the seasonal cycle if requested.""" + if interaction_cfg is None: + interaction_cfg = (self.forcing_dict or {}).get("season_interaction") + + if not interaction_cfg: + return forcing_field + + if self.seasonal_data_field is None: + logger.warning("season_interaction requested but seasonal_data_field is missing; skipping interaction") + return forcing_field + + # Ensure forcing and seasonal fields share the same device / dtype for math + np_dtype = None + if torch.is_tensor(forcing_field): + dev = forcing_field.device + dtype = forcing_field.dtype + forcing_tensor = forcing_field + else: + dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + forcing_tensor = torch.as_tensor(forcing_field, device=dev, dtype=dtype) + np_dtype = getattr(forcing_field, "dtype", None) - seasonal_trend = np.asarray( - [amplitude * sin((2 * pi / period) * x) for x in range(self.time_length + self.transient)] - ) + # Bring the precomputed seasonal cycle onto the same device so interactions are cheap. + seasonal_tensor = torch.as_tensor(self.seasonal_data_field, device=dev, dtype=dtype) - seasonal_data_field = np.ones_like(self.data_field) - seasonal_data_field *= seasonal_trend.reshape(1, -1) + if seasonal_tensor.shape != forcing_tensor.shape: + # Using mismatched grids would silently broadcast; guard it so users catch config errors. + raise ValueError( + f"seasonal_data_field shape {seasonal_tensor.shape} does not match forcing shape {forcing_tensor.shape}" + ) - # Apply seasonal weights - if season_weight is not None: - season_weight = season_weight.sum(axis=0).reshape(self.spatial_resolution) # vector dim L - seasonal_data_field *= season_weight[:, None] # L times T + # Optional normalisation lets us work with comparable seasonal amplitudes + eps = float(interaction_cfg.get("eps", 1e-6)) + norm = str(interaction_cfg.get("normalisation", "zscore")).lower() + + if norm == "zscore": + # Remove the seasonal mean and scale by variance so anomalies are dimensionless. + mean = seasonal_tensor.mean(dim=1, keepdim=True) + std = seasonal_tensor.std(dim=1, keepdim=True) + seasonal_tensor = (seasonal_tensor - mean) / (std + eps) + elif norm == "minmax": + # Stretch to [-0.5, 0.5] so the strength parameter is intuitive. + s_min = seasonal_tensor.amin(dim=1, keepdim=True) + s_max = seasonal_tensor.amax(dim=1, keepdim=True) + seasonal_tensor = (seasonal_tensor - s_min) / (s_max - s_min + eps) + seasonal_tensor = seasonal_tensor - 0.5 + elif norm in ("none", "identity"): + pass + else: + raise ValueError(f"Unsupported season_interaction normalisation '{norm}'") + + # Apply the requested interaction mode to blend seasonality with forcing + mode = str(interaction_cfg.get("mode", "multiplicative")).lower() + strength = float(interaction_cfg.get("strength", 1.0)) + + if mode == "multiplicative": + # Seasonal anomalies rescale the forcing field, raising or lowering its amplitude. + scale = 1.0 + strength * seasonal_tensor + min_scale = interaction_cfg.get("min_scale") + max_scale = interaction_cfg.get("max_scale") + if min_scale is not None: + scale = torch.clamp(scale, min=float(min_scale)) + if max_scale is not None: + scale = torch.clamp(scale, max=float(max_scale)) + forcing_tensor = forcing_tensor * scale + elif mode == "additive": + # Inject the seasonal fluctuations directly as an additional perturbation. + forcing_tensor = forcing_tensor + strength * seasonal_tensor + elif mode == "hybrid": + # Combine both: a multiplicative scaling plus an additive share (controlled via mix). + mix = float(interaction_cfg.get("mix", 0.5)) + scale = 1.0 + strength * seasonal_tensor + forcing_tensor = forcing_tensor * scale + mix * strength * seasonal_tensor + else: + raise ValueError(f"Unsupported season_interaction mode '{mode}'") - self.seasonal_data_field = seasonal_data_field + # Final affine tweak so users can bias the modulation if desired + bias = float(interaction_cfg.get("bias", 0.0)) + if bias != 0.0: + forcing_tensor = forcing_tensor + bias + + if torch.is_tensor(forcing_field): + return forcing_tensor - # Add it to the data field. - self.data_field += seasonal_data_field + result = forcing_tensor.detach().cpu().numpy() + if np_dtype is not None: + # Preserve the caller's dtype to avoid surprising precision changes. + result = result.astype(np_dtype, copy=False) + return result - def _add_external_forcing(self): + def create_co2_forcing(self) -> np.ndarray: """ - Adds external forcing to the data field using PyTorch tensors for GPU acceleration. + Create a CO2 forcing field that grows over time with mild spatial variability. - Allows for both linear and nonlinear ramps. + Uses f_1, f_2, f_time_1, f_time_2, and ramp_type from forcing_dict to control the temporal evolution of CO2 + forcing. + + Returns an array shaped (spatial_resolution, time_length + transient) that can be added to the synthetic field + or used as an external driver. """ - if self.forcing_dict is None: - raise TypeError("Forcing dict is empty") - - w_f = deepcopy(self.forcing_dict.get("w_f")) - f_1 = float(self.forcing_dict.get("f_1", 0)) - f_2 = float(self.forcing_dict.get("f_2", 0)) - f_time_1 = self.forcing_dict.get("f_time_1", 0) - f_time_2 = self.forcing_dict.get("f_time_2", self.time_length) - ramp_type = self.forcing_dict.get("ramp_type", "linear") # Default to linear - - if w_f is None: - w_f = deepcopy(self.mode_weights) - w_f = (w_f != 0).astype(int) # Convert non-zero elements to 1 - - # Merge last two dims first => shape (d_z, lat*lon) - temp = w_f.reshape(w_f.shape[0], w_f.shape[1] * w_f.shape[2]) - # sum over dim=0 => shape (lat*lon,) - w_f_sum = torch.tensor(temp.sum(axis=0), dtype=torch.float32, device="cuda") + + dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + + time_len = self.time_length + self.transient + if time_len <= 0: + raise ValueError("Time length including transient must be positive") + + spatial_len = self.spatial_resolution + if spatial_len <= 0: + raise ValueError("Spatial resolution must be positive") + + # Get forcing parameters from forcing_dict + forcing_cfg = self.forcing_dict or {} + f_1 = float(forcing_cfg.get("f_1", 0.0)) + f_2 = float(forcing_cfg.get("f_2", 0.1)) + f_time_1 = int(forcing_cfg.get("f_time_1", 0)) + f_time_2 = int(forcing_cfg.get("f_time_2", time_len)) + ramp_type = forcing_cfg.get("ramp_type", "linear") + + # Adjust times to include transient period f_time_1 += self.transient f_time_2 += self.transient - time_length = self.time_length + self.transient # Generate the forcing trend using torch tensors if ramp_type == "linear": - ramp = torch.linspace(f_1, f_2, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") + ramp = torch.linspace(f_1, f_2, f_time_2 - f_time_1, dtype=dtype, device=dev) elif ramp_type == "quadratic": - t = torch.linspace(0, 1, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") + t = torch.linspace(0, 1, f_time_2 - f_time_1, dtype=dtype, device=dev) ramp = f_1 + (f_2 - f_1) * t**2 elif ramp_type == "exponential": - t = torch.linspace(0, 1, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") + t = torch.linspace(0, 1, f_time_2 - f_time_1, dtype=dtype, device=dev) ramp = f_1 + (f_2 - f_1) * (torch.exp(t) - 1) / (torch.exp(torch.tensor(1.0)) - 1) elif ramp_type == "sigmoid": - t = torch.linspace(-6, 6, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") + t = torch.linspace(-6, 6, f_time_2 - f_time_1, dtype=dtype, device=dev) ramp = f_1 + (f_2 - f_1) * (1 / (1 + torch.exp(-t))) elif ramp_type == "sinusoidal": - t = torch.linspace(0, pi, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") + t = torch.linspace(0, math.pi, f_time_2 - f_time_1, dtype=dtype, device=dev) ramp = f_1 + (f_2 - f_1) * (0.5 * (1 - torch.cos(t))) else: raise ValueError( @@ -309,19 +717,44 @@ def _add_external_forcing(self): # Generate the forcing trend using torch tensors trend = torch.cat( [ - torch.full((f_time_1,), f_1, dtype=torch.float32, device="cuda"), + torch.full((f_time_1,), f_1, dtype=dtype, device=dev), ramp, - torch.full((time_length - f_time_2,), f_2, dtype=torch.float32, device="cuda"), + torch.full((time_len - f_time_2,), f_2, dtype=dtype, device=dev), ] - ).reshape(1, time_length) + ).reshape(1, time_len) - if w_f_sum.dim() == 2: - w_f_sum = w_f_sum.sum(dim=0, keepdim=True) # Sum across the correct dimension + logger.info( + f"CO2 forcing: Using {ramp_type} ramp from f_1={f_1} to f_2={f_2} " + f"(time {f_time_1} to {f_time_2}, total length {time_len})" + ) - # Compute the forcing field on GPU - forcing_field = (w_f_sum.reshape(1, -1) * trend.T).T - self.forcing_data_field = forcing_field.cpu().numpy() + logger.debug( + f"CO2 forcing: Using {ramp_type} ramp from f_1={f_1} to f_2={f_2} " + f"(time {f_time_1} to {f_time_2}, total length {time_len})" + ) + if spatial_len == 1: + base_pattern = torch.ones(1, device=dev, dtype=dtype) + else: + coords = torch.linspace(-1.0, 1.0, spatial_len, device=dev, dtype=dtype) + num_lobes = max(3, min(8, spatial_len // 64 + 3)) + rand_params = torch.as_tensor(np.random.rand(num_lobes, 3), device=dev, dtype=dtype) + centers = rand_params[:, 0] * 2.0 - 1.0 + widths = 0.3 + rand_params[:, 1] * 0.7 + amplitudes = 0.3 + rand_params[:, 2] * 0.7 + + diff = coords.unsqueeze(0) - centers.unsqueeze(1) + gaussians = torch.exp(-0.5 * (diff / widths.unsqueeze(1)) ** 2) + base_pattern = (amplitudes.unsqueeze(1) * gaussians).sum(dim=0) + base_pattern = base_pattern / (base_pattern.mean() + 1e-6) + + spatial_pattern = base_pattern + spatial_pattern = spatial_pattern / (spatial_pattern.mean() + 1e-8) + + # Add a deterministic oscillation so the grid is not uniform. + if spatial_len > 1: + idx = torch.linspace(0.0, 2.0 * math.pi, spatial_len, device=dev, dtype=dtype) + spatial_pattern = spatial_pattern * (1.0 + 0.1 * torch.sin(idx)) if self.verbose: print(f"Using {ramp_type} ramp: f_1={f_1}, f_2={f_2}, f_time_1={f_time_1}, f_time_2={f_time_2}") print(f"Forcing data field mean: {self.forcing_data_field.mean()}") @@ -329,74 +762,469 @@ def _add_external_forcing(self): # data_field_before = self.data_field.copy() - self.data_field += self.forcing_data_field + spatial_pattern = torch.clamp(spatial_pattern, min=0.05) + forcing = spatial_pattern.unsqueeze(1) * trend + + forcing_np = forcing.detach().cpu().numpy() # data_field_after = self.data_field if self.verbose: print(f"After addition - Data field mean: {self.data_field.mean()}") - # # Convert tensors to numpy for plotting if necessary - # if isinstance(w_f_sum, torch.Tensor): - # w_f_sum = w_f_sum.cpu().numpy() - # if isinstance(forcing_field, torch.Tensor): - # forcing_field = forcing_field.cpu().numpy() - # if isinstance(data_field_before, torch.Tensor): - # data_field_before = data_field_before.cpu().numpy() - # if isinstance(data_field_after, torch.Tensor): - # data_field_after = data_field_after.cpu().numpy() - - # # Compute mean values over spatial dimensions - # mean_forcing = forcing_field.mean(axis=0) - # mean_data_before = data_field_before.mean(axis=0) - # mean_data_after = data_field_after.mean(axis=0) - - # # Plot 1: Mean Forcing over Time - # plt.figure(figsize=(10, 4)) - # plt.plot(range(time_length), mean_forcing, label="Mean Forcing", color="blue") - # plt.axvline(x=f_time_1, linestyle="--", color="gray", label="Start Forcing") - # plt.axvline(x=f_time_2, linestyle="--", color="gray", label="End Forcing") - # plt.xlabel("Time Steps") - # plt.ylabel("Forcing Intensity") - # plt.title("Evolution of External Forcing Over Time") - # plt.legend() - # plt.grid() - # plt.savefig(f"mean_forcing_over_time_{f_1}_{f_2}_{ramp_type}.png") # Save to a file - # plt.close() - - # # Plot 2: Mean Data Before and After Forcing - # plt.figure(figsize=(10, 4)) - # plt.plot(range(time_length), mean_data_before, label="Data Before Forcing", color="red", linestyle="dashed") - # plt.plot(range(time_length), mean_data_after, label="Data After Forcing", color="green") - # plt.axvline(x=f_time_1, linestyle="--", color="gray", label="Start Forcing") - # plt.axvline(x=f_time_2, linestyle="--", color="gray", label="End Forcing") - # plt.xlabel("Time Steps") - # plt.ylabel("Mean Data Value") - # plt.title("Effect of Forcing on Data Field") - # plt.legend() - # plt.grid() - # plt.savefig(f"mean_data_before_after_forcing_{f_1}_{f_2}_{ramp_type}.png") # Save to a file - # plt.close() + return forcing_np + + def _resolve_aerosol_warming_index(self, warming_index, n_aerosol_latents: int): + """Resolve and validate the optional warming latent index.""" + if warming_index is None and n_aerosol_latents >= 4: + warming_index = 3 + try: + return int(warming_index) if warming_index is not None else None + except (TypeError, ValueError): + return None + + def _create_aerosol_spatial_templates( + self, n_aerosol_latents: int, spatial_len: int, aerosol_contrast: float, dev, dtype + ) -> torch.Tensor: + """Create and normalize aerosol spatial templates for latent projection.""" + if spatial_len == 1: + return torch.ones(n_aerosol_latents, 1, device=dev, dtype=dtype) + + grid_size = int(math.sqrt(spatial_len)) + is_2d_grid = (grid_size * grid_size == spatial_len) and grid_size > 1 + + if is_2d_grid: + logger.info(f"Creating 2D aerosol templates for {grid_size}x{grid_size} grid") + spatial_templates = self._create_aerosol_templates_2d(n_aerosol_latents, grid_size, dev, dtype) + else: + logger.info(f"Creating 1D aerosol templates for {spatial_len} points") + spatial_templates = self._create_aerosol_templates_1d(n_aerosol_latents, spatial_len, dev, dtype) + + self._normalize_aerosol_templates(spatial_templates, aerosol_contrast) + self._log_aerosol_template_orthogonality(spatial_templates) + return spatial_templates + + def _create_aerosol_templates_2d(self, n_aerosol_latents: int, grid_size: int, dev, dtype) -> torch.Tensor: + """Build 2D region-localized aerosol templates.""" + spatial_templates = torch.zeros(n_aerosol_latents, grid_size * grid_size, device=dev, dtype=dtype) + + lat_coords = torch.linspace(-1.0, 1.0, grid_size, device=dev, dtype=dtype) + lon_coords = torch.linspace(-1.0, 1.0, grid_size, device=dev, dtype=dtype) + lat_grid, lon_grid = torch.meshgrid(lat_coords, lon_coords, indexing="ij") + + def gaussian_2d(lat_center, lon_center, lat_sigma, lon_sigma, angle=0.0): + x = lon_grid - lon_center + y = lat_grid - lat_center + if angle != 0.0: + cos_a = math.cos(angle) + sin_a = math.sin(angle) + x_rot = x * cos_a + y * sin_a + y_rot = -x * sin_a + y * cos_a + else: + x_rot = x + y_rot = y + return torch.exp(-0.5 * ((x_rot / lon_sigma) ** 2 + (y_rot / lat_sigma) ** 2)) + + def smooth_window(x, start, end, sharpness=8.0): + return torch.sigmoid((x - start) * sharpness) * torch.sigmoid((end - x) * sharpness) + + north_mask = smooth_window(lat_grid, -1.0, -0.05, sharpness=8.0) + north_mask_strict = smooth_window(lat_grid, -1.0, -0.30, sharpness=8.0) + + if n_aerosol_latents >= 1: + na = ( + 1.0 * gaussian_2d(-0.80, -0.80, 0.30, 0.38, angle=-0.25) + + 0.8 * gaussian_2d(-0.85, -0.95, 0.30, 0.28, angle=-0.05) + + 0.7 * gaussian_2d(-0.70, -0.35, 0.26, 0.32, angle=0.10) + ) + na *= north_mask_strict * smooth_window(lon_grid, -1.0, 0.00, sharpness=6.5) + spatial_templates[0] = -na.reshape(-1) * 1.65 + + if n_aerosol_latents >= 2: + eu = ( + 1.0 * gaussian_2d(-0.80, 0.40, 0.24, 0.30, angle=0.20) + + 0.8 * gaussian_2d(-0.75, 0.50, 0.22, 0.26, angle=0.30) + + 0.6 * gaussian_2d(-0.70, 0.30, 0.20, 0.24, angle=-0.05) + ) + eu *= north_mask_strict * smooth_window(lon_grid, -0.30, 0.55, sharpness=7.5) + spatial_templates[1] = -eu.reshape(-1) * 1.55 + + if n_aerosol_latents >= 3: + ea = ( + 1.0 * gaussian_2d(0.45, 0.65, 0.16, 0.22, angle=-0.20) + + 0.6 * gaussian_2d(0.35, 0.80, 0.12, 0.18, angle=-0.10) + + 0.5 * gaussian_2d(0.50, 0.45, 0.14, 0.18, angle=0.10) + ) + ea *= north_mask * smooth_window(lon_grid, 0.35, 0.98, sharpness=8.0) + spatial_templates[2] = -ea.reshape(-1) * 1.10 + + if n_aerosol_latents >= 4: + ind = ( + 1.0 * gaussian_2d(0.25, 0.25, 0.10, 0.12, angle=0.30) + + 0.6 * gaussian_2d(0.15, 0.35, 0.08, 0.10, angle=-0.10) + + 0.4 * gaussian_2d(0.35, 0.15, 0.10, 0.12, angle=0.00) + ) + ind *= smooth_window(lat_grid, -0.55, -0.05, sharpness=8.0) * smooth_window( + lon_grid, 0.20, 0.65, sharpness=8.0 + ) + spatial_templates[3] = -ind.reshape(-1) * 0.95 + + for i in range(4, n_aerosol_latents): + wave_num = i - 2 + pattern_2d = torch.abs(torch.sin(wave_num * math.pi * lat_grid)) * torch.abs( + torch.cos((wave_num + 1) * math.pi * lon_grid) + ) + spatial_templates[i] = -pattern_2d.reshape(-1) * (1.0 + 0.2 * i) + + return spatial_templates + + def _create_aerosol_templates_1d(self, n_aerosol_latents: int, spatial_len: int, dev, dtype) -> torch.Tensor: + """Build fallback 1D aerosol templates for non-square grids.""" + spatial_templates = torch.zeros(n_aerosol_latents, spatial_len, device=dev, dtype=dtype) + coords = torch.linspace(-1.0, 1.0, spatial_len, device=dev, dtype=dtype) + + if n_aerosol_latents >= 1: + northern_mask = (coords < 0.0).float() + spatial_templates[0] = -torch.sigmoid((-coords - 0.3) * 8.0) * northern_mask * 1.5 + + if n_aerosol_latents >= 2: + southern_mask = (coords > 0.0).float() + spatial_templates[1] = -torch.sigmoid((coords - 0.3) * 8.0) * southern_mask * 1.0 + + if n_aerosol_latents >= 3: + spatial_templates[2] = -torch.exp(-((coords / 0.3) ** 2)) * 0.8 + + if n_aerosol_latents >= 4: + mid_lat_mask = ((coords > -0.5) & (coords < 0.5)).float() + spatial_templates[3] = -torch.abs(torch.sin(2 * math.pi * coords)) * mid_lat_mask * 1.2 + + for i in range(4, n_aerosol_latents): + wave_num = i - 2 + spatial_templates[i] = -torch.abs(torch.sin(wave_num * math.pi * coords + i * 0.3)) * (1.0 + 0.2 * i) + + return spatial_templates + + def _normalize_aerosol_templates(self, spatial_templates: torch.Tensor, aerosol_contrast: float) -> None: + """Apply contrast and L2 normalization in-place.""" + n_aerosol_latents = spatial_templates.shape[0] + for i in range(n_aerosol_latents): + if aerosol_contrast != 1.0: + spatial_templates[i] = -torch.abs(spatial_templates[i]).pow(aerosol_contrast) + + norm = torch.sqrt((spatial_templates[i] ** 2).sum()) + if norm > 1e-8: + spatial_templates[i] = spatial_templates[i] / norm + else: + logger.warning(f"Aerosol template {i} has near-zero norm, skipping normalization") + + def _log_aerosol_template_orthogonality(self, spatial_templates: torch.Tensor) -> None: + """Log pairwise inner products for diagnostics.""" + n_aerosol_latents = spatial_templates.shape[0] + if n_aerosol_latents <= 1: + return + + logger.info("Aerosol spatial template orthogonality:") + for i in range(n_aerosol_latents): + for j in range(i + 1, n_aerosol_latents): + inner_prod = (spatial_templates[i] * spatial_templates[j]).sum().item() + logger.info(f" Template {i} · Template {j} = {inner_prod:.4f}") + + def _accumulate_aerosol_latent_forcing( + self, + forcing: torch.Tensor, + spatial_templates: torch.Tensor, + latent_signs: List[float], + t: torch.Tensor, + time_len: int, + base_ramp_up_time: int, + base_peak_time: int, + base_decline_time: int, + timing_stagger: float, + aerosol_scale: float, + ) -> None: + """Accumulate latent-specific temporal patterns into the forcing field.""" + n_aerosol_latents = spatial_templates.shape[0] + for i in range(n_aerosol_latents): + offset_fraction = i / n_aerosol_latents + latent_ramp_up = base_ramp_up_time + int(offset_fraction * timing_stagger * time_len) + latent_peak = base_peak_time + int(offset_fraction * timing_stagger * 0.7 * time_len) + latent_decline = base_decline_time + int(offset_fraction * timing_stagger * 0.5 * time_len) + + latent_peak = min(latent_peak, time_len - 100) + latent_decline = min(latent_decline, time_len - 10) + + t_ramp = latent_ramp_up / time_len + t_peak = latent_peak / time_len + t_decline = latent_decline / time_len + ramp_up = torch.sigmoid((t - t_ramp) * 10.0 / (t_peak - t_ramp + 1e-6)) + ramp_down = torch.sigmoid((t_decline - t) * 10.0 / (t_decline - t_peak + 1e-6)) + envelope = ramp_up * ramp_down + + base_freq = 3.0 + i * 2.0 + freq_mod = 1.0 + 0.2 * torch.sin(2.0 * math.pi * base_freq * t + i * 0.5) + latent_seasonal = torch.sin(2.0 * math.pi * (6.0 + i) * t) + latent_bursts = torch.sin(2.0 * math.pi * (18.0 + i * 3) * t + 0.3 * i) + + trend_sign = latent_signs[i] if i < len(latent_signs) else -1.0 + latent_trend = ( + trend_sign * aerosol_scale * envelope * freq_mod * (1.0 + 0.15 * latent_seasonal + 0.05 * latent_bursts) + ) + forcing += spatial_templates[i].unsqueeze(1) * latent_trend.unsqueeze(0) + + sign_label = "warming" if trend_sign > 0 else "cooling" + logger.debug( + f" Latent {i}: ramp_up={latent_ramp_up}, peak={latent_peak}, decline={latent_decline}, " + f"freq={base_freq}, {sign_label}" + ) + + def create_aerosol_forcing(self) -> np.ndarray: + """ + Create an aerosol forcing field with distinct temporal dynamics per region. + + Each spatial region (corresponding to an aerosol latent) has a staggered temporal signal with unique timing and + frequency modulation. This ensures aerosol latents are distinguishable for causal discovery. + + Uses aerosol_ramp_up_time, aerosol_peak_time, and aerosol_decline_time from forcing_dict as base timing, with + staggered offsets per region. + """ + dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + + time_len = self.time_length + self.transient + if time_len <= 0: + raise ValueError("Time length including transient must be positive") + + spatial_len = self.spatial_resolution + if spatial_len <= 0: + raise ValueError("Spatial resolution must be positive") + + forcing_cfg = self.forcing_dict or {} + aerosol_scale = max(float(forcing_cfg.get("aerosol_scale", 0.03)), 0.0) + aerosol_contrast = max(float(forcing_cfg.get("aerosol_spatial_contrast", 1.05)), 0.5) + + base_ramp_up_time = int(forcing_cfg.get("aerosol_ramp_up_time", int(0.2 * self.time_length))) + self.transient + base_peak_time = int(forcing_cfg.get("aerosol_peak_time", int(0.6 * self.time_length))) + self.transient + base_decline_time = int(forcing_cfg.get("aerosol_decline_time", int(0.85 * self.time_length))) + self.transient + timing_stagger = float(forcing_cfg.get("aerosol_timing_stagger", 0.3)) + + t = torch.linspace(0.0, 1.0, time_len, device=dev, dtype=dtype) + n_aerosol_latents = len(self.forcing_indices.get("aerosol", [])) if self.forcing_indices else 1 + n_aerosol_latents = max(n_aerosol_latents, 1) + + warming_index = self._resolve_aerosol_warming_index( + forcing_cfg.get("aerosol_warming_index", None), + n_aerosol_latents, + ) + latent_signs = [-1.0] * n_aerosol_latents + if warming_index is not None and 0 <= warming_index < n_aerosol_latents: + latent_signs[warming_index] = 1.0 + + spatial_templates = self._create_aerosol_spatial_templates( + n_aerosol_latents, + spatial_len, + aerosol_contrast, + dev, + dtype, + ) + forcing = torch.zeros((spatial_len, time_len), device=dev, dtype=dtype) + + logger.info( + f"Aerosol forcing: Creating {n_aerosol_latents} orthogonal spatial templates with distinct temporal patterns " + f"(base_ramp={base_ramp_up_time}, base_peak={base_peak_time}, base_decline={base_decline_time}, " + f"timing_stagger={timing_stagger}, scale={aerosol_scale}, spatial_contrast={aerosol_contrast})" + ) + if warming_index is not None and 0 <= warming_index < n_aerosol_latents: + logger.info(f"Aerosol forcing: warming latent index={warming_index}") + + self._accumulate_aerosol_latent_forcing( + forcing=forcing, + spatial_templates=spatial_templates, + latent_signs=latent_signs, + t=t, + time_len=time_len, + base_ramp_up_time=base_ramp_up_time, + base_peak_time=base_peak_time, + base_decline_time=base_decline_time, + timing_stagger=timing_stagger, + aerosol_scale=aerosol_scale, + ) + + forcing_np = forcing.detach().cpu().numpy() + self.aerosol_spatial_templates = spatial_templates.detach().cpu().numpy() + logger.info(f"[AEROSOL FORCING] Stored {n_aerosol_latents} spatial templates for latent extraction") + return forcing_np + + def _generate_forcing_trajectories(self) -> None: + """ + Generate forcing latent trajectories BEFORE dynamics. + + This creates the CO2 and aerosol forcing fields and extracts their + latent representations (time series) that will be used as exogenous + inputs during the dynamics computation. + + Must be called BEFORE _create_linear(), _create_nonlinear(), or _create_polynomial(). + + The forcing latent trajectories are stored in: + - self.co2_latent_trajectory: shape (time,) - scalar CO2 latent over time + - self.aerosol_latent_trajectory: shape (n_aerosol_latents, time) - aerosol latents over time + """ + if self.forcing_dict is None or self.forcing_indices is None: + logger.info("No forcing configured, skipping trajectory generation") + return + + time_len = self.time_length + self.transient + + # Initialize forcing data fields if needed + if self.forcing_data_field is None: + self.forcing_data_field = np.zeros((self.spatial_resolution, time_len)) + if self.co2_forcing_data_field is None: + self.co2_forcing_data_field = np.zeros((self.spatial_resolution, time_len)) + if self.aerosol_forcing_data_field is None: + self.aerosol_forcing_data_field = np.zeros((self.spatial_resolution, time_len)) + + # Generate CO2 forcing field and extract latent trajectory + co2_forcing = self.create_co2_forcing() + self.co2_forcing_data_field = co2_forcing + self.co2_latent_trajectory = co2_forcing.mean(axis=0) # shape: (time,) + + logger.info( + f"[FORCING TRAJ] Generated CO2 latent trajectory: shape {self.co2_latent_trajectory.shape}, " + f"range [{self.co2_latent_trajectory.min():.4f}, {self.co2_latent_trajectory.max():.4f}]" + ) + + # Generate aerosol forcing field and extract latent trajectories via template projection + aerosol_forcing = self.create_aerosol_forcing() + self.aerosol_forcing_data_field = aerosol_forcing + + n_aerosol_latents = len(self.forcing_indices.get("aerosol", [])) + + if n_aerosol_latents > 0 and self.aerosol_spatial_templates is not None: + aerosol_latents = np.zeros((n_aerosol_latents, time_len)) + for i in range(n_aerosol_latents): + template = self.aerosol_spatial_templates[i] # (spatial_resolution,) + # Inner product of template with forcing field at each timestep + aerosol_latents[i] = template @ aerosol_forcing # (time,) + self.aerosol_latent_trajectory = aerosol_latents + + logger.info( + f"[FORCING TRAJ] Generated aerosol latent trajectories: shape {self.aerosol_latent_trajectory.shape}" + ) + for i in range(n_aerosol_latents): + logger.debug( + f" Aerosol latent {i}: range [{aerosol_latents[i].min():.4f}, {aerosol_latents[i].max():.4f}]" + ) + else: + self.aerosol_latent_trajectory = None + if n_aerosol_latents > 0: + logger.info( + f"[FORCING TRAJ] Warning: {n_aerosol_latents} aerosol latents configured but no spatial templates" + ) + + # Store combined forcing field for diagnostics (but DON'T add to data_field) + self.forcing_data_field = co2_forcing + aerosol_forcing def _create_linear(self): - """Weights N \times L data_field L \times T.""" - weights = deepcopy(self.mode_weights.reshape(self.n_vars, -1)) - # weights_inv = np.linalg.pinv(weights) - weights_inv = torch.Tensor(np.linalg.pinv(weights)).to(device="cuda") - weights = torch.Tensor(weights).to(device="cuda") - time_len = deepcopy(self.time_length) - time_len += self.transient + """ + Create linear SAVAR dynamics with forcing in the dynamics equation: + + z_j(t) = sum_{k,τ} P_climate[j,k,τ] · z_k(t-τ) + sum_{f,τ} P_forcing[j,f,τ] · forcing_f(t-τ) + N_j(t) + S_t = W^{-1} · z(t) + + where: + - W is the mixing matrix (mode_weights) + - P_climate is the climate-to-climate transition matrix + - P_forcing is the forcing-to-climate transition matrix + - forcing_f are exogenous forcing latent trajectories (CO2, aerosol) + - N_t ~ N(0, s·W·W^T) is the noise at each timestep + - The sum of all links (climate + forcing) equals 1 for each target mode + """ + dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + + weights = deepcopy(self.mode_weights.reshape(self.n_climate_modes, -1)) + weights_inv = torch.tensor(np.linalg.pinv(weights), device=dev, dtype=dtype) + weights_torch = torch.tensor(weights, device=dev, dtype=dtype) + + time_len = self.time_length + self.transient tau_max = self.tau_max - # phi = dict_to_matrix(self.links_coeffs) - phi = torch.Tensor(dict_to_matrix(self.links_coeffs)).to(device="cuda") + # Get full transition matrix + phi_full_np = dict_to_matrix(self.links_coeffs) + + # Extract climate-to-climate submatrix + phi_climate_np = phi_full_np[: self.n_climate_modes, : self.n_climate_modes, :] + + # Check if we have forcing + has_forcing = self.forcing_indices is not None and self.co2_latent_trajectory is not None + + if has_forcing: + # Extract forcing-to-climate submatrix: phi_full[climate_targets, forcing_sources, :] + phi_forcing_np = phi_full_np[: self.n_climate_modes, self.n_climate_modes :, :] + + # Normalize both climate and forcing coefficients together + phi_climate_np, phi_forcing_np = normalize_transition_with_forcing( + phi_climate_np, phi_forcing_np, self.n_climate_modes + ) + phi_forcing = torch.tensor(phi_forcing_np, device=dev, dtype=dtype) + + # Prepare forcing latent trajectories + co2_latent = torch.tensor(self.co2_latent_trajectory, device=dev, dtype=dtype) + + # Concatenate all forcing latents: (n_forcing, time) + if self.aerosol_latent_trajectory is not None: + aerosol_latent = torch.tensor(self.aerosol_latent_trajectory, device=dev, dtype=dtype) + forcing_latents = torch.cat( + [co2_latent.unsqueeze(0), aerosol_latent], dim=0 # (1, time) # (n_aerosol, time) + ) + else: + forcing_latents = co2_latent.unsqueeze(0) # (1, time) + + logger.debug(f"[LINEAR] Forcing latents shape: {forcing_latents.shape}") + logger.debug(f"[LINEAR] Forcing coefficient matrix shape: {phi_forcing_np.shape}") + else: + # No forcing - use raw coefficients (already bounded < 1 by create_links_coeffs) + pass + + phi_climate = torch.tensor(phi_climate_np, device=dev, dtype=dtype) + + # Initialize data_field from existing (includes noise added by _add_noise_field) + data_field = torch.tensor(self.data_field, device=dev, dtype=dtype) + + if has_forcing: + logger.info("create_linear (with forcing in dynamics: S_t = W^{-1}·(P_c·z + P_f·f))") + else: + logger.info("create_linear (no forcing: S_t = W^{-1}·P·W·S)") + # data_field = deepcopy(self.data_field) data_field = torch.Tensor(self.data_field).to(device="cuda") if self.verbose: print("create_linear") for t in tqdm(range(tau_max, time_len)): + # Climate-to-climate AR contribution: W^{-1} · P_climate · W · S_{t-τ:t-1} + ar_contribution = torch.zeros(self.spatial_resolution, 1, device=dev, dtype=dtype) for i in range(tau_max): - data_field[..., t : t + 1] += weights_inv @ phi[..., i] @ weights @ data_field[..., t - 1 - i : t - i] - # data_field[..., t:t + 1] += torch.matmul(torch.matmul(torch.matmul(weights_inv, phi[..., i]), weights), data_field[..., t - 1 - i:t - i]) + ar_contribution += ( + weights_inv @ phi_climate[..., i] @ weights_torch @ data_field[..., t - 1 - i : t - i] + ) + + # Forcing-to-climate contribution (if forcing exists) + if has_forcing: + # Forcing contribution in latent space: P_forcing @ forcing_latents + forcing_contrib_latent = torch.zeros(self.n_climate_modes, 1, device=dev, dtype=dtype) + for i in range(tau_max): + lag_idx = t - 1 - i + if lag_idx >= 0: + # phi_forcing: (n_climate, n_forcing, tau_max) + # forcing_latents: (n_forcing, time) + forcing_at_lag = forcing_latents[:, lag_idx].unsqueeze(1) # (n_forcing, 1) + forcing_contrib_latent += phi_forcing[..., i] @ forcing_at_lag + + # Project forcing contribution to observation space: W^{-1} @ forcing_contrib + forcing_contrib_obs = weights_inv @ forcing_contrib_latent + ar_contribution += forcing_contrib_obs + + # S_t = AR_contribution + forcing_contribution + # Noise is already in data_field (added before dynamics by _add_noise_field) + data_field[..., t : t + 1] += ar_contribution self.data_field = data_field[..., self.transient :].detach().cpu().numpy() @@ -411,15 +1239,13 @@ def _create_intervened_nextstep(self, input_data, intervened_mode=None, interven This is to keep the savar structure similar to the one of `self.data_field` """ - weights = deepcopy(self.mode_weights.reshape(self.n_vars, -1)) + weights = deepcopy(self.mode_weights.reshape(self.n_climate_modes, -1)) # weights_inv = np.linalg.pinv(weights) weights_inv = torch.Tensor(np.linalg.pinv(weights)).to(device="cuda") weights = torch.Tensor(weights).to(device="cuda") tau = input_data.shape[1] - # phi = dict_to_matrix(self.links_coeffs) phi = torch.Tensor(dict_to_matrix(self.links_coeffs)).to(device="cuda") - # data_field = deepcopy(self.data_field) next_step = torch.zeros(self.spatial_resolution).to(device="cuda") change_indices = [] @@ -487,7 +1313,7 @@ def train_nnar(self, num_epochs=50, learning_rate=0.001, batch_size=32): optimizer.step() if (epoch + 1) % 5 == 0: - print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {sum(batch_losses)/len(batch_losses):.6f}") + logger.info(f"Epoch [{epoch+1}/{num_epochs}], Loss: {sum(batch_losses)/len(batch_losses):.6f}") if self.verbose: print("Training of single-layer NNAR model completed.") @@ -498,64 +1324,183 @@ def _create_nonlinear(self): same logic as _create_linear to step forward in time and adds the nonlinearity (sigmoid) before adding to data_field. - If train_nnar=True was set, we assume self.nnar_model was trained in generate_data(). - Otherwise, we can do a direct inline "torch.sigmoid(...)" approach. - Can be increased in complexity if needed + z_j(t) = tanh(sum_{k,τ} P_climate[j,k,τ] · z_k(t-τ)) + sum_{f,τ} P_forcing[j,f,τ] · forcing_f(t-τ) + N_j(t) + S_t = W^{-1} · z(t) + + Note: The nonlinearity (tanh) is applied only to climate contributions. + Forcing contributions remain linear as they are exogenous inputs. """ + dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 - weights = torch.Tensor(np.linalg.pinv(self.mode_weights.reshape(self.n_vars, -1))).to("cuda") - phi = torch.Tensor(dict_to_matrix(self.links_coeffs)).to("cuda") - mode_weights_tensor = torch.Tensor(self.mode_weights.reshape(self.n_vars, -1)).to("cuda") - data_field = torch.Tensor(self.data_field).to("cuda") + weights_np = np.linalg.pinv(self.mode_weights.reshape(self.n_climate_modes, -1)) + weights_inv = torch.tensor(weights_np, device=dev, dtype=dtype) + mode_weights_tensor = torch.tensor(self.mode_weights.reshape(self.n_climate_modes, -1), device=dev, dtype=dtype) time_len = self.time_length + self.transient tau_max = self.tau_max - if self.verbose: - print("create_nonlinear (single-layer net + sigmoid)") + # Get full transition matrix + phi_full_np = dict_to_matrix(self.links_coeffs) + + # Extract climate-to-climate submatrix + phi_climate_np = phi_full_np[: self.n_climate_modes, : self.n_climate_modes, :] + + # Check if we have forcing + has_forcing = self.forcing_indices is not None and self.co2_latent_trajectory is not None + + if has_forcing: + # Extract forcing-to-climate submatrix + phi_forcing_np = phi_full_np[: self.n_climate_modes, self.n_climate_modes :, :] + + # Normalize both climate and forcing coefficients together + phi_climate_np, phi_forcing_np = normalize_transition_with_forcing( + phi_climate_np, phi_forcing_np, self.n_climate_modes + ) + phi_forcing = torch.tensor(phi_forcing_np, device=dev, dtype=dtype) + + # Prepare forcing latent trajectories + co2_latent = torch.tensor(self.co2_latent_trajectory, device=dev, dtype=dtype) + + if self.aerosol_latent_trajectory is not None: + aerosol_latent = torch.tensor(self.aerosol_latent_trajectory, device=dev, dtype=dtype) + forcing_latents = torch.cat([co2_latent.unsqueeze(0), aerosol_latent], dim=0) + else: + forcing_latents = co2_latent.unsqueeze(0) + + logger.debug(f"[NONLINEAR] Forcing latents shape: {forcing_latents.shape}") + else: + # No forcing - use raw coefficients (already bounded < 1 by create_links_coeffs) + pass + + phi_climate = torch.tensor(phi_climate_np, device=dev, dtype=dtype) + data_field = torch.tensor(self.data_field, device=dev, dtype=dtype) + + if has_forcing: + logger.info("create_nonlinear (with forcing: tanh(climate) + linear(forcing))") + else: + logger.info("create_nonlinear (no forcing: tanh(climate) + noise)") for t in tqdm(range(tau_max, time_len)): - # Sum up influences from each lag - nonlinear_contrib = 0.0 + # Climate contribution with nonlinearity + nonlinear_contrib = torch.zeros(self.spatial_resolution, device=dev, dtype=dtype) for i in range(tau_max): - # get linear combination as in _create_linear - lincombo = weights @ phi[..., i] @ mode_weights_tensor @ data_field[..., (t - 1 - i) : (t - i)] - # Apply a sigmoid (or feed it through the small neural net if you want more complexity) + lincombo = ( + weights_inv @ phi_climate[..., i] @ mode_weights_tensor @ data_field[..., (t - 1 - i) : (t - i)] + ) lincombo_nl = torch.sigmoid(lincombo) - # accumulate nonlinear_contrib += lincombo_nl.squeeze(-1) - # Add the (nonlinear) effect to the data field at time t + # Forcing contribution (linear, exogenous) + if has_forcing: + forcing_contrib_latent = torch.zeros(self.n_climate_modes, 1, device=dev, dtype=dtype) + for i in range(tau_max): + lag_idx = t - 1 - i + if lag_idx >= 0: + forcing_at_lag = forcing_latents[:, lag_idx].unsqueeze(1) + forcing_contrib_latent += phi_forcing[..., i] @ forcing_at_lag + + forcing_contrib_obs = weights_inv @ forcing_contrib_latent + nonlinear_contrib += forcing_contrib_obs.squeeze(-1) + + # S_t = nonlinear_contribution + forcing_contribution + # Noise is already in data_field (added before dynamics by _add_noise_field) data_field[:, t] += nonlinear_contrib self.data_field = data_field[:, self.transient :].detach().cpu().numpy() def _create_polynomial(self): - """Example polynomial autoregression, e.g. x^2 for poly_degree=2.""" - w_np = np.linalg.pinv(self.mode_weights.reshape(self.n_vars, -1)) - phi_np = dict_to_matrix(self.links_coeffs) + """ + Example polynomial autoregression, e.g. x^2 for poly_degree=2. w_np = + np.linalg.pinv(self.mode_weights.reshape(self.n_climate_modes, -1)) phi_np = dict_to_matrix(self.links_coeffs) + + z_j(t) = sum_deg (sum_{k,τ} P_climate[j,k,τ] · z_k(t-τ))^deg + sum_{f,τ} P_forcing[j,f,τ] · forcing_f(t-τ) + N_j(t) + S_t = W^{-1} · z(t) - w_torch = torch.Tensor(w_np).to("cuda") - phi_torch = torch.Tensor(phi_np).to("cuda") - mw_torch = torch.Tensor(self.mode_weights.reshape(self.n_vars, -1)).to("cuda") - data_field = torch.Tensor(self.data_field).to("cuda") + Note: The polynomial nonlinearity is applied only to climate contributions. + Forcing contributions remain linear as they are exogenous inputs. + """ + dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + + w_np = np.linalg.pinv(self.mode_weights.reshape(self.n_climate_modes, -1)) + w_torch = torch.tensor(w_np, device=dev, dtype=dtype) + mw_torch = torch.tensor( + self.mode_weights.reshape(self.n_climate_modes, -1), + device=dev, + dtype=dtype, + ) time_len = self.time_length + self.transient tau_max = self.tau_max - if self.verbose: - print(f"create_polynomial with degrees={self.poly_degrees}") + # Get full transition matrix + phi_full_np = dict_to_matrix(self.links_coeffs) + + # Extract climate-to-climate submatrix + phi_climate_np = phi_full_np[: self.n_climate_modes, : self.n_climate_modes, :] + + # Check if we have forcing + has_forcing = self.forcing_indices is not None and self.co2_latent_trajectory is not None + + if has_forcing: + # Extract forcing-to-climate submatrix + phi_forcing_np = phi_full_np[: self.n_climate_modes, self.n_climate_modes :, :] + + # Normalize both climate and forcing coefficients together + phi_climate_np, phi_forcing_np = normalize_transition_with_forcing( + phi_climate_np, phi_forcing_np, self.n_climate_modes + ) + phi_forcing = torch.tensor(phi_forcing_np, device=dev, dtype=dtype) + + # Prepare forcing latent trajectories + co2_latent = torch.tensor(self.co2_latent_trajectory, device=dev, dtype=dtype) + + if self.aerosol_latent_trajectory is not None: + aerosol_latent = torch.tensor(self.aerosol_latent_trajectory, device=dev, dtype=dtype) + forcing_latents = torch.cat([co2_latent.unsqueeze(0), aerosol_latent], dim=0) + else: + forcing_latents = co2_latent.unsqueeze(0) + + logger.debug(f"[POLYNOMIAL] Forcing latents shape: {forcing_latents.shape}") + else: + # No forcing - use raw coefficients (already bounded < 1 by create_links_coeffs) + pass + + phi_climate = torch.tensor(phi_climate_np, device=dev, dtype=dtype) + data_field = torch.tensor(self.data_field, device=dev, dtype=dtype) + + if has_forcing: + logger.info( + f"create_polynomial (with forcing: poly(climate) + linear(forcing) + noise) degrees={self.poly_degrees}" + ) + else: + logger.info(f"create_polynomial (no forcing: poly(climate) + noise) degrees={self.poly_degrees}") for t in tqdm(range(tau_max, time_len)): - # For each time step, sum over the contributions of all lags + # Climate contribution with polynomial nonlinearity + poly_contrib = torch.zeros(self.spatial_resolution, device=dev, dtype=dtype) for i in range(tau_max): - lincombo = w_torch @ phi_torch[..., i] @ mw_torch @ data_field[..., (t - 1 - i) : (t - i)] + lincombo = w_torch @ phi_climate[..., i] @ mw_torch @ data_field[..., (t - 1 - i) : (t - i)] # For each requested polynomial degree, add its effect - poly_sum = 0.0 for deg in self.poly_degrees: - poly_sum += lincombo**deg - - data_field[:, t] += poly_sum.squeeze(-1) + poly_contrib += lincombo**deg + + # Forcing contribution (linear, exogenous) + if has_forcing: + forcing_contrib_latent = torch.zeros(self.n_climate_modes, 1, device=dev, dtype=dtype) + for i in range(tau_max): + lag_idx = t - 1 - i + if lag_idx >= 0: + forcing_at_lag = forcing_latents[:, lag_idx].unsqueeze(1) + forcing_contrib_latent += phi_forcing[..., i] @ forcing_at_lag + + forcing_contrib_obs = w_torch @ forcing_contrib_latent + poly_contrib += forcing_contrib_obs.squeeze(-1) + + # S_t = polynomial_contribution + forcing_contribution + # Noise is already in data_field (added before dynamics by _add_noise_field) + data_field[:, t] += poly_contrib self.data_field = data_field[:, self.transient :].detach().cpu().numpy() diff --git a/climatem/synthetic_data/utils.py b/climatem/synthetic_data/utils.py index 0fb4362..1eeaaf8 100644 --- a/climatem/synthetic_data/utils.py +++ b/climatem/synthetic_data/utils.py @@ -8,6 +8,33 @@ from tigramite.data_processing import smooth +def permute_matrices( + lat, + lon, + modes_inferred, + modes_gt, + mat_transition, + tau, +): + modes_inferred = modes_inferred.reshape((lat, lon, modes_inferred.shape[-1])).transpose((2, 0, 1)) + + idx_gt_flat = np.argmax(modes_gt.reshape(modes_gt.shape[0], -1), axis=1) # shape: (n_modes,) + idx_inferred_flat = np.argmax(modes_inferred.reshape(modes_inferred.shape[0], -1), axis=1) # shape: (n_modes,) + + # Convert flat indices to 2D coordinates (row, col) + idx_gt = np.array([np.unravel_index(i, (lat, lon)) for i in idx_gt_flat]) # shape: (n_modes, 2) + idx_inferred = np.array([np.unravel_index(i, (lat, lon)) for i in idx_inferred_flat]) # shape: (n_modes, 2) + + # Compute error matrix using squared Euclidean distance between indices which yields an (n_modes x n_modes) matrix + permutation_list = ((idx_gt[:, None, :] - idx_inferred[None, :, :]) ** 2).sum(axis=2).argmin(axis=1) + + # Permute + for k in range(tau): + mat_transition[k] = mat_transition[k][np.ix_(permutation_list, permutation_list)] + + return mat_transition + + def check_stability(graph: Union[np.ndarray, dict], lag_first_axis: bool = False, verbose: bool = False): """ Raises an AssertionError if the input graph corresponds to a non-stationary process. @@ -264,30 +291,3 @@ def create_graph(links_coeffs, return_lag=True): return np.asarray(graph), max_lag else: return np.asarray(graph) - - -def permute_matrices( - lat, - lon, - modes_inferred, - modes_gt, - mat_transition, - tau, -): - modes_inferred = modes_inferred.reshape((lat, lon, modes_inferred.shape[-1])).transpose((2, 0, 1)) - - idx_gt_flat = np.argmax(modes_gt.reshape(modes_gt.shape[0], -1), axis=1) # shape: (n_modes,) - idx_inferred_flat = np.argmax(modes_inferred.reshape(modes_inferred.shape[0], -1), axis=1) # shape: (n_modes,) - - # Convert flat indices to 2D coordinates (row, col) - idx_gt = np.array([np.unravel_index(i, (lat, lon)) for i in idx_gt_flat]) # shape: (n_modes, 2) - idx_inferred = np.array([np.unravel_index(i, (lat, lon)) for i in idx_inferred_flat]) # shape: (n_modes, 2) - - # Compute error matrix using squared Euclidean distance between indices which yields an (n_modes x n_modes) matrix - permutation_list = ((idx_gt[:, None, :] - idx_inferred[None, :, :]) ** 2).sum(axis=2).argmin(axis=1) - - # Permute - for k in range(tau): - mat_transition[k] = mat_transition[k][np.ix_(permutation_list, permutation_list)] - - return mat_transition diff --git a/mappings/healpix_lonlat_mapping.npy b/mappings/healpix_lonlat_mapping.npy deleted file mode 100644 index 1db6fe2..0000000 Binary files a/mappings/healpix_lonlat_mapping.npy and /dev/null differ diff --git a/mappings/healpix_resdown4_lonlat_mapping.npy b/mappings/healpix_resdown4_lonlat_mapping.npy deleted file mode 100644 index 5d92f84..0000000 Binary files a/mappings/healpix_resdown4_lonlat_mapping.npy and /dev/null differ diff --git a/poetry.lock b/poetry.lock index 867d37b..bf2fc46 100644 --- a/poetry.lock +++ b/poetry.lock @@ -314,78 +314,6 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} -[[package]] -name = "astropy" -version = "6.1.7" -description = "Astronomy and astrophysics core library" -optional = false -python-versions = ">=3.10" -groups = ["main"] -markers = "python_version <= \"3.11\"" -files = [ - {file = "astropy-6.1.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:be954c5f7707a089609053665aeb76493b79e5c4753c39486761bc6d137bf040"}, - {file = "astropy-6.1.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b5e48df5ab2e3e521e82a7233a4b1159d071e64e6cbb76c45415dc68d3b97af1"}, - {file = "astropy-6.1.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55c78252633c644361e2f7092d71f80ef9c2e6649f08d97711d9f19af514aedc"}, - {file = "astropy-6.1.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:985e5e74489d23f1a11953b6b283fccde3f46cb6c68fee4f7228e5f6d8350ba9"}, - {file = "astropy-6.1.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:dc2ea28ed41a3d92c39b1481d9c5be016ae58d68f144f3fd8cecffe503525bab"}, - {file = "astropy-6.1.7-cp310-cp310-win32.whl", hash = "sha256:4e4badadd8dfa5dca08fd86e9a50a3a91af321975859f5941579e6b7ce9ba199"}, - {file = "astropy-6.1.7-cp310-cp310-win_amd64.whl", hash = "sha256:8d7f6727689288ee08fc0a4a297fc7e8089d01718321646bd00fea0906ad63dc"}, - {file = "astropy-6.1.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:09edca01276ee63f7b2ff511da9bfb432068ba3242e27ef27d76e5a171087b7e"}, - {file = "astropy-6.1.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:072f62a67992393beb016dc80bee8fb994fda9aa69e945f536ed8ac0e51291e6"}, - {file = "astropy-6.1.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2706156d3646f9c9a7fc810475d8ab0df4c717beefa8326552576a0f8ddca20"}, - {file = "astropy-6.1.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fcd99e627692f8e58bb3097d330bfbd109a22e00dab162a67f203b0a0601ad2c"}, - {file = "astropy-6.1.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b0ebbcb637b2e9bcb73011f2b7890d7a3f5a41b66ccaad7c28f065e81e28f0b2"}, - {file = "astropy-6.1.7-cp311-cp311-win32.whl", hash = "sha256:192b12ede49cd828362ab1a6ede2367fe203f4d851804ec22fa92e009a524281"}, - {file = "astropy-6.1.7-cp311-cp311-win_amd64.whl", hash = "sha256:3cac64bcdf570c947019bd2bc96711eeb2c7763afe192f18c9551e52a6c296b2"}, - {file = "astropy-6.1.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f2a8bcbb1306052cc38c9eed2c9331bfafe2582b499a7321946abf74b26eb256"}, - {file = "astropy-6.1.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:eaf88878684f9d31aff36475c90d101f4cff22fdd4fd50098d9950fd56994df7"}, - {file = "astropy-6.1.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb8cd231e53556e4eebe0393ea95a8cea6b2ff4187c95ac4ff8b17e7a8da823"}, - {file = "astropy-6.1.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ad36334d138a4f71d6fdcf225a98ad1dad6c343da4362d5a47a71f5c9da3ca9"}, - {file = "astropy-6.1.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:dd731c526869d0c68507be7b31dd10871b7c44d310bb5495476505560c83cd33"}, - {file = "astropy-6.1.7-cp312-cp312-win32.whl", hash = "sha256:662bacd7ae42561e038cbd85eea3b749308cf3575611a745b60f034d3350c97a"}, - {file = "astropy-6.1.7-cp312-cp312-win_amd64.whl", hash = "sha256:5b4d02a98a0bf91ff7fd4ef0bd0ecca83c9497338cb88b61ec9f971350688222"}, - {file = "astropy-6.1.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:fbeaf04427987c0c6fa2e579eb40011802b06fba6b3a7870e082d5c693564e1b"}, - {file = "astropy-6.1.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ab6e88241a14185b9404b02246329185b70292984aa0616b20a0628dfe4f4ebb"}, - {file = "astropy-6.1.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0529c75565feaabb629946806b4763ae7b02069aeff4c3b56a69e8a9e638500"}, - {file = "astropy-6.1.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c5ec347631da77573fc729ba04e5d89a3bc94500bf6037152a2d0f9965ae1ce"}, - {file = "astropy-6.1.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:dc496f87aaccaa5c6624acc985b8770f039c5bbe74b120c8ed7bad3698e24e1b"}, - {file = "astropy-6.1.7-cp313-cp313-win32.whl", hash = "sha256:b1e01d534383c038dbf8664b964fa4ea818c7419318830d3c732c750c64115c6"}, - {file = "astropy-6.1.7-cp313-cp313-win_amd64.whl", hash = "sha256:af08cf2b0368f1ea585eb26a55d99a2de9e9b0bd30aba84b5329059c3ec33590"}, - {file = "astropy-6.1.7.tar.gz", hash = "sha256:a405ac186306b6cb152e6df2f7444ab8bd764e4127d7519da1b3ae4dd65357ef"}, -] - -[package.dependencies] -astropy-iers-data = ">=0.2024.10.28.0.34.7" -numpy = ">=1.23" -packaging = ">=19.0" -pyerfa = ">=2.0.1.1" -PyYAML = ">=3.13" - -[package.extras] -all = ["asdf-astropy (>=0.3)", "astropy[recommended]", "astropy[typing]", "beautifulsoup4", "bleach", "bottleneck", "certifi", "dask[array]", "fsspec[http] (>=2023.4.0)", "h5py", "html5lib", "ipython (>=4.2)", "jplephem", "mpmath", "pandas", "pre-commit", "pyarrow (>=7.0.0)", "pytest (>=7.0)", "pytz", "s3fs (>=2023.4.0)", "sortedcontainers"] -docs = ["Jinja2 (>=3.1.3)", "astropy[recommended]", "matplotlib (>=3.9.1)", "numpy (<2.0)", "pytest (>=7.0)", "sphinx", "sphinx-astropy[confv2] (>=1.9.1)", "sphinx-changelog (>=1.2.0)", "sphinx_design", "sphinxcontrib-globalsubs (>=0.1.1)", "tomli"] -recommended = ["matplotlib (>=3.5.0,!=3.5.2)", "scipy (>=1.8)"] -test = ["pytest (>=7.0)", "pytest-astropy (>=0.10)", "pytest-astropy-header (>=0.2.1)", "pytest-doctestplus (>=0.12)", "pytest-xdist", "threadpoolctl"] -test-all = ["array-api-strict", "astropy[test]", "coverage[toml]", "ipython (>=4.2)", "objgraph", "sgp4 (>=2.3)", "skyfield (>=1.20)"] -typing = ["typing_extensions (>=4.0.0)"] - -[[package]] -name = "astropy-iers-data" -version = "0.2026.2.16.0.48.25" -description = "IERS Earth Rotation and Leap Second tables for the astropy core package" -optional = false -python-versions = ">=3.8" -groups = ["main"] -markers = "python_version <= \"3.11\"" -files = [ - {file = "astropy_iers_data-0.2026.2.16.0.48.25-py3-none-any.whl", hash = "sha256:180d1c3f59d18aa616345560799c2d88ec6e5164b8c45c746380acf892946136"}, - {file = "astropy_iers_data-0.2026.2.16.0.48.25.tar.gz", hash = "sha256:be14512844e71536a15e165d729385f3cb4865d7822172509e68c4ac79322067"}, -] - -[package.extras] -docs = ["pytest"] -test = ["hypothesis", "pytest", "pytest-remotedata"] - [[package]] name = "asttokens" version = "3.0.0" @@ -469,6 +397,90 @@ files = [ [package.extras] dev = ["backports.zoneinfo", "freezegun (>=1.0,<2.0)", "jinja2 (>=3.0)", "pytest (>=6.0)", "pytest-cov", "pytz", "setuptools", "tzdata"] +[[package]] +name = "basemap" +version = "1.4.1" +description = "Plot data on map projections with matplotlib" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, <3.13" +groups = ["main"] +markers = "python_version <= \"3.11\"" +files = [ + {file = "basemap-1.4.1-cp27-cp27m-win32.whl", hash = "sha256:355c984cbde3f098cac49f41e2ab0bbbd49091770eeb58486a001f9ee865f725"}, + {file = "basemap-1.4.1-cp27-cp27m-win_amd64.whl", hash = "sha256:4fb6763f4f2bb904fea0afbd8194f14af0b01fb52d719be35dd4b423e8d8dca2"}, + {file = "basemap-1.4.1-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:be103ef8c2cb1e7363e3ab48ae633ec2643f0b8a9226d119ec8681096b81335c"}, + {file = "basemap-1.4.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:c76d77ac21b6fb5008ca0283c6f86d24bf7ed6d0d18867edcfc534748cf348b5"}, + {file = "basemap-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a8786f23e3f456ab402b8160efb7b8b1c3780d6b84154dfcfdcd3f0401ef3285"}, + {file = "basemap-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:22793cf731d6cc06e11cf2760fd2875cd18b2980ce1c5ba15f6ebd9e224b1a1d"}, + {file = "basemap-1.4.1-cp310-cp310-manylinux1_i686.whl", hash = "sha256:8818fbb5e1133f0955a62ea159a3792124fc06f9b37ea4427612f6def923b711"}, + {file = "basemap-1.4.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:650e9b2170579193cb89bea2c328740fc13c0327b8ea55173490f760d0f04f8c"}, + {file = "basemap-1.4.1-cp310-cp310-win32.whl", hash = "sha256:2e198e442bae7ad0d25f529788fcb5802192d844d856fa966f6018a46648a375"}, + {file = "basemap-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:91d03458343d73f6a5e4d42e79df59cf57eba699094b71ce7c57ec40c8de9f59"}, + {file = "basemap-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:509bb451ce26f1f7651cc19d27ef049935304c7795a141cfbe05582802554226"}, + {file = "basemap-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a6860ab68a461cd36b6eda5d3973bfa1e3a4fdeaf608e8dd4b324b08e8561eed"}, + {file = "basemap-1.4.1-cp311-cp311-manylinux1_i686.whl", hash = "sha256:500a3d314166057d0b3c065e63147042c996ead1af5d59e319f813b83d4ea220"}, + {file = "basemap-1.4.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:6e0a13ba17ca011c55db44fab9210363154bc67482776d3b149a3ef1b4b01bad"}, + {file = "basemap-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e6b8430a4f73485fcdb766cd2acb9e766cdacc67adca7757e7a2915b685d8461"}, + {file = "basemap-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:c4d80330f00728acfd88be78e0485fd688984db873853f9c75647ec78e653203"}, + {file = "basemap-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:51a5e8f2183e7505f4dfbb965048348c2f1e501283b13ad9950409c3da9655f6"}, + {file = "basemap-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7f99f3da41fc3e6bfb726f0fbd3302939a3b03e353ccd2be5ebd7bc09b84c386"}, + {file = "basemap-1.4.1-cp312-cp312-manylinux1_i686.whl", hash = "sha256:20644986a63d57c9d94afb8aba180d5362c1bd08d61d4537aa6221a70e3df762"}, + {file = "basemap-1.4.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b64f000374a41bf19ac8f8b39cbd2a3b244c79673cd686c9eb507c19daef9e78"}, + {file = "basemap-1.4.1-cp312-cp312-win32.whl", hash = "sha256:84918a6e030adc9aedee92fe622adde8b54cc81051a6ce6723c04987081d92f4"}, + {file = "basemap-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:ff588f63f35d0b58607c9642cc2cf68967b464ee7aa27a6d87db2ec8edc910e2"}, + {file = "basemap-1.4.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:b9679c4cacc9af04d9cedb9afd01b9c85bc12df1c8f9685147500e91a33ad5a8"}, + {file = "basemap-1.4.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:d1600576b21a37883992d5bff0c4485e5468a14fe2df7616e08824103b21465d"}, + {file = "basemap-1.4.1-cp35-cp35m-win32.whl", hash = "sha256:427a2050dcfbbde5bb4421f87b3f594852160029f8bbdf8d0fff81ad554de2f4"}, + {file = "basemap-1.4.1-cp35-cp35m-win_amd64.whl", hash = "sha256:87091beba0e5f6cfc416d1847a582645b2447cd6bf1e8c4ceaabd98d3493ef41"}, + {file = "basemap-1.4.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:f448cfeb0090d4f53f1381c8019c3841e621b7b14df3816172ee72385c35b8b0"}, + {file = "basemap-1.4.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:669d89621f8c7899c061fcf37bc32df24ebeaf5eb0a02b9d9b9896ca6698f6de"}, + {file = "basemap-1.4.1-cp36-cp36m-win32.whl", hash = "sha256:ff7b29d1920c77da51386d5767674c498bbdea822a1b51fda3d25a3be8c066d2"}, + {file = "basemap-1.4.1-cp36-cp36m-win_amd64.whl", hash = "sha256:5c50fa14a71d115a7614b3a659f14d1e65dbf43dbcbedae0c901058860806f5b"}, + {file = "basemap-1.4.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:d56d7ed588a5ca5d364b42c73c933258d394c10b3c4dc224e7d31bb41a6fbef5"}, + {file = "basemap-1.4.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e3345544d530248fa4b4fb89f4b9f93219bffa5983b83f00e42b55b5abf615e8"}, + {file = "basemap-1.4.1-cp37-cp37m-win32.whl", hash = "sha256:2985d8d937427a20868dc18cd677bd6a6b4b94efc1b33ab29706d4e0dac418cf"}, + {file = "basemap-1.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:7b58e500413dbe655127bbe9da82c9d175eadd130ed66b6a9efea8d46b65881b"}, + {file = "basemap-1.4.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:353b2d02bd70fd48b3d0e146a028862b7d9e09ffb7995c70151fb80607a2df25"}, + {file = "basemap-1.4.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:ad169a9ca573752867b136b62ac70cc12ab0ce86e3c69d57e924eaa21b5f5ec7"}, + {file = "basemap-1.4.1-cp38-cp38-win32.whl", hash = "sha256:d273b628adb276200eb9b97f6175f9145b8a280d55d995869ec4bb3c561021c2"}, + {file = "basemap-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:f88ecab0468121d94e67e9497329cd2207dd2c7bf9a4779433cedc9cf6ab86f4"}, + {file = "basemap-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c93beda915b0b68e74a8eec986313e4758a0caa1bd57fee8660362c82f78b26f"}, + {file = "basemap-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3752eec899129a6645170fb1e7a16d0263f8ed2691f9f37916ecb19a2947ab72"}, + {file = "basemap-1.4.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:e20adfa6e77f367567d0389fb1da8ed6ea0ca9e871fd9fec7cfdc2de78495b99"}, + {file = "basemap-1.4.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:1cb70fe82ea9f49ee6df2ecafeade1e9ef4d5654a3e5c2b6f1b50bbfcfada33a"}, + {file = "basemap-1.4.1-cp39-cp39-win32.whl", hash = "sha256:ffe185bc4be347b0dc8ad68d2085ec699a558357d6019b5daf6465b6bbdf793d"}, + {file = "basemap-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:b196ae9446183deaaedfbd2e757d3edf23225584ce1a47b795cba96febd6acdf"}, + {file = "basemap-1.4.1.zip", hash = "sha256:6822d3d08c11cdc443e3ed01d61d512e7cf62d4b40bbc1d235f0a898f2c54a7a"}, +] + +[package.dependencies] +basemap-data = ">=1.3.2,<1.4" +matplotlib = {version = ">=1.5,<3.9", markers = "python_version >= \"3.5\""} +numpy = {version = ">=1.21,<1.27", markers = "python_version >= \"3.8\""} +packaging = {version = ">=16.0,<24.0", markers = "python_version >= \"3.5\""} +pyproj = {version = ">=1.9.3,<3.7.0", markers = "python_version >= \"3.5\""} +pyshp = {version = ">=1.2,<2.4", markers = "python_version >= \"2.7\""} + +[package.extras] +doc = ["cftime (>=1.4.0,<1.7.0)", "furo (>=2022.4.7,<2023.9.11)", "netCDF4 (>=1.5.6,<1.7.0)", "scipy (>=1.2,<1.12)", "sphinx (>=5.3,<7.2)"] +lint = ["astroid (>=1.6,<2.0)", "astroid (>=2.4,<2.5)", "astroid (>=2.5,<2.7)", "astroid (>=2.8,<3.1)", "flake8 (>=2.6,<3.0)", "flake8 (>=2.6,<3.0)", "flake8 (>=2.6,<3.0)", "flake8 (>=2.6,<3.9)", "flake8 (>=2.6,<4.0)", "flake8 (>=2.6,<6.2)", "pylint (>=1.9,<2.0)", "pylint (>=2.11,<3.1)", "pylint (>=2.6,<2.7)", "pylint (>=2.7,<2.10)", "unittest2"] +owslib = ["OWSLib (>=0.28.1,<0.30.0)", "OWSLib (>=0.8.0,<0.11.0)", "OWSLib (>=0.8.0,<0.11.0)", "OWSLib (>=0.8.0,<0.15.0)", "OWSLib (>=0.8.0,<0.18.0)", "OWSLib (>=0.8.0,<0.19.0)", "OWSLib (>=0.8.0,<0.20.0)", "ordereddict"] +pillow = ["pillow (>=3.4.0,<4.0.0)", "pillow (>=3.4.0,<4.0.0)", "pillow (>=4.3.0,<5.0.0)", "pillow (>=5.4.0,<6.0.0)", "pillow (>=6.2.2,<7.0.0)", "pillow (>=7.1.0,<8.0.0)", "pillow (>=8.3.2,<9.0.0)", "pillow (>=9.4.0,<10.2.0)"] +test = ["coverage (>=3.7,<4.0)", "coverage (>=4.5,<5.0)", "coverage (>=4.5,<5.0)", "coverage (>=4.5,<5.0)", "coverage (>=5.5,<6.0)", "coverage (>=5.5,<6.0)", "coverage (>=5.5,<7.4)", "pytest (>=2.9.0,<3.0)", "pytest (>=3.2.0,<3.3)", "pytest (>=3.2.0,<3.3)", "pytest (>=4.6.9,<5.0)", "pytest (>=4.6.9,<5.0)", "pytest (>=6.1.2,<6.2)", "pytest (>=6.2.5,<7.5)", "pytest-cov (>=2.5,<2.6)", "pytest-cov (>=2.5,<2.6)", "pytest-cov (>=2.5,<2.6)", "pytest-cov (>=2.5,<2.9)", "pytest-cov (>=2.9,<3.0)", "pytest-cov (>=2.9,<4.2)", "unittest2"] + +[[package]] +name = "basemap-data" +version = "1.3.2" +description = "Data assets for matplotlib basemap" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, <4" +groups = ["main"] +markers = "python_version <= \"3.11\"" +files = [ + {file = "basemap_data-1.3.2-py2.py3-none-any.whl", hash = "sha256:26e794556c496b26f7714658cdbea5c68cb47d6a8a9fb0e674844fa89c56fc59"}, + {file = "basemap_data-1.3.2.zip", hash = "sha256:0072efd6f12c76e9f35e8fd718360d634b849ba988e74acccaf1ec536275f70b"}, +] + [[package]] name = "beautifulsoup4" version = "4.13.3" @@ -1544,62 +1556,6 @@ files = [ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] -[[package]] -name = "healpy" -version = "1.19.0" -description = "Healpix tools package for Python" -optional = false -python-versions = ">=3.10" -groups = ["main"] -markers = "python_version <= \"3.11\"" -files = [ - {file = "healpy-1.19.0-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:36f85568670f36f928aba0eb73299ed70a06e58dc32360f13d0f9a443781c7bb"}, - {file = "healpy-1.19.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8ec4744f16f1590fe47a685258ba119c0fa49dff74fa6f970b7a16c712302b0d"}, - {file = "healpy-1.19.0-cp310-cp310-macosx_15_0_arm64.whl", hash = "sha256:20225e0312ea37f472a539521824001700afacf486f11a3ec905f85ec8e75f8c"}, - {file = "healpy-1.19.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c187bdab29ba6b9c34e55907b24af6867f2b5d3cd48ba150e696dfa200827ea4"}, - {file = "healpy-1.19.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:f62bf1c078787931f226a0550138cad98cf19f2b72e89f1f0bdd4c342dbc1aa1"}, - {file = "healpy-1.19.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f05943f632dfe0cad774227038cd4d751525ac5e7c3fb1e47afa7971680facd8"}, - {file = "healpy-1.19.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c5b60a21effe4af4831e0fda67e38e633ca8bf33a940581186e20654cf2d72ea"}, - {file = "healpy-1.19.0-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:bb31f28fb15aa7e99725e92d25f458f856ac16542938d8bedd797a74449540b2"}, - {file = "healpy-1.19.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3262006e70fe9328f7fb6788197186e44ad70b94d2cca2cb61f828c7808e5fe4"}, - {file = "healpy-1.19.0-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:8910dd4cd8aec740ab954af5b230d5655caae2d25f55dd8739dde9588257dc85"}, - {file = "healpy-1.19.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4894571e46e5604d9b7135b7c6f588b515cc7b91831c5068940d476754660f1e"}, - {file = "healpy-1.19.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:66939e3d4bc436f5b6249b2356ecf6607e24d25e2de2728d26b4f62d1bdfd2c3"}, - {file = "healpy-1.19.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:f6b654b6ea825f7fa59b607ccb513fb104976ba78a7d073b5f9db0cbdf2dae66"}, - {file = "healpy-1.19.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9af09c0655c3e015d9ec55aadcf6628f95f4f534c0a6320716da8319939ee6d3"}, - {file = "healpy-1.19.0-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:31496bc2eb9f52c1ea3ea4c7e67771e245a49f1a93937ff73f9056b666cd1f30"}, - {file = "healpy-1.19.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:69987864e318b4e22e2407670bdb2d791fcd9c244bc6c25c7001e14fc34f15a2"}, - {file = "healpy-1.19.0-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:e8b304fced0dad0a0a41ddd33bb6b057525f9ddd4dada993bf3d6aa830a08e31"}, - {file = "healpy-1.19.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:730095eef4bbe3c94039edb330d1800285cf43c7e16a81ce5b87edb31a8ffa42"}, - {file = "healpy-1.19.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:89dbe14b270b5479a9652d1ab9f28fc7fce34253549db665f2e91e6cea588883"}, - {file = "healpy-1.19.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:29bb5f7acde68ca5850c3a59ba33cacf4c3768c9e85f7013a99aa5fbd127ce84"}, - {file = "healpy-1.19.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:126bef3ae6594f461435ac85c6bcb32e427eca43f743881d00186bc3a55da1d8"}, - {file = "healpy-1.19.0-cp313-cp313-macosx_13_0_x86_64.whl", hash = "sha256:97155fe60e8309caa610cee6a26028219f181314dc8ad33517070ec48ea316f7"}, - {file = "healpy-1.19.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:70171ba093c9d7740d7248560609f6d8c17e6e1b9543e7b732f092b2a03dda55"}, - {file = "healpy-1.19.0-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:82b1f40ce8a982bf209a06e0f48b389ff1cbcd2e1524f12a2a78e92e600907b2"}, - {file = "healpy-1.19.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:086bae2510ac60f7cf9bbb520e1d1b8864cc8fca74b46b16e24857e74d08f896"}, - {file = "healpy-1.19.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:920c0a1c6749c05c8ad9522a5a2630f7bc83124c5742ef50f91b9f5e6a1bdcc7"}, - {file = "healpy-1.19.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:26e3e6c50f2d256c9218f0fb0624406a93c4f22bc66deebf51a0f31fd5594b89"}, - {file = "healpy-1.19.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:bc3c243985abf1f1b5d91da12a23ceed49181dd9d0e46ed362babdc92c814aa8"}, - {file = "healpy-1.19.0-cp314-cp314-macosx_13_0_x86_64.whl", hash = "sha256:75179a3681d2e2bfcfe7b33171cbb520be914453062217e6beac15f3c11f9a63"}, - {file = "healpy-1.19.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:52cbcaacdb1ff252ce8edd39d07d44e47666cc09c471bdd8eb1ee472773bafb4"}, - {file = "healpy-1.19.0-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:6236f8b900ae7914c8c3c04d3ddd9a01c695e001798e334d2bffb989603bc46b"}, - {file = "healpy-1.19.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b3cf9c627af77a3fc670d6240d8a83485f373c274945ca8ad156fc2a25d6eb61"}, - {file = "healpy-1.19.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:c22360ad7ec6f16bdbb650dececf52d21f7f0e03a943eb1a1046fdb9c51a3c5e"}, - {file = "healpy-1.19.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:1a479826c1d6ae8312476e079a3f5c7aa3e40911d74f2d7069adcd62e98604ad"}, - {file = "healpy-1.19.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a103e327102e124785c1b5bdd379c6e5469137633212b56a6196bf04ddd100ff"}, - {file = "healpy-1.19.0.tar.gz", hash = "sha256:28e839cb885a23d36c77fc3423a3cb9271a07fda94085bd12fc329f941130ec5"}, -] - -[package.dependencies] -astropy = "*" -numpy = ">=1.19" - -[package.extras] -all = ["matplotlib", "scipy"] -doc = ["ipykernel", "matplotlib", "nbsphinx", "numpydoc", "scipy", "sphinx (>=6,<9)"] -test = ["matplotlib", "pytest", "pytest-astropy-header", "pytest-cython", "pytest-doctestplus", "requests", "scipy"] - [[package]] name = "httpcore" version = "1.0.7" @@ -3954,35 +3910,6 @@ files = [ {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, ] -[[package]] -name = "pyerfa" -version = "2.0.1.5" -description = "Python bindings for ERFA" -optional = false -python-versions = ">=3.9" -groups = ["main"] -markers = "python_version <= \"3.11\"" -files = [ - {file = "pyerfa-2.0.1.5-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:b282d7c60c4c47cf629c484c17ac504fcb04abd7b3f4dfcf53ee042afc3a5944"}, - {file = "pyerfa-2.0.1.5-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:be1aeb70390dd03a34faf96749d5cabc58437410b4aab7213c512323932427df"}, - {file = "pyerfa-2.0.1.5-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0603e8e1b839327d586c8a627cdc634b795e18b007d84f0cda5500a0908254e"}, - {file = "pyerfa-2.0.1.5-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e43c7194e3242083f2350b46c09fd4bf8ba1bcc0ebd1460b98fc47fe2389906"}, - {file = "pyerfa-2.0.1.5-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:07b80cd70701f5d066b1ac8cce406682cfcd667a1186ec7d7ade597239a6021d"}, - {file = "pyerfa-2.0.1.5-cp39-abi3-win32.whl", hash = "sha256:d30b9b0df588ed5467e529d851ea324a67239096dd44703125072fd11b351ea2"}, - {file = "pyerfa-2.0.1.5-cp39-abi3-win_amd64.whl", hash = "sha256:66292d437dcf75925b694977aa06eb697126e7b86553e620371ed3e48b5e0ad0"}, - {file = "pyerfa-2.0.1.5-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:4991dee680ff36c87911d8faa4c7d1aa6278ad9b5e0d16158cf22fa7d74ba25c"}, - {file = "pyerfa-2.0.1.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:690e258294202c86f479e78e80fd235cd27bd717f7f60062fccc3dbd6ef0b1a9"}, - {file = "pyerfa-2.0.1.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:171ce9676a448a7eb555f03aa19ad5c749dbced1ce4f9923e4d93443c4a9c612"}, - {file = "pyerfa-2.0.1.5.tar.gz", hash = "sha256:17d6b24fe4846c65d5e7d8c362dcb08199dc63b30a236aedd73875cc83e1f6c0"}, -] - -[package.dependencies] -numpy = ">=1.19.3" - -[package.extras] -docs = ["sphinx-astropy (>=1.3)"] -test = ["pytest", "pytest-doctestplus (>=0.7)"] - [[package]] name = "pyflakes" version = "3.2.0" @@ -5619,4 +5546,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.12" -content-hash = "6ed7da60935958024b402007a6534a7f77cf53fe58e6a3989e26dd61d3266de4" +content-hash = "f9e1bfec7b95f189eb9ebc0afe65b0cac734f7b550a67e0928e11a7059dd8378" diff --git a/pyproject.toml b/pyproject.toml index a0e8bee..4fcf1a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ omegaconf = "~2.3.0" torchmetrics = "~1.6.1" geopy = "~2.4.1" tigramite = "^5.2.7.0" +basemap = "^1.4.1" cfgrib = "^0.9.15.0" cftime = "^1.6.4.post1" dask = "^2025.2.0" @@ -48,7 +49,7 @@ scikit-learn = "^1.6.1" jupyter-contrib-nbextensions = "^0.7.0" notebook = "^7.3.2" jupyterlab-widgets = "^3.0.13" -healpy = "^1.19.0" + [tool.poetry.group.dev.dependencies] # Optional dependencies that need to be installed with poetry diff --git a/scripts/analyze_savar_run.py b/scripts/analyze_savar_run.py new file mode 100644 index 0000000..1fdfed9 --- /dev/null +++ b/scripts/analyze_savar_run.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python3 +"""Post-hoc analysis of a trained SAVAR experiment. + +Loads a trained LatentTSDCD model checkpoint and its SAVAR dataset, then +produces diagnostic plots and metrics (causal graph recovery, forcing +latent correlations, mode weight analysis). +""" + +import argparse +import os +from pathlib import Path +from typing import Dict, Tuple + +import torch +import torch.nn.functional as F + +from climatem.config import ( + dataParams, + expParams, + gtParams, + modelParams, + optimParams, + plotParams, + savarParams, + trainParams, +) +from climatem.data_loader.causal_datamodule import CausalClimateDataModule +from climatem.model.tsdcd_latent import LatentTSDCD +from climatem.utils import load_config + + +def _expand_path(path_str: str, fallback: Path) -> str: + expanded = Path(os.path.expandvars(path_str)) + if expanded.exists(): + return str(expanded) + return str(fallback) + + +def _load_config(config_path: Path) -> Dict: + if not config_path.exists(): + raise FileNotFoundError(f"Config file {config_path} not found") + return load_config(config_path) + + +def _instantiate_params(cfg: Dict, project_root: Path): + exp_cfg = dict(cfg["exp_params"]) + exp_cfg["exp_path"] = str(project_root / "ws" / "SAVAR_RESULTS") + experiment_params = expParams(**exp_cfg) + + data_cfg = dict(cfg["data_params"]) + data_cfg.pop("seq_len", None) + data_cfg["data_dir"] = _expand_path(data_cfg["data_dir"], project_root / "ws" / "SAVAR_DATA") + data_cfg["climateset_data"] = _expand_path( + data_cfg.get("climateset_data", data_cfg["data_dir"]), + project_root / "ws" / "SAVAR_DATA", + ) + data_cfg["icosahedral_coordinates_path"] = _expand_path( + data_cfg["icosahedral_coordinates_path"], + project_root / "mappings" / "vertex_lonlat_mapping.npy", + ) + data_params_obj = dataParams(**data_cfg) + + gt_params_obj = gtParams(**cfg["gt_params"]) + + train_cfg = dict(cfg["train_params"]) + train_cfg.pop("ratio_valid", None) + train_params_obj = trainParams(**train_cfg) + + model_params_obj = modelParams(**cfg["model_params"]) + optim_params_obj = optimParams(**cfg["optim_params"]) + + plot_cfg = cfg.get("plot_params", {}) + plot_params_obj = plotParams(**plot_cfg) if plot_cfg else plotParams() + + savar_cfg = cfg.get("savar_params", {}) + savar_params_obj = savarParams(**savar_cfg) if savar_cfg else savarParams() + + return ( + experiment_params, + data_params_obj, + gt_params_obj, + train_params_obj, + model_params_obj, + optim_params_obj, + plot_params_obj, + savar_params_obj, + ) + + +def _encode_means(model: LatentTSDCD, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + b, tau, d, _ = x.shape + device = x.device + history = torch.zeros(b, tau, d, model.d_z, device=device) + future = torch.zeros(b, d, model.d_z, device=device) + + for i in range(d): + for t in range(tau): + q_mu, _ = model.autoencoder(x[:, t, i], i, encode=True) + history[:, t, i] = q_mu + q_mu_y, _ = model.autoencoder(y[:, i], i, encode=True) + future[:, i] = q_mu_y + + return history, future + + +def _summarize_prediction(pred: torch.Tensor, truth: torch.Tensor) -> Dict[str, float]: + pred_flat = pred.reshape(pred.shape[0], -1) + truth_flat = truth.reshape(truth.shape[0], -1) + + diff = pred_flat - truth_flat + rmse = torch.sqrt(torch.mean(diff**2)).item() + mae = torch.mean(torch.abs(diff)).item() + + pred_centered = pred_flat - pred_flat.mean(dim=1, keepdim=True) + truth_centered = truth_flat - truth_flat.mean(dim=1, keepdim=True) + corr = F.cosine_similarity(pred_centered, truth_centered, dim=1).mean().item() + + pred_std = pred_flat.std(dim=1, unbiased=False).mean().item() + truth_std = truth_flat.std(dim=1, unbiased=False).mean().item() + amp_ratio = pred_std / truth_std if truth_std > 0 else float("nan") + + sign_agreement = ((pred_flat.sign() == truth_flat.sign()) | (truth_flat == 0)).float().mean().item() + + return { + "rmse": rmse, + "mae": mae, + "corr": corr, + "amp_ratio": amp_ratio, + "sign_agreement": sign_agreement, + } + + +def main(): + parser = argparse.ArgumentParser(description="Diagnose latent transition behaviour on SAVAR runs.") + parser.add_argument( + "--run-dir", + type=Path, + required=True, + help="Experiment directory (contains training_results/model.pth).", + ) + parser.add_argument( + "--config-path", + type=Path, + default=Path("configs/single_param_file_savar.json"), + help="Config file used for training.", + ) + parser.add_argument("--rollout-steps", type=int, default=6, help="Number of autoregressive steps to compute.") + parser.add_argument( + "--gpu", + action="store_true", + help="Force GPU execution. By default, the script runs on CPU to avoid latent sampling device mismatches.", + ) + args = parser.parse_args() + + run_dir = args.run_dir.resolve() + project_root = Path(__file__).resolve().parents[1] + cfg = _load_config(args.config_path.resolve()) + + ( + experiment_params, + data_params_obj, + gt_params_obj, + train_params_obj, + model_params_obj, + optim_params_obj, + plot_params_obj, + savar_params_obj, + ) = _instantiate_params(cfg, project_root) + + force_gpu = args.gpu + gpu_available = torch.cuda.is_available() + use_gpu = force_gpu and gpu_available and experiment_params.gpu + + if force_gpu and not gpu_available: + print("--gpu requested but CUDA is not available; falling back to CPU.") + + if not use_gpu and experiment_params.gpu: + print("Running diagnostics on CPU to avoid device mismatches in latent sampling.") + + experiment_params.gpu = use_gpu + device = torch.device("cuda" if use_gpu else "cpu") + + if not use_gpu: + torch.set_default_tensor_type("torch.FloatTensor") + + torch.set_grad_enabled(False) + + datamodule = CausalClimateDataModule( + tau=experiment_params.tau, + future_timesteps=experiment_params.future_timesteps, + num_months_aggregated=data_params_obj.num_months_aggregated, + train_val_interval_length=data_params_obj.train_val_interval_length, + in_var_ids=data_params_obj.in_var_ids, + out_var_ids=data_params_obj.out_var_ids, + train_years=data_params_obj.train_years, + train_historical_years=data_params_obj.train_historical_years, + test_years=data_params_obj.test_years, + val_split=1 - train_params_obj.ratio_train, + seq_to_seq=data_params_obj.seq_to_seq, + channels_last=data_params_obj.channels_last, + train_scenarios=data_params_obj.train_scenarios, + test_scenarios=data_params_obj.test_scenarios, + train_models=data_params_obj.train_models, + batch_size=data_params_obj.batch_size, + eval_batch_size=data_params_obj.eval_batch_size, + num_workers=experiment_params.num_workers, + pin_memory=experiment_params.pin_memory, + load_train_into_mem=data_params_obj.load_train_into_mem, + load_test_into_mem=data_params_obj.load_test_into_mem, + verbose=experiment_params.verbose, + seed=experiment_params.random_seed, + seq_len=data_params_obj.seq_len, + data_dir=data_params_obj.climateset_data, + output_save_dir=data_params_obj.data_dir, + num_ensembles=data_params_obj.num_ensembles, + lon=experiment_params.lon, + lat=experiment_params.lat, + num_levels=data_params_obj.num_levels, + global_normalization=data_params_obj.global_normalization, + seasonality_removal=data_params_obj.seasonality_removal, + reload_climate_set_data=data_params_obj.reload_climate_set_data, + icosahedral_coordinates_path=data_params_obj.icosahedral_coordinates_path, + time_len=savar_params_obj.time_len, + comp_size=savar_params_obj.comp_size, + noise_val=savar_params_obj.noise_val, + n_per_col=savar_params_obj.n_per_col, + difficulty=savar_params_obj.difficulty, + seasonality=savar_params_obj.seasonality, + overlap=savar_params_obj.overlap, + is_forced=savar_params_obj.is_forced, + f_1=savar_params_obj.f_1, + f_2=savar_params_obj.f_2, + f_time_1=savar_params_obj.f_time_1, + f_time_2=savar_params_obj.f_time_2, + ramp_type=savar_params_obj.ramp_type, + linearity=savar_params_obj.linearity, + poly_degrees=savar_params_obj.poly_degrees, + plot_original_data=savar_params_obj.plot_original_data, + ) + datamodule.setup() + + d = len(data_params_obj.in_var_ids) + num_input = d * experiment_params.tau * (model_params_obj.tau_neigh * 2 + 1) + + model = LatentTSDCD( + num_layers=model_params_obj.num_layers, + num_hidden=model_params_obj.num_hidden, + num_input=num_input, + num_output=model_params_obj.num_output, + num_layers_mixing=model_params_obj.num_layers_mixing, + num_hidden_mixing=model_params_obj.num_hidden_mixing, + position_embedding_dim=model_params_obj.position_embedding_dim, + reduce_encoding_pos_dim=model_params_obj.reduce_encoding_pos_dim, + coeff_kl=optim_params_obj.coeff_kl, + d=d, + distr_z0="gaussian", + distr_encoder="gaussian", + distr_transition="gaussian", + distr_decoder="gaussian", + d_x=experiment_params.d_x, + d_z=experiment_params.d_z, + tau=experiment_params.tau, + instantaneous=model_params_obj.instantaneous, + nonlinear_dynamics=model_params_obj.nonlinear_dynamics, + nonlinear_mixing=model_params_obj.nonlinear_mixing, + hard_gumbel=model_params_obj.hard_gumbel, + no_gt=gt_params_obj.no_gt, + debug_gt_graph=gt_params_obj.debug_gt_graph, + debug_gt_z=gt_params_obj.debug_gt_z, + debug_gt_w=gt_params_obj.debug_gt_w, + tied_w=model_params_obj.tied_w, + fixed=model_params_obj.fixed, + fixed_output_fraction=model_params_obj.fixed_output_fraction, + ).to(device) + model.eval() + + model_path = run_dir / "training_results" / "model.pth" + if not model_path.exists(): + raise FileNotFoundError(f"Model checkpoint {model_path} not found") + state_dict = torch.load(model_path, map_location=device) + model.load_state_dict(state_dict) + + val_batch = next(iter(datamodule.val_dataloader())) + x_val, y_val = val_batch + x_val = torch.nan_to_num(x_val).to(device) + y_val = torch.nan_to_num(y_val[:, 0]).to(device) + + px_mu_tf, _, _, pz_mu_tf, _ = model.predict(x_val, y_val, teacher_forcing=True) + _, q_mu_future = _encode_means(model, x_val, y_val) + latent_mae_tf = torch.mean(torch.abs(pz_mu_tf - q_mu_future)).item() + latent_norm_tf = pz_mu_tf.norm(dim=-1).mean().item() + latent_norm_target = q_mu_future.norm(dim=-1).mean().item() + tf_summary = _summarize_prediction(px_mu_tf.cpu(), y_val.cpu()) + + px_mu_free, _, _, pz_mu_free, _ = model.predict(x_val, teacher_forcing=False) + free_summary = _summarize_prediction(px_mu_free.cpu(), y_val.cpu()) + latent_mae_free = torch.mean(torch.abs(pz_mu_free - q_mu_future)).item() + latent_norm_free = pz_mu_free.norm(dim=-1).mean().item() + + rollout_preds = [] + rollout_latents = [] + history = x_val.clone() + for step in range(args.rollout_steps): + px_mu_step, _, _, pz_mu_step, _ = model.predict(history, teacher_forcing=False) + rollout_preds.append(px_mu_step.cpu()) + rollout_latents.append(pz_mu_step.cpu()) + history = torch.cat([history[:, 1:], px_mu_step.unsqueeze(1)], dim=1) + + rollout_stats = [] + for idx, (pred_step, latent_step) in enumerate(zip(rollout_preds, rollout_latents), start=1): + rollout_stats.append( + { + "step": idx, + "pred_mean": pred_step.mean().item(), + "pred_std": pred_step.std(unbiased=False).item(), + "latent_mean_abs": latent_step.abs().mean().item(), + "latent_norm": latent_step.norm(dim=-1).mean().item(), + } + ) + + print("\nTeacher-forcing check:") + print(f" latent MAE (pred vs encoded): {latent_mae_tf:.4f}") + print(f" latent norm predicted / target: {latent_norm_tf:.4f} / {latent_norm_target:.4f}") + print( + f" reconstruction -> RMSE: {tf_summary['rmse']:.4f}, " + f"MAE: {tf_summary['mae']:.4f}, Corr: {tf_summary['corr']:.4f}, " + f"Amplitude ratio: {tf_summary['amp_ratio']:.4f}, Sign agreement: {tf_summary['sign_agreement']:.4f}" + ) + + print("\nFree-run one step check:") + print(f" latent MAE (pred vs encoded): {latent_mae_free:.4f}") + print(f" latent norm predicted: {latent_norm_free:.4f}") + print( + f" prediction -> RMSE: {free_summary['rmse']:.4f}, " + f"MAE: {free_summary['mae']:.4f}, Corr: {free_summary['corr']:.4f}, " + f"Amplitude ratio: {free_summary['amp_ratio']:.4f}, Sign agreement: {free_summary['sign_agreement']:.4f}" + ) + + print("\nAutoregressive rollout (no teacher forcing):") + for stats in rollout_stats: + print( + f" step {stats['step']:>2}: pred_mean={stats['pred_mean']:+.4f}, " + f"pred_std={stats['pred_std']:.4f}, latent_mean_abs={stats['latent_mean_abs']:.4f}, " + f"latent_norm={stats['latent_norm']:.4f}" + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/check_new_forcing_correlations.py b/scripts/check_new_forcing_correlations.py new file mode 100644 index 0000000..22319c3 --- /dev/null +++ b/scripts/check_new_forcing_correlations.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +""" +Check forcing latent correlations in the NEW SAVAR dataset with orthogonal spatial templates. +""" + +import numpy as np +from scipy.stats import pearsonr + +# NEW data directory with updated parameters +data_dir = "/hkfs/work/workspace_haic/scratch/qa4548-climate_ws/SAVAR_DATA_TEST/m_4_tl_10000_ifd_True_dif_easy_ns_0.25_ses_False_ol_False_f1_0_f2_5.0_ft1_2500_ft2_7500_rmp_sinusoidal_lin_linear_pds_[2]_asp_2.0_asc_1.5_art_2500_apt_5000_adt_7500" + +co2_latent = np.load(f"{data_dir}/co2_latent_trajectory.npy") +aerosol_latent = np.load(f"{data_dir}/aerosol_latent_trajectory.npy") + +print(f"CO2 latent: {co2_latent.shape}") +print(f"Aerosol latent: {aerosol_latent.shape}") + +# Aerosol is (4, time) so transpose to (time, 4) +if aerosol_latent.shape[0] == 4: + aerosol_latent = aerosol_latent.T + print(f"Transposed aerosol to: {aerosol_latent.shape}") + +print("\n" + "="*80) +print("GROUND TRUTH FORCING LATENT CORRELATIONS (NEW DATA)") +print("="*80) + +# CO2 vs each of the 4 aerosol latents +print("\nCO2 ↔ Aerosol latents (temporal correlation):") +for i in range(4): + corr, p = pearsonr(co2_latent, aerosol_latent[:, i]) + print(f" CO2 ↔ Aerosol[{i}]: r={corr:.4f}, p={p:.4e}") + +# CO2 vs mean of aerosols +aerosol_mean = aerosol_latent.mean(axis=1) +corr_mean, p_mean = pearsonr(co2_latent, aerosol_mean) +print(f"\n CO2 ↔ Mean(Aerosol): r={corr_mean:.4f}, p={p_mean:.4e}") + +if abs(corr_mean) > 0.5: + print(" ⚠️ MODERATE CORRELATION: Limited aerosol independence") +elif abs(corr_mean) > 0.3: + print(" ⚠️ WEAK CORRELATION: Some CO2-aerosol coupling remains") +else: + print(" ✓ Low correlation: Good aerosol independence") + +# Inter-aerosol correlations +print("\n" + "="*80) +print("INTER-AEROSOL CORRELATIONS (should be diverse)") +print("="*80) +inter_aerosol_corrs = [] +for i in range(4): + for j in range(i+1, 4): + corr, _ = pearsonr(aerosol_latent[:, i], aerosol_latent[:, j]) + inter_aerosol_corrs.append(abs(corr)) + print(f" Aerosol[{i}] ↔ Aerosol[{j}]: r={corr:.4f}") + +mean_inter_corr = np.mean(inter_aerosol_corrs) +print(f"\n Mean |correlation|: {mean_inter_corr:.4f}") + +if mean_inter_corr > 0.5: + print(" ⚠️ HIGH: Aerosol latents still too correlated") +elif mean_inter_corr > 0.3: + print(" ⚠️ MODERATE: Some redundancy remains") +else: + print(" ✓ Low correlation: Good aerosol diversity") + +# Signal strength +print("\n" + "="*80) +print("FORCING SIGNAL STRENGTH") +print("="*80) +print(f"CO2 std: {co2_latent.std():.4f}") +print(f"Aerosol stds: {[f'{aerosol_latent[:, i].std():.4f}' for i in range(4)]}") +print(f"Mean aerosol std: {aerosol_latent.std(axis=0).mean():.4f}") + +aerosol_to_co2_ratio = aerosol_latent.std(axis=0).mean() / co2_latent.std() +print(f"\nAerosol/CO2 signal ratio: {aerosol_to_co2_ratio:.4f}") + +if aerosol_to_co2_ratio < 0.2: + print(" ⚠️ Aerosol signal still much weaker than CO2") +elif aerosol_to_co2_ratio < 0.5: + print(" ⚠️ Aerosol signal somewhat weak compared to CO2") +else: + print(" ✓ Aerosol signal strength comparable to CO2") + +print("\n" + "="*80) +print("SUMMARY") +print("="*80) + +issues = [] +if abs(corr_mean) > 0.4: + issues.append(f"CO2-Aerosol correlation still moderate (r={corr_mean:.3f})") +if mean_inter_corr > 0.4: + issues.append(f"Inter-aerosol correlation high ({mean_inter_corr:.3f})") +if aerosol_to_co2_ratio < 0.3: + issues.append(f"Aerosol signal weak ({aerosol_to_co2_ratio:.1%} of CO2)") + +if issues: + print("\n⚠️ Issues detected:") + for i, issue in enumerate(issues, 1): + print(f" {i}. {issue}") + print("\nOrthogonal spatial templates may need further tuning.") +else: + print("\n✓ Ground truth forcing data looks good!") + print(" - Low CO2-aerosol correlation") + print(" - Low inter-aerosol correlation") + print(" - Adequate signal strength") + print("\n→ Ready to train model on this dataset") diff --git a/scripts/diagnose_forcing_latents.py b/scripts/diagnose_forcing_latents.py new file mode 100644 index 0000000..53e22e1 --- /dev/null +++ b/scripts/diagnose_forcing_latents.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +""" +Diagnostic script to check forcing latent encoder health. + +Run this on a trained model checkpoint to diagnose why forcing latents aren't learning. +""" + +import sys +import torch +import numpy as np +from pathlib import Path +from scipy.stats import pearsonr + +def diagnose_forcing_encoders(checkpoint_path): + """Load checkpoint and analyze forcing encoder parameters.""" + + print(f"Loading checkpoint from: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + # Extract model state dict + if 'model_state_dict' in checkpoint: + state_dict = checkpoint['model_state_dict'] + else: + state_dict = checkpoint + + print("\n" + "="*80) + print("FORCING ENCODER PARAMETER ANALYSIS") + print("="*80) + + # List all forcing-related keys for debugging + forcing_keys = [k for k in state_dict.keys() if 'forcing' in k.lower()] + if forcing_keys: + print(f"\nFound {len(forcing_keys)} forcing-related parameters:") + for key in sorted(forcing_keys)[:20]: # Show first 20 + print(f" - {key}") + if len(forcing_keys) > 20: + print(f" ... and {len(forcing_keys) - 20} more") + + # Check if forcing encoders exist + has_co2_encoder = any('co2_forcing_encoder' in k for k in state_dict.keys()) + has_aerosol_encoder = any('aerosol_forcing_encoder' in k for k in state_dict.keys()) + + if not has_co2_encoder and not has_aerosol_encoder: + print("\n❌ NO FORCING ENCODERS FOUND!") + print("Model does not have use_forced_latents=True or not trained with forcing.") + return + + print(f"\n✓ CO2 Encoder: {'Present' if has_co2_encoder else 'Missing'}") + print(f"✓ Aerosol Encoder: {'Present' if has_aerosol_encoder else 'Missing'}") + + # Analyze encoder weights + print("\n" + "-"*80) + print("ENCODER WEIGHT MAGNITUDES") + print("-"*80) + + encoder_stats = {} + + # CO2 encoder + if has_co2_encoder: + co2_keys = [k for k in state_dict.keys() if 'co2_forcing_encoder_mu' in k and 'weight' in k] + if co2_keys: + print(f"\nCO2 Encoder ALL layers:") + for key in sorted(co2_keys): + w = state_dict[key] + print(f" {key}: shape={w.shape}, norm={w.norm().item():.4f}") + + # Use first layer for stats + co2_weight = state_dict[co2_keys[0]] + encoder_stats['co2_encoder'] = { + 'key': co2_keys[0], + 'shape': co2_weight.shape, + 'mean': co2_weight.mean().item(), + 'std': co2_weight.std().item(), + 'max': co2_weight.max().item(), + 'norm': co2_weight.norm().item() + } + + # Check final layer output dimension + final_layer_keys = [k for k in co2_keys if 'lin2' in k] + if final_layer_keys: + final_w = state_dict[final_layer_keys[0]] + n_output = final_w.shape[0] + print(f"\n Final layer outputs {n_output} dimensions (expected 1 for n_forced_latents_co2)") + if n_output != 1: + print(f" ⚠️ ENCODER BUG: CO2 encoder outputs {n_output} instead of 1 latent!") + + # Aerosol encoder + if has_aerosol_encoder: + aerosol_keys = [k for k in state_dict.keys() if 'aerosol_forcing_encoder_mu' in k and 'weight' in k] + if aerosol_keys: + print(f"\nAerosol Encoder ALL layers:") + for key in sorted(aerosol_keys): + w = state_dict[key] + print(f" {key}: shape={w.shape}, norm={w.norm().item():.4f}") + + # Use first layer for stats + aerosol_weight = state_dict[aerosol_keys[0]] + encoder_stats['aerosol_encoder'] = { + 'key': aerosol_keys[0], + 'shape': aerosol_weight.shape, + 'mean': aerosol_weight.mean().item(), + 'std': aerosol_weight.std().item(), + 'max': aerosol_weight.max().item(), + 'norm': aerosol_weight.norm().item() + } + + # Check final layer output dimension + final_layer_keys = [k for k in aerosol_keys if 'lin2' in k] + if final_layer_keys: + final_w = state_dict[final_layer_keys[0]] + n_output = final_w.shape[0] + print(f"\n ⚠️ CRITICAL: Final layer outputs {n_output} dimensions (expected 4 for n_forced_latents_aerosol)") + if n_output != 4: + print(f" ⚠️ ENCODER BUG: Aerosol encoder outputs {n_output} instead of 4 latents!") + + # Climate encoder for comparison + climate_keys = [k for k in state_dict.keys() if 'mu_encoder' in k and 'weight' in k and 'forcing' not in k] + if climate_keys: + climate_weight = state_dict[climate_keys[0]] + encoder_stats['climate_encoder'] = { + 'key': climate_keys[0], + 'shape': climate_weight.shape, + 'mean': climate_weight.mean().item(), + 'std': climate_weight.std().item(), + 'max': climate_weight.max().item(), + 'norm': climate_weight.norm().item() + } + print(f"\nClimate Encoder Weight (for comparison):") + print(f" Shape: {climate_weight.shape}") + print(f" Mean: {climate_weight.mean().item():.6f}") + print(f" Std: {climate_weight.std().item():.6f}") + print(f" L2 Norm: {climate_weight.norm().item():.6f}") + + # CRITICAL: Check encoder correlation (are CO2 and aerosol encoders learning the same thing?) + print("\n" + "-"*80) + print("ENCODER CORRELATION ANALYSIS (Redundancy Check)") + print("-"*80) + + encoder_correlation = None + if has_co2_encoder and has_aerosol_encoder: + co2_keys = [k for k in state_dict.keys() if 'co2_forcing_encoder_mu' in k and 'weight' in k] + aerosol_keys = [k for k in state_dict.keys() if 'aerosol_forcing_encoder_mu' in k and 'weight' in k] + + if co2_keys and aerosol_keys: + co2_w = state_dict[co2_keys[0]] + aerosol_w = state_dict[aerosol_keys[0]] + + # Flatten weights to compute correlation + co2_flat = co2_w.flatten().numpy() + + # For aerosol, we may have multiple output dimensions (n_aerosol_latents) + # Check each aerosol latent encoder separately + if aerosol_w.dim() == 2: + # Shape: (input_dim, n_aerosol_latents) + n_aerosol_latents = aerosol_w.shape[1] + print(f"\nComparing CO2 encoder to {n_aerosol_latents} aerosol latent encoders:") + + correlations = [] + for i in range(n_aerosol_latents): + aerosol_flat = aerosol_w[:, i].flatten().numpy() + + # Ensure same length for correlation + min_len = min(len(co2_flat), len(aerosol_flat)) + if min_len > 1: + corr, p_value = pearsonr(co2_flat[:min_len], aerosol_flat[:min_len]) + correlations.append(corr) + print(f" CO2 ↔ Aerosol[{i}]: r={corr:.4f} (p={p_value:.4e})") + + encoder_correlation = np.mean(correlations) + print(f"\n Mean correlation: {encoder_correlation:.4f}") + + if encoder_correlation > 0.8: + print(f" ⚠️ HIGH CORRELATION (>0.8): Encoders may have collapsed to similar solutions!") + elif encoder_correlation > 0.5: + print(f" ⚠️ MODERATE CORRELATION (>0.5): Some redundancy detected") + else: + print(f" ✓ Low correlation: Encoders appear to learn distinct patterns") + else: + # Fallback for other shapes + aerosol_flat = aerosol_w.flatten().numpy() + min_len = min(len(co2_flat), len(aerosol_flat)) + if min_len > 1: + corr, p_value = pearsonr(co2_flat[:min_len], aerosol_flat[:min_len]) + encoder_correlation = corr + print(f"\n CO2 ↔ Aerosol encoder correlation: r={corr:.4f} (p={p_value:.4e})") + + if abs(corr) > 0.8: + print(f" ⚠️ HIGH CORRELATION: Encoders collapsed!") + elif abs(corr) > 0.5: + print(f" ⚠️ MODERATE CORRELATION: Some redundancy") + else: + print(f" ✓ Low correlation: Distinct patterns") + + # Check bias terms + co2_bias_keys = [k for k in state_dict.keys() if 'co2_forcing_encoder_mu' in k and 'bias' in k] + aerosol_bias_keys = [k for k in state_dict.keys() if 'aerosol_forcing_encoder_mu' in k and 'bias' in k] + + if co2_bias_keys and aerosol_bias_keys: + co2_bias = state_dict[co2_bias_keys[0]] + aerosol_bias = state_dict[aerosol_bias_keys[0]] + + print(f"\n Bias comparison:") + print(f" CO2 bias: {co2_bias.numpy()}") + print(f" Aerosol bias: {aerosol_bias.numpy()}") + + # Check if aerosol biases are just scaled/shifted versions of CO2 + if co2_bias.numel() == 1 and aerosol_bias.numel() > 1: + bias_std = aerosol_bias.std() + if bias_std < 0.01: + print(f" ⚠️ All aerosol biases are nearly identical (std={bias_std:.4f})") + + # Analyze encoder logvars + print("\n" + "-"*80) + print("ENCODER LOGVAR VALUES") + print("-"*80) + + # CO2 logvar + co2_logvar_keys = [k for k in state_dict.keys() if 'co2_forcing_encoder_logvar' in k] + if co2_logvar_keys: + co2_logvar = state_dict[co2_logvar_keys[0]] + print(f"\nCO2 Encoder Logvar:") + print(f" Shape: {co2_logvar.shape}") + print(f" Values: {co2_logvar.numpy()}") + print(f" Std: {torch.exp(0.5 * co2_logvar).numpy()}") + + # Aerosol logvar + aerosol_logvar_keys = [k for k in state_dict.keys() if 'aerosol_forcing_encoder_logvar' in k] + if aerosol_logvar_keys: + aerosol_logvar = state_dict[aerosol_logvar_keys[0]] + print(f"\nAerosol Encoder Logvar:") + print(f" Shape: {aerosol_logvar.shape}") + print(f" Values: {aerosol_logvar.numpy()}") + print(f" Std: {torch.exp(0.5 * aerosol_logvar).numpy()}") + + # Climate logvar + climate_logvar_keys = [k for k in state_dict.keys() if k.endswith('logvar_encoder') and 'forcing' not in k] + if climate_logvar_keys: + climate_logvar = state_dict[climate_logvar_keys[0]] + print(f"\nClimate Encoder Logvar (for comparison):") + print(f" Shape: {climate_logvar.shape}") + print(f" Mean: {climate_logvar.mean().item():.6f}") + print(f" Range: [{climate_logvar.min().item():.6f}, {climate_logvar.max().item():.6f}]") + + # Analyze decoder weights + print("\n" + "-"*80) + print("DECODER WEIGHT ANALYSIS (by latent)") + print("-"*80) + + decoder_keys = [k for k in state_dict.keys() if 'decoder' in k and 'weight' in k.lower() and 'w_adj' not in k] + if decoder_keys: + # Try to find decoder output layer + decoder_weight_key = [k for k in decoder_keys if 'layers' in k or 'output' in k] + if decoder_weight_key: + print(f"\nFound decoder weights: {decoder_weight_key[0]}") + + # Try to load w_adj directly from checkpoint if available + w_adj_keys = [k for k in state_dict.keys() if 'w_adj' in k] + decoder_norms = None + if w_adj_keys: + print(f"\nFound w_adj in checkpoint: {w_adj_keys}") + for key in w_adj_keys: + w_adj = state_dict[key] + print(f"\n {key}:") + print(f" Shape: {w_adj.shape}") + if w_adj.dim() >= 2: + # Compute L2 norm per latent (column norms for decoder utilization) + # w_adj shape is typically (d_x, d_z) or (d, d_x, d_z) + if w_adj.dim() == 3: + # Average over feature dimension d + norms = torch.norm(w_adj, dim=(0, 1)) # Norm over (d, d_x) + else: + norms = torch.norm(w_adj, dim=0) # Norm over d_x (rows) + + decoder_norms = norms + print(f" Per-latent decoder L2 norms: {norms.numpy()}") + + # Check for dead latents + threshold = norms.mean() * 0.01 # 1% of mean + dead_latents = (norms < threshold).nonzero(as_tuple=True)[0] + if len(dead_latents) > 0: + print(f" ⚠️ DEAD LATENTS (norm < {threshold:.6f}): {dead_latents.numpy()}") + + # Check if forcing latents have significantly lower norms + # Assuming latents are ordered: [climate_latents, co2_latents, aerosol_latents] + # This needs to be inferred from config or specified + if len(norms) == 9: # Example: 4 climate + 1 CO2 + 4 aerosol + climate_norms = norms[:4] + co2_norms = norms[4:5] + aerosol_norms = norms[5:9] + + print(f"\n Decoder utilization by latent type:") + print(f" Climate latents (0-3): mean={climate_norms.mean():.4f}, std={climate_norms.std():.4f}") + print(f" CO2 latent (4): {co2_norms[0]:.4f}") + print(f" Aerosol latents (5-8): mean={aerosol_norms.mean():.4f}, std={aerosol_norms.std():.4f}") + + if co2_norms.mean() < climate_norms.mean() * 0.2: + print(f" ⚠️ CO2 decoder usage is {(co2_norms.mean()/climate_norms.mean()):.1%} of climate") + if aerosol_norms.mean() < climate_norms.mean() * 0.2: + print(f" ⚠️ Aerosol decoder usage is {(aerosol_norms.mean()/climate_norms.mean()):.1%} of climate") + + # Check for optimizer state + print("\n" + "-"*80) + print("OPTIMIZER STATE") + print("-"*80) + + if 'optimizer_state_dict' in checkpoint: + opt_state = checkpoint['optimizer_state_dict'] + print("\n✓ Optimizer state found in checkpoint") + + # Check if forcing encoder parameters have optimizer state + if 'state' in opt_state: + param_groups = opt_state.get('param_groups', []) + print(f" Number of parameter groups: {len(param_groups)}") + + for i, group in enumerate(param_groups): + print(f" Group {i}: lr={group.get('lr', 'N/A')}, {len(group.get('params', []))} parameters") + else: + print("\n❌ No optimizer state in checkpoint") + + # Summary and diagnosis + print("\n" + "="*80) + print("DIAGNOSIS SUMMARY") + print("="*80) + + issues_found = [] + recommendations = [] + + # Check encoder weight norms + if 'co2_encoder' in encoder_stats and 'climate_encoder' in encoder_stats: + ratio = encoder_stats['co2_encoder']['norm'] / encoder_stats['climate_encoder']['norm'] + if ratio < 0.1: + issues_found.append(f"CO2 encoder weights are {ratio:.2%} of climate encoder (too small)") + recommendations.append("Increase forcing_latent_supervision_coeff or check gradient flow") + print(f"\n✓ CO2/Climate encoder weight ratio: {ratio:.2%}") + + if 'aerosol_encoder' in encoder_stats and 'climate_encoder' in encoder_stats: + ratio = encoder_stats['aerosol_encoder']['norm'] / encoder_stats['climate_encoder']['norm'] + if ratio < 0.1: + issues_found.append(f"Aerosol encoder weights are {ratio:.2%} of climate encoder (too small)") + recommendations.append("Increase forcing_latent_supervision_coeff or aerosol_effect_strength") + print(f"✓ Aerosol/Climate encoder weight ratio: {ratio:.2%}") + + # Check encoder correlation + if encoder_correlation is not None: + if encoder_correlation > 0.8: + issues_found.append(f"CO2 and aerosol encoders highly correlated (r={encoder_correlation:.3f})") + recommendations.append("CRITICAL: Encoders collapsed! Add diversity loss or increase aerosol_timing_stagger") + elif encoder_correlation > 0.5: + issues_found.append(f"CO2 and aerosol encoders moderately correlated (r={encoder_correlation:.3f})") + recommendations.append("Increase temporal diversity (aerosol_timing_stagger) or spatial contrast") + print(f"✓ CO2↔Aerosol encoder correlation: {encoder_correlation:.3f}") + + # Check decoder utilization + if decoder_norms is not None and len(decoder_norms) == 9: + climate_norms = decoder_norms[:4] + forcing_norms = decoder_norms[4:] + ratio = forcing_norms.mean() / climate_norms.mean() + if ratio < 0.3: + issues_found.append(f"Forcing latents underutilized by decoder ({ratio:.1%} vs climate)") + recommendations.append("Increase decoder_utilization_coeff or min_forcing_decoder_norm") + print(f"✓ Forcing/Climate decoder usage ratio: {ratio:.2%}") + + # Check logvar values + if co2_logvar_keys: + co2_std = torch.exp(0.5 * state_dict[co2_logvar_keys[0]]) + if (co2_std < 0.01).any(): + issues_found.append("CO2 encoder has very small variance (std < 0.01)") + recommendations.append("Check if CO2 encoder is receiving gradients during training") + print(f"✓ CO2 latent std range: [{co2_std.min().item():.4f}, {co2_std.max().item():.4f}]") + + if aerosol_logvar_keys: + aerosol_std = torch.exp(0.5 * state_dict[aerosol_logvar_keys[0]]) + if (aerosol_std < 0.01).any(): + issues_found.append("Aerosol encoder has very small variance (std < 0.01)") + recommendations.append("Check if aerosol encoder is receiving gradients during training") + print(f"✓ Aerosol latent std range: [{aerosol_std.min().item():.4f}, {aerosol_std.max().item():.4f}]") + + # Print summary + if issues_found: + print("\n" + "="*80) + print("⚠️ ISSUES DETECTED:") + print("="*80) + for i, issue in enumerate(issues_found, 1): + print(f" {i}. {issue}") + + if recommendations: + print("\n" + "="*80) + print("💡 RECOMMENDATIONS:") + print("="*80) + for i, rec in enumerate(recommendations, 1): + print(f" {i}. {rec}") + else: + print("\n✓ No obvious parameter issues detected") + print(" If forcing attribution is still poor, check:") + print(" 1. Training loss curves (are forcing losses decreasing?)") + print(" 2. Ground truth forcing data quality") + print(" 3. Whether forcing effects are strong enough in synthetic data") + + print("\n" + "="*80) + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: python diagnose_forcing_latents.py ") + print("\nExample:") + print(" python diagnose_forcing_latents.py /path/to/results/model_40000.pt") + sys.exit(1) + + checkpoint_path = Path(sys.argv[1]) + if not checkpoint_path.exists(): + print(f"Error: Checkpoint not found: {checkpoint_path}") + sys.exit(1) + + diagnose_forcing_encoders(checkpoint_path) diff --git a/scripts/main_picabu.py b/scripts/main_picabu.py index 0d112c9..da51f28 100755 --- a/scripts/main_picabu.py +++ b/scripts/main_picabu.py @@ -1,5 +1,22 @@ -# Here we have a quick main where we are testing data loading with different ensemble members and ideally with different climate models. +"""Main training entry point for the ClimatEM causal discovery model. + +This script parses a JSON configuration file (via ``--config-path``), builds +the data module, model, and trainer objects, then launches training through +HuggingFace Accelerate. It is typically invoked with:: + + accelerate launch scripts/main_picabu.py --config-path .json + +Workflow +-------- +1. Parse CLI arguments and load the JSON config. +2. Instantiate all parameter dataclasses (``expParams``, ``dataParams``, etc.). +3. Build a ``CausalClimateDataModule`` for temporal-causal sequences. +4. Construct the ``LatentTSDCD`` model. +5. Hand both to ``TrainingLatent`` which runs the ALM optimisation loop. +6. Compute and save final metrics (SHD, precision, recall, MCC). +""" import json +import logging import os import time import warnings @@ -10,6 +27,14 @@ from accelerate import Accelerator from accelerate.utils import DistributedDataParallelKwargs +# Ensure INFO logs are emitted to console even when no handlers are pre-configured. +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + force=True, +) +logger = logging.getLogger(__name__) + from climatem.config import * from climatem.data_loader.causal_datamodule import CausalClimateDataModule from climatem.model.metrics import edge_errors, mcc_latent, precision_recall, shd, w_mae @@ -33,18 +58,30 @@ def __init__(self, **kwargs): def to_dict(self): return self.__dict__ - # def fancy_print(self, prefix=''): - # str_list = [] - # for key, val in self.__dict__.items(): - # str_list.append(prefix + f"{key} = {val}") - # return '\n'.join(str_list) - def main( experiment_params, data_params, train_params, model_params, optim_params, plot_params, savar_params ): - """ - :param hp: object containing hyperparameter values + """Build the datamodule, model, and trainer, then run the full training loop. + + Parameters + ---------- + experiment_params : expParams + Experiment-level settings (paths, GPU flag, random seed, latent dims). + data_params : dataParams + Data loading settings (paths, scenarios, variables, batch size). + gt_params : gtParams + Ground-truth debugging flags (whether to use GT graph/latents/weights). + train_params : trainParams + Training hyper-parameters (learning rate, iterations, patience). + model_params : modelParams + Model architecture settings (layers, hidden dims, dynamics type). + optim_params : optimParams + Optimisation settings (loss coefficients, ALM penalties). + plot_params : plotParams + Plotting frequency and options. + savar_params : savarParams + Synthetic SAVAR data generation parameters. """ t0 = time.time() @@ -66,11 +103,11 @@ def main( device = torch.device("cuda" if (torch.cuda.is_available() and experiment_params.gpu) else "cpu") if data_params.data_format == "hdf5": - print("IS HDF5") + logger.info("IS HDF5") return else: if model_params.instantaneous and experiment_params.tau == 0: - print("Using instantaneous connections") + logger.info("Using instantaneous connections") tau = experiment_params.tau + 1 else: tau = experiment_params.tau @@ -112,6 +149,7 @@ def main( seasonality_removal=data_params.seasonality_removal, reload_climate_set_data=data_params.reload_climate_set_data, icosahedral_coordinates_path=data_params.icosahedral_coordinates_path, + forcing_conditioning=data_params.forcing_conditioning, # Below SAVAR data arguments time_len=savar_params.time_len, comp_size=savar_params.comp_size, @@ -119,6 +157,11 @@ def main( n_per_col=savar_params.n_per_col, difficulty=savar_params.difficulty, seasonality=savar_params.seasonality, + periods=savar_params.periods, + amplitudes=savar_params.amplitudes, + phases=savar_params.phases, + yearly_jitter_amp=savar_params.yearly_jitter_amp, + yearly_jitter_phase=savar_params.yearly_jitter_phase, overlap=savar_params.overlap, is_forced=savar_params.is_forced, f_1=savar_params.f_1, @@ -129,6 +172,25 @@ def main( linearity=savar_params.linearity, poly_degrees=savar_params.poly_degrees, plot_original_data=savar_params.plot_original_data, + use_separate_forcings=savar_params.use_separate_forcings, + forcing_amplification=savar_params.forcing_amplification, + aerosol_scale = savar_params.aerosol_scale, + aerosol_spatial_contrast = savar_params.aerosol_spatial_contrast, + aerosol_ramp_up_time = savar_params.aerosol_ramp_up_time, + aerosol_peak_time = savar_params.aerosol_peak_time, + aerosol_decline_time = savar_params.aerosol_decline_time, + # Forcing causal structure parameters + n_co2_latents=savar_params.n_co2_latents, + n_aerosol_latents=savar_params.n_aerosol_latents, + co2_effect_strength=savar_params.co2_effect_strength, + aerosol_effect_strength=savar_params.aerosol_effect_strength, + # Background state parameters + enable_background=savar_params.enable_background, + background_strength=savar_params.background_strength, + background_strength_mode=savar_params.background_strength_mode, + background_smoothness=savar_params.background_smoothness, + background_timescale_rho=savar_params.background_timescale_rho, + background_n_modes=savar_params.background_n_modes ) datamodule.setup() @@ -137,21 +199,25 @@ def main( # WE SHOULD REMOVE THIS, and initialize with params d = len(data_params.in_var_ids) - print(f"Using {d} variables") + logger.info(f"Using {d} variables") + # num_input = (number of variables) * (time lags) * (spatial neighbourhood width). + # With instantaneous connections we include the current time step (tau + 1); + # the neighbourhood spans tau_neigh steps on each side plus the centre (2*tau_neigh + 1). if model_params.instantaneous: num_input = d * (experiment_params.tau + 1) else: num_input = d * (experiment_params.tau) # set the model - model = LatentTSDCD( - num_layers=model_params.num_layers, - num_hidden=model_params.num_hidden, - num_input=num_input, - num_output=2, # This should be parameterized somewhere? + model = LatentTSDCD( + num_layers=model_params.num_layers, + num_hidden=model_params.num_hidden, + num_input=num_input, + num_output=2, # 2 outputs per latent: mean and variance of a Gaussian distribution num_layers_mixing=model_params.num_layers_mixing, num_hidden_mixing=model_params.num_hidden_mixing, position_embedding_dim=model_params.position_embedding_dim, + reduce_encoding_pos_dim=model_params.reduce_encoding_pos_dim, transition_param_sharing=model_params.transition_param_sharing, position_embedding_transition=model_params.position_embedding_transition, coeff_kl=optim_params.coeff_kl, @@ -177,7 +243,14 @@ def main( # also fixed=model_params.fixed, fixed_output_fraction=model_params.fixed_output_fraction, - ) + use_exogenous=model_params.use_exogenous, + d_y_co2=model_params.d_y_co2, + d_y_aerosol=model_params.d_y_aerosol, + use_forced_latents=model_params.use_forced_latents, + n_forced_latents_co2=model_params.n_forced_latents_co2, + n_forced_latents_aerosol=model_params.n_forced_latents_aerosol, + forcing_arch=model_params.forcing_arch, + ) # Make folder to save run results exp_path = Path(experiment_params.exp_path) @@ -212,7 +285,6 @@ def main( hp["train_params"] = train_params.__dict__ hp["model_params"] = model_params.__dict__ hp["optim_params"] = optim_params.__dict__ - hp["savar_params"] = savar_params.__dict__ with open(exp_path / "params.json", "w") as file: json.dump(hp, file, indent=4) @@ -240,10 +312,11 @@ def main( accelerator, wandbname=name, profiler=False, + profiler_path="./log", ) # where is the model at this point? - print("Where is my model?", next(trainer.model.parameters()).device) + logger.info("Where is my model? %s", next(trainer.model.parameters()).device) valid_loss = trainer.train_with_QPM() @@ -323,9 +396,9 @@ def main( # assert that trainer.model is in eval mode if trainer.model.training: - print("Model is in train mode") + logger.info("Model is in train mode") else: - print("Model is in eval mode") + logger.info("Model is in eval mode") # NOTE: just dummies here for now # train_mse, train_smape, val_mse, val_smape = 10.0, 10.0, 10.0, 10.0 @@ -400,14 +473,14 @@ def assert_args( # get user's scratch directory: scratch_path = os.getenv("SCRATCH") params["data_params"]["data_dir"] = params["data_params"]["data_dir"].replace("$SCRATCH", scratch_path) - print ("new data path:", params["data_params"]["data_dir"]) + logger.info("new data path: %s", params["data_params"]["data_dir"]) params["exp_params"]["exp_path"] = params["exp_params"]["exp_path"].replace("$SCRATCH", scratch_path) - print ("new exp path:", params["exp_params"]["exp_path"]) + logger.info("new exp path: %s", params["exp_params"]["exp_path"]) # get directory of project via current file (aka .../climatem/scripts/main_picabu.py) params["data_params"]["icosahedral_coordinates_path"] = params["data_params"]["icosahedral_coordinates_path"].replace("$CLIMATEMDIR", root_path.absolute().as_posix()) - print ("new icosahedron path:", params["data_params"]["icosahedral_coordinates_path"]) + logger.info("new icosahedron path: %s", params["data_params"]["icosahedral_coordinates_path"]) experiment_params = expParams(**params["exp_params"]) data_params = dataParams(**params["data_params"]) @@ -427,7 +500,13 @@ def assert_args( #Below is coherent with savar data generation if savar_params.use_correct_hyperparams: - experiment_params.d_z = int(savar_params.n_per_col**2) + n_climate_latents = int(savar_params.n_per_col**2) + if model_params.use_forced_latents: + experiment_params.d_z = n_climate_latents + model_params.n_forced_latents_co2 + model_params.n_forced_latents_aerosol + else: + experiment_params.d_z = n_climate_latents + if not savar_params.is_forced: + model_params.use_exogenous = False if savar_params.difficulty == "easy": optim_params.sparsity_upper_threshold = 1/(experiment_params.d_z*experiment_params.tau) #expected N out of N^2*tau total links if savar_params.difficulty == "med_easy": @@ -448,4 +527,3 @@ def assert_args( ) main(experiment_params, data_params, train_params, model_params, optim_params, plot_params, savar_params) - diff --git a/scripts/varimax_pcmci_savar_evaluation.py b/scripts/varimax_pcmci_savar_evaluation.py index 366ae67..f62ea35 100644 --- a/scripts/varimax_pcmci_savar_evaluation.py +++ b/scripts/varimax_pcmci_savar_evaluation.py @@ -1,4 +1,16 @@ +"""Baseline evaluation of causal discovery using PCMCI with Varimax-PCA preprocessing. + +This script provides a baseline comparison for the ClimatEM causal discovery +model. It applies Varimax-rotated PCA to extract latent modes from synthetic +SAVAR data and then runs PCMCI (Runge et al., 2019) to infer causal links. +The inferred adjacency matrix is compared against the known ground-truth +structure using precision, recall, F1, and SHD. + +The script also loads results from the ClimatEM model (CDSD) for side-by-side +comparison. +""" import json +import logging from pathlib import Path import matplotlib.pyplot as plt @@ -17,21 +29,30 @@ from tigramite.pcmci import PCMCI from tqdm.auto import tqdm +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + parcorr = ParCorr(significance="analytic") def extract_adjacency_matrix(links_coeffs, N, tau): - """ - Extract the ground truth adjacency matrices for each time lag from the links_coeffs. - - Args: - links_coeffs (dict): The dictionary of causal links between latent variables. - N (int): The number of latent variables. - tau (int): The maximum time lag. - - Returns: - adj_matrices (np.ndarray): The ground truth adjacency matrices (tau x N x N), - where each matrix corresponds to a different time lag. + """Extract ground-truth adjacency matrices for each time lag. + + Parameters + ---------- + links_coeffs : dict + Dictionary mapping each latent variable index to a list of + ``((target_var, lag), coefficient)`` tuples describing causal links. + N : int + Number of latent variables. + tau : int + Maximum time lag to consider. + + Returns + ------- + adj_matrices : np.ndarray + Binary adjacency matrices with shape ``(tau, N, N)`` where entry + ``[t, i, j]`` is 1 if variable *j* causes variable *i* at lag *t+1*. """ # Initialize a 3D array to store adjacency matrices for each time lag (tau x N x N) adj_matrices = np.zeros((tau, N, N)) @@ -54,8 +75,25 @@ def extract_adjacency_matrix(links_coeffs, N, tau): def evaluate_adjacency_matrix(A_inferred, A_ground_truth, threshold): - """Evaluates the precision, recall, F1-score, and Structural Hamming Distance (SHD) between the inferred and ground - truth adjacency matrices.""" + """Evaluate precision, recall, F1-score, and SHD between two adjacency matrices. + + Parameters + ---------- + A_inferred : np.ndarray + Inferred adjacency matrix (may be real-valued). + A_ground_truth : np.ndarray + Ground-truth adjacency matrix (may be real-valued). + threshold : float + Threshold for binarising both matrices before comparison. + + Returns + ------- + precision : float + recall : float + f1 : float + shd : int + Structural Hamming Distance (false positives + false negatives). + """ # Binarize the matrices before comparison A_inferred_bin = binarize_matrix(A_inferred, threshold) A_ground_truth_bin = binarize_matrix(A_ground_truth, threshold) @@ -77,45 +115,49 @@ def evaluate_adjacency_matrix(A_inferred, A_ground_truth, threshold): return precision, recall, f1, shd -def extract_adjacency_matrix(links_coeffs, N, tau): - """ - Extract the ground truth adjacency matrices for each time lag from the links_coeffs. - - Args: - links_coeffs (dict): The dictionary of causal links between latent variables. - N (int): The number of latent variables. - tau (int): The maximum time lag. - - Returns: - adj_matrices (np.ndarray): The ground truth adjacency matrices (tau x N x N), - where each matrix corresponds to a different time lag. - """ - # Initialize a 3D array to store adjacency matrices for each time lag (tau x N x N) - adj_matrices = np.zeros((tau, N, N)) - - # Loop through each component and its links - for key, values in links_coeffs.items(): - for link, coeff in values: - target_var, lag = link - time_lag = -lag # Convert the negative lag to a positive index - # Only consider lags that are within the specified time window (tau) - if time_lag <= tau: - if abs(coeff) > 0.01: - adj_matrices[time_lag - 1, key, target_var] = ( - 1 # Fill the adjacency matrix at the appropriate time lag - ) - else: - adj_matrices[time_lag - 1, key, target_var] = 0 - - return adj_matrices - - def binarize_matrix(A, threshold=0.5): - """Binarizes the adjacency matrix by applying a threshold.""" + """Binarise an adjacency matrix by applying a threshold. + + Parameters + ---------- + A : np.ndarray + Real-valued adjacency matrix. + threshold : float + Values strictly above this are set to 1; all others to 0. + + Returns + ------- + np.ndarray + Integer array with values in {0, 1}. + """ return (A > threshold).astype(int) def varimax(Phi, gamma=1, q=20, tol=1e-6): + """Compute the Varimax rotation of a factor loading matrix. + + Implements the standard Varimax criterion (Kaiser, 1958) via SVD-based + iterative optimisation. + + Parameters + ---------- + Phi : np.ndarray + Factor loading matrix of shape ``(p, k)`` where *p* is the number of + observed variables and *k* is the number of factors. + gamma : float, optional + Rotation parameter. ``gamma=1`` gives standard Varimax. + q : int, optional + Maximum number of iterations. + tol : float, optional + Convergence tolerance (ratio of successive singular-value sums). + + Returns + ------- + rotated : np.ndarray + Rotated loading matrix ``Phi @ R``. + R : np.ndarray + Orthogonal rotation matrix. + """ p, k = Phi.shape R = eye(k) d = 0 @@ -132,7 +174,6 @@ def varimax(Phi, gamma=1, q=20, tol=1e-6): if __name__ == "__main__": - # load your existing JSON config config_path = Path("configs/single_param_file_savar.json") with open(config_path, "r") as f: @@ -169,68 +210,16 @@ def varimax(Phi, gamma=1, q=20, tol=1e-6): params = np.load(params_file, allow_pickle=True).item() links_coeffs = params["links_coeffs"] - # modes_gt = np.load(savar_folder / f"{savar_fname[:-4]}_mode_weights.npy") - # modes_gt -= modes_gt.mean() - # modes_gt /= modes_gt.std() - adj_gt = extract_adjacency_matrix(links_coeffs, n_modes, tau) n_gt_connections = (np.array(adj_gt) > 0).sum() - # load CDSD results (already permuted / aligned) + # NOTE: Adjust these paths to match your local environment. cdsd_adj_inferred_path = Path("/home/ka/ka_iti/ka_qa4548/my_projects/climatem/workspace/pfs7wor9/ka_qa4548-results/SAVAR_DATA_TEST/var_savar_scenarios_piControl_nonlinear_False_tau_5_z_9_lr_0.001_bs_256_spreg_0_ormuinit_100000.0_spmuinit_0.1_spthres_0.05_fixed_False_num_ensembles_1_instantaneous_False_crpscoef_1_spcoef_0_tempspcoef_0_overlap_0.3_forcing_True/plots/graphs.npy") cdsd_modes_inferred_path = Path("/home/ka/ka_iti/ka_qa4548/my_projects/climatem/workspace/pfs7wor9/ka_qa4548-results/SAVAR_DATA_TEST/var_savar_scenarios_piControl_nonlinear_False_tau_5_z_9_lr_0.001_bs_256_spreg_0_ormuinit_100000.0_spmuinit_0.1_spthres_0.05_fixed_False_num_ensembles_1_instantaneous_False_crpscoef_1_spcoef_0_tempspcoef_0_overlap_0.3_forcing_True/plots/w_decoder.npy") modes_inferred = np.load(cdsd_modes_inferred_path) adj_w = np.load(cdsd_adj_inferred_path) - ############################ - - # # Fit PCA + varimax - # pca_model = PCA(n_modes).fit(savar_data.T) - # latent_data = pca_model.transform(savar_data.T) - # varimaxpcs, varimax_rotation = varimax(latent_data) - - # # To recover which mode is which and permute accordingly when evaluating - # inverse_varimax = dot(latent_data, np.linalg.pinv(varimax_rotation)) - # reverted_data = pca_model.inverse_transform(inverse_varimax) - - # dataframe = pp.DataFrame(varimaxpcs, datatime={0: np.arange(len(varimaxpcs))}, var_names=var_names) - # # Run PCMCI - # pcmci = PCMCI(dataframe=dataframe, cond_ind_test=parcorr, verbosity=1) - - # results = pcmci.run_pcmci(tau_min=1, tau_max=5, pc_alpha=None, alpha_level=0.001) - - # Permute accordingly before evaluating learned graph. - # individual_modes = np.zeros((n_modes, time_len, lat, lon)) - # for k in range(n_modes): - # latent_data_bis = np.zeros(latent_data.shape) - # latent_data_bis[:, k] = latent_data[:, k] - # inverse_varimax = dot(latent_data_bis, np.linalg.pinv(varimax_rotation)) - # reverted_data = pca_model.inverse_transform(inverse_varimax) - # individual_modes[k] = reverted_data.reshape((-1, lat, lon)) - # individual_modes = individual_modes.std(1) - # individual_modes -= individual_modes.mean() - # individual_modes /= individual_modes.std() - - # permutation_list = ((modes_gt[:, None] - individual_modes[None]) ** 2).sum((2, 3)).argmin(1) - - # # Get adjacency matrix from PCMCI graph - # graph = results["graph"] - # graph[ - # results["val_matrix"] - # < np.abs(results["val_matrix"].flatten()[results["val_matrix"].flatten().argsort()[::-1][n_gt_connections - 1]]) - # ] = "" - - # adj_matrix_inferred = np.zeros((tau, n_modes, n_modes)) - # for k in range(n_modes): - # graph_k = graph[k] - # for j in range(n_modes): - # adj_matrix_inferred[:, k, j] = graph_k[j][1:] == "-->" - - # for k in range(tau): - # adj_matrix_inferred[k] = adj_matrix_inferred[k][np.ix_(permutation_list, permutation_list)] - # adj_matrix_inferred = adj_matrix_inferred.transpose((0, 2, 1)) - - # Find the permutation + # Find the permutation modes_inferred = modes_inferred.reshape((lat, lon, modes_inferred.shape[-1])).transpose((2, 0, 1)) # Get the flat index of the maximum for each mode @@ -243,15 +232,15 @@ def varimax(Phi, gamma=1, q=20, tol=1e-6): # Compute error matrix using squared Euclidean distance between indices which yields an (n_modes x n_modes) matrix permutation_list = ((idx_gt[:, None, :] - idx_inferred[None, :, :]) ** 2).sum(axis=2).argmin(axis=1) - print("permutation_list:", permutation_list) + logger.info("permutation_list: %s", permutation_list) - # Permute + # Permute for k in range(tau): adj_w[k] = adj_w[k][np.ix_(permutation_list, permutation_list)] - print("PERMUTED THE MATRICES") + logger.info("PERMUTED THE MATRICES") precision, recall, f1, shd = evaluate_adjacency_matrix(adj_w, adj_gt, 0.9) - print(f"difficuly {difficulty} results:") - print(f"Precision: {precision}, Recall: {recall}, F1 Score: {f1}, SHD: {shd}") + logger.info("difficulty %s results:", difficulty) + logger.info("Precision: %s, Recall: %s, F1 Score: %s, SHD: %s", precision, recall, f1, shd)