Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e491bcf
add a util function for hourly data
SarahAlidoost Jun 10, 2026
8066c9a
fix minor docstrings
SarahAlidoost Jun 10, 2026
bd9c771
Merge branch 'main' into support_hourly
SarahAlidoost Jun 10, 2026
3874ba6
fix encoder_decoder
SarahAlidoost Jun 12, 2026
0dd47b0
update nb
SarahAlidoost Jun 13, 2026
615a43a
remove unused argument from model
SarahAlidoost Jun 17, 2026
2e1a57f
remove setting device from the model class
SarahAlidoost Jun 17, 2026
ee0626e
remove unused arg
SarahAlidoost Jun 17, 2026
a9d519d
refcator dataset
SarahAlidoost Jun 18, 2026
33befb3
fix a bug
SarahAlidoost Jun 18, 2026
87405cc
refactor the model
SarahAlidoost Jun 18, 2026
977ff25
bring nb from main
SarahAlidoost Jun 22, 2026
74bd3a9
bring nb from main
SarahAlidoost Jun 22, 2026
6b33b05
Merge branch 'main' into support_hourly
SarahAlidoost Jun 22, 2026
84c0603
add doc about jupyter hub
SarahAlidoost Jun 22, 2026
146cb48
remove some of dropouts
SarahAlidoost Jun 23, 2026
5e8707e
remove permute
SarahAlidoost Jun 23, 2026
d08c9d7
Merge branch 'main' into support_hourly
SarahAlidoost Jun 24, 2026
85d403d
make batch tensor
SarahAlidoost Jun 24, 2026
21ce834
set batch device
SarahAlidoost Jun 24, 2026
7452cd7
fix the title of the histogram
SarahAlidoost Jun 24, 2026
afc34ed
remove to dtype in forward
SarahAlidoost Jun 24, 2026
34096e1
refactor model
SarahAlidoost Jun 24, 2026
0ccb9ef
refactor temporal agg
SarahAlidoost Jun 24, 2026
9cdc177
refactor encoder
SarahAlidoost Jun 24, 2026
b1d1428
use checkpoint in forward of the model
SarahAlidoost Jun 24, 2026
1e7f9bd
improve predict and train
SarahAlidoost Jun 25, 2026
832634e
improve encoder
SarahAlidoost Jun 25, 2026
ee82421
improve train
SarahAlidoost Jun 25, 2026
d65eca2
add notebook for hourly data
SarahAlidoost Jun 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ X_mixed (MonthlyConvDecoder)---------> Output
We explain the model architecture in more detail in the [code and math
description](docs/code_math_description.md) document.

## Using HPC (Levante)

If you have access to the Levante HPC cluster, you can run the workflow using
`slurm` or `jupyter hub`. Please refer to the [Levante usage
instructions](docs/levante_usage.md) for more details.

## References

- [Attention is all you need](https://doi.org/10.48550/arXiv.1706.03762)
Expand Down
119 changes: 63 additions & 56 deletions climanet/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings

import numpy as np
from .utils import add_month_day_dims, calc_stats
from .utils import add_month_day_dims, calc_stats, add_month_hour_dims
from .geo_embedding_utils import (
calculate_sh_geo_pos_embeddings,
compute_patch_geo_pos_embedding,
Expand All @@ -28,7 +28,21 @@ def __init__(
sh_pos_table: str = None, # Optional; str formatted path to precomputed table of sh
sh_embed_dim: int = 96, # sh_embed_dim should <= (sh_order_L + 1)**2
sh_order_L: int = 10,
is_hourly: bool = False,
):
"""Initialize the dataset with daily and monthly data, and optional land mask.

Args:
daily_da: xarray DataArray with daily data (M, time, H, W)
monthly_da: xarray DataArray with monthly data (M, H, W)
land_mask: Optional xarray DataArray with land mask (H, W) or (1, H, W)
time_dim: Name of the time dimension in the input data
spatial_dims: Tuple of (lat_dim, lon_dim) names in the input data
patch_size: Tuple of (patch_height, patch_width) in pixels
stride: Tuple of (stride_height, stride_width) in pixels. If None, defaults to patch_size (non-overlapping patches).
is_hourly: Whether the daily data is hourly (T=31*24) or daily (T=31).

"""
self.spatial_dims = spatial_dims
self.patch_size = patch_size
self.daily_da = daily_da
Expand All @@ -53,46 +67,55 @@ def __init__(
f"Patch size {patch_size} is larger than data dimensions {daily_da.sizes}"
)

# Reshape daily → (M, T=31, H, W), monthly → (M, H, W),
# and get padded_days_mask → (M, T=31)
daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_day_dims(
daily_da, monthly_da, time_dim=time_dim
)
if is_hourly:
# hours_per_day == 24
# Reshape daily → (M, T=31*24, H, W), monthly → (M, H, W),
# and get padded_days_mask → (M, T=31*24)
daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_hour_dims(
daily_da, monthly_da, time_dim=time_dim
)
else:
# Reshape daily → (M, T=31, H, W), monthly → (M, H, W),
# and get padded_days_mask → (M, T=31)
daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_day_dims(
daily_da, monthly_da, time_dim=time_dim
)

# Convert to numpy once — all __getitem__ calls use these
self.daily_np = daily_mt.to_numpy().copy().astype(np.float32) # (M, T=31, H, W) float
self.monthly_np = monthly_m.to_numpy().copy().astype(np.float32) # (M, H, W) float
self.padded_mask_np = padded_days_mask.to_numpy().copy() # (M, T=31) bool
self.daily_timef_np = daily_timef.to_numpy().copy().astype(np.float32) # (M,T=31, 4)
self.daily_t = torch.from_numpy(daily_mt.values.astype(np.float32)) # (M, T=31, H, W)
self.monthly_t = torch.from_numpy(monthly_m.values.astype(np.float32)) # (M, H, W)
self.padded_days_tensor = torch.from_numpy(padded_days_mask.values.copy()).bool() # (M, T=31)
self.daily_timef_t = torch.from_numpy(daily_timef.values.astype(np.float32)) # (M, T=31, 4)

# Store coordinate arrays
self.lat_coords = daily_da[spatial_dims[0]].to_numpy().copy()
self.lon_coords = daily_da[spatial_dims[1]].to_numpy().copy()

if land_mask is not None:
lm = land_mask.to_numpy().copy()
lm = torch.from_numpy(land_mask.values.copy()).bool()
if lm.ndim == 3:
lm = lm.squeeze(0) # (1, H, W) → (H, W)
self.land_mask_np = lm
self.land_mask_t = lm
else:
self.land_mask_np = None
self.land_mask_t = None

# Precompute the NaN mask before filling NaNs
# daily_mask: True where NaN (i.e. missing ocean data, not land)
self.daily_nan_mask = np.isnan(self.daily_np) # (M, T=31, H, W)
self.daily_nan_mask = torch.isnan(self.daily_t) # (M, T=31, H, W)

# NaNs will be filled with 0 in-place
np.nan_to_num(self.daily_np, copy=False, nan=0.0)
self.daily_t.nan_to_num_(nan=0.0)

# Stats will be set later via set_stats() for train/test datasets
self.daily_mean = None
self.daily_std = None

# Precompute padded_days_mask as a tensor (same for all patches)
self.padded_days_tensor = torch.from_numpy(self.padded_mask_np).bool()
# Pre-build zero land tensor for the no-mask case
ph, pw = self.patch_size
self._zero_land = torch.zeros(ph, pw, dtype=torch.bool)

# Precompute lazy index mapping for patches
H, W = self.daily_np.shape[2], self.daily_np.shape[3]
H, W = self.daily_t.shape[2], self.daily_t.shape[3]
self.patch_indices = self._compute_patch_indices(H, W)

# Precompute geoposition and scale embeddings for patches
Expand All @@ -101,6 +124,9 @@ def __init__(
self.patch_geo_embeddings, self.patch_scale_features = (
self._compute_geoscalepatch_embeddings()
)
self.scale_f_dim = torch.tensor(self.patch_scale_features.shape[-1])
self.sh_embed_dim_t = torch.tensor(self.sh_embed_dim)
self.harmonic_order_t = torch.tensor(self.sh_order_L)

def _get_geo_pos(self, sh_pos_table: str):
"""Calculate or retrieve spherical harmonics based geo position embeddings."""
Expand Down Expand Up @@ -205,33 +231,19 @@ def __getitem__(self, idx):
ph, pw = self.patch_size

# Extract spatial patch via numpy slicing — faster than xarray indexing
daily_patch = self.daily_np[
:, :, i : i + ph, j : j + pw
] # (M, T, H, W) -> (M,T,pH, pW)
monthly_patch = self.monthly_np[
:, i : i + ph, j : j + pw
] # (M, H, W) -> (M, pH, pW)
daily_nan_mask = self.daily_nan_mask[
:, :, i : i + ph, j : j + pw
] # (M, T, H, W) -> (M, T, pH, pW)

if self.land_mask_np is not None:
land_patch = self.land_mask_np[i : i + ph, j : j + pw] # (H, W)
land_tensor = torch.from_numpy(np.ascontiguousarray(land_patch)).bool()
else:
land_tensor = torch.zeros(ph, pw, dtype=torch.bool)
# (M, T, H, W) -> (M,T,pH, pW)
daily_tensor = self.daily_t[:, :, i : i + ph, j : j + pw ].unsqueeze(0)

# geo_pos_tensor = self.sh_geo_pos[i: i + ph, j: j + pw] # (H,W, sh_emb_dim) -> (pH, pW, sh_embed_dim)
# (M, H, W) -> (M, pH, pW)
monthly_tensor = self.monthly_t[:, i : i + ph, j : j + pw]

# Convert to tensors (from_numpy is zero-copy on contiguous arrays)
# (1, M, T, H, W)
daily_tensor = torch.from_numpy(daily_patch).unsqueeze(0)
# (M, H, W)
monthly_tensor = torch.from_numpy(monthly_patch)
# (1, M, T, H, W)
daily_nan_mask = torch.from_numpy(daily_nan_mask).unsqueeze(0)
# ( M, T, 2)
daily_timef_tensor = torch.from_numpy(self.daily_timef_np)
# (M, T, H, W) -> (M, T, pH, pW)
daily_nan_mask = self.daily_nan_mask[:, :, i : i + ph, j : j + pw].unsqueeze(0)

if self.land_mask_t is not None:
land_tensor = self.land_mask_t[i : i + ph, j : j + pw] # (H, W)
else:
land_tensor = self._zero_land

# daily_mask: NaN locations that are NOT land
# Reshape land_tensor for broadcasting: (pH, pW) → (1, 1, 1, pH, pW)
Expand All @@ -249,25 +261,20 @@ def __getitem__(self, idx):
# get scale feature for patch
scale_feature_tensor = self.patch_scale_features[idx] # (10,)

# create tensors to pass sh embedding dimension, harmonic order, and scale feature dim
sh_embed_dim = torch.tensor(self.sh_embed_dim)
harmonic_order = torch.tensor(self.sh_order_L)
scale_f_dim = torch.tensor(len(scale_feature_tensor))

# Convert to tensors
return {
return {
"daily_patch": daily_tensor, # (C=1, M, T=31, pH, pW)
"monthly_patch": monthly_tensor, # (M, pH, pW)
"daily_mask_patch": daily_mask_tensor, # (C=1, M, T=31, pH, pW)
"land_mask_patch": land_tensor, # (pH,pW) True=Land
"daily_timef_patch": daily_timef_tensor, # (M, T=31, 2)
"daily_timef_patch": self.daily_timef_t, # (M, T=31, 2)
"padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded
"scale_feature_patch": scale_feature_tensor, # (10,)
"geo_pos_embedding_patch": geo_pos_embedding_tensor, # (sh_embed_dim,)
"sh_embed_dim": sh_embed_dim,
"harmonic_order": harmonic_order,
"scale_f_dim": scale_f_dim,
"coords": (i, j),
"sh_embed_dim": self.sh_embed_dim_t,
"harmonic_order": self.harmonic_order_t,
"scale_f_dim": self.scale_f_dim,
"coords": torch.tensor([i, j]),
"lat_patch": lat_patch, # (pH,)
"lon_patch": lon_patch, # (pW,)
}
Expand All @@ -282,14 +289,14 @@ def compute_stats(self, indices: list = None) -> Tuple[np.ndarray, np.ndarray]:
Tuple of (mean, std) arrays
"""
if indices is None:
data = self.monthly_np # (M, H, W)
data = self.monthly_t.numpy() # (M, H, W)
else:
# Stack selected spatial patches
ph, pw = self.patch_size
patches = []
for idx in indices:
i, j = self.patch_indices[idx]
patch = self.monthly_np[:, i : i + ph, j : j + pw]
patch = self.monthly_t[:, i : i + ph, j : j + pw].numpy()
patches.append(patch)
data = np.concatenate(patches, axis=-1)

Expand Down
36 changes: 20 additions & 16 deletions climanet/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def predict_monthly_var(
run_dir: str = ".",
verbose: bool = True,
dataloader_num_workers: int = 2,
predict_threads: int | None = None,
):
"""
Predicts monthly variable values using a trained model and a provided dataset.
Expand Down Expand Up @@ -107,37 +106,42 @@ def predict_monthly_var(
# Initialize an empty list to store predictions
base_dataset = dataset.dataset if hasattr(dataset, "dataset") else dataset

M = base_dataset.monthly_np.shape[0]
M = base_dataset.monthly_t.shape[0]
H, W = base_dataset.patch_size
all_predictions = torch.empty(len(dataset), M, H, W)
all_predictions = torch.empty(len(dataset), M, H, W, device=device)

# Set up logging
writer = setup_logging(run_dir)

with torch.no_grad():
with torch.inference_mode():
idx = 0
average_loss = 0.0
for i, batch in enumerate(dataloader):
# Move batch to the appropriate device
batch = {
k: v.to(device, non_blocking=use_cuda)
for k, v in batch.items()
}

predictions = model(
batch["daily_patch"].to(device, non_blocking=use_cuda),
batch["daily_mask_patch"].to(device, non_blocking=use_cuda),
batch["daily_timef_patch"].to(device, non_blocking=use_cuda),
batch["land_mask_patch"].to(device, non_blocking=use_cuda),
batch["geo_pos_embedding_patch"].to(device, non_blocking=use_cuda),
batch["scale_feature_patch"].to(device, non_blocking=use_cuda),
batch["padded_days_mask"].to(device, non_blocking=use_cuda),
batch["daily_patch"],
batch["daily_mask_patch"],
batch["daily_timef_patch"],
batch["land_mask_patch"],
batch["geo_pos_embedding_patch"],
batch["scale_feature_patch"],
batch["padded_days_mask"],
)

# Compute masked loss
loss = compute_masked_loss(
predictions,
batch["monthly_patch"].to(device, non_blocking=use_cuda),
batch["land_mask_patch"].to(device, non_blocking=use_cuda),
batch["monthly_patch"],
batch["land_mask_patch"],
)
average_loss += loss.item()
average_loss += loss.detach()

all_predictions[idx : idx + predictions.size(0)] = predictions.cpu()
all_predictions[idx : idx + predictions.size(0)] = predictions.detach()
idx += predictions.size(0)

if verbose:
Expand All @@ -147,7 +151,7 @@ def predict_monthly_var(

writer.add_scalar("Progress/Batch", i + 1, idx)

average_loss = average_loss / len(dataloader)
average_loss = average_loss.item() / len(dataloader)

if verbose:
print(f"Average loss over all batches: {average_loss:.4f}")
Expand Down
Loading
Loading