-
Notifications
You must be signed in to change notification settings - Fork 0
Support hourly data #54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 33 commits
e491bcf
8066c9a
bd9c771
3874ba6
0dd47b0
615a43a
2e1a57f
ee0626e
a9d519d
33befb3
87405cc
977ff25
74bd3a9
6b33b05
84c0603
146cb48
5e8707e
d08c9d7
85d403d
21ce834
7452cd7
afc34ed
34096e1
0ccb9ef
9cdc177
b1d1428
1e7f9bd
832634e
ee82421
d65eca2
207fdac
31be9d1
1081822
d2c07a6
56181dc
c69694a
eb24a91
b983280
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| import warnings | ||
|
|
||
| import numpy as np | ||
| from .utils import add_month_day_dims, calc_stats | ||
| from .utils import add_month_day_dims, calc_stats, add_month_hour_dims | ||
| from .geo_embedding_utils import ( | ||
| calculate_sh_geo_pos_embeddings, | ||
| compute_patch_geo_pos_embedding, | ||
|
|
@@ -28,7 +28,21 @@ def __init__( | |
| sh_pos_table: str = None, # Optional; str formatted path to precomputed table of sh | ||
| sh_embed_dim: int = 96, # sh_embed_dim should <= (sh_order_L + 1)**2 | ||
| sh_order_L: int = 10, | ||
| is_hourly: bool = False, | ||
| ): | ||
| """Initialize the dataset with daily and monthly data, and optional land mask. | ||
|
|
||
| Args: | ||
| daily_da: xarray DataArray with daily data (M, time, H, W) | ||
| monthly_da: xarray DataArray with monthly data (M, H, W) | ||
| land_mask: Optional xarray DataArray with land mask (H, W) or (1, H, W) | ||
| time_dim: Name of the time dimension in the input data | ||
| spatial_dims: Tuple of (lat_dim, lon_dim) names in the input data | ||
| patch_size: Tuple of (patch_height, patch_width) in pixels | ||
| stride: Tuple of (stride_height, stride_width) in pixels. If None, defaults to patch_size (non-overlapping patches). | ||
| is_hourly: Whether the daily data is hourly (T=31*24) or daily (T=31). | ||
|
|
||
| """ | ||
| self.spatial_dims = spatial_dims | ||
| self.patch_size = patch_size | ||
| self.daily_da = daily_da | ||
|
|
@@ -53,46 +67,55 @@ def __init__( | |
| f"Patch size {patch_size} is larger than data dimensions {daily_da.sizes}" | ||
| ) | ||
|
|
||
| # Reshape daily → (M, T=31, H, W), monthly → (M, H, W), | ||
| # and get padded_days_mask → (M, T=31) | ||
| daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_day_dims( | ||
| daily_da, monthly_da, time_dim=time_dim | ||
| ) | ||
| if is_hourly: | ||
| # hours_per_day == 24 | ||
| # Reshape daily → (M, T=31*24, H, W), monthly → (M, H, W), | ||
| # and get padded_days_mask → (M, T=31*24) | ||
| daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_hour_dims( | ||
| daily_da, monthly_da, time_dim=time_dim | ||
| ) | ||
| else: | ||
| # Reshape daily → (M, T=31, H, W), monthly → (M, H, W), | ||
| # and get padded_days_mask → (M, T=31) | ||
| daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_day_dims( | ||
| daily_da, monthly_da, time_dim=time_dim | ||
| ) | ||
|
|
||
| # Convert to numpy once — all __getitem__ calls use these | ||
| self.daily_np = daily_mt.to_numpy().copy().astype(np.float32) # (M, T=31, H, W) float | ||
| self.monthly_np = monthly_m.to_numpy().copy().astype(np.float32) # (M, H, W) float | ||
| self.padded_mask_np = padded_days_mask.to_numpy().copy() # (M, T=31) bool | ||
| self.daily_timef_np = daily_timef.to_numpy().copy().astype(np.float32) # (M,T=31, 4) | ||
| self.daily_t = torch.from_numpy(daily_mt.values.astype(np.float32)) # (M, T=31, H, W) | ||
| self.monthly_t = torch.from_numpy(monthly_m.values.astype(np.float32)) # (M, H, W) | ||
| self.padded_days_tensor = torch.from_numpy(padded_days_mask.values.copy()).bool() # (M, T=31) | ||
| self.daily_timef_t = torch.from_numpy(daily_timef.values.astype(np.float32)) # (M, T=31, 4) | ||
|
|
||
| # Store coordinate arrays | ||
| self.lat_coords = daily_da[spatial_dims[0]].to_numpy().copy() | ||
| self.lon_coords = daily_da[spatial_dims[1]].to_numpy().copy() | ||
|
|
||
| if land_mask is not None: | ||
| lm = land_mask.to_numpy().copy() | ||
| lm = torch.from_numpy(land_mask.values.copy()).bool() | ||
| if lm.ndim == 3: | ||
| lm = lm.squeeze(0) # (1, H, W) → (H, W) | ||
| self.land_mask_np = lm | ||
| self.land_mask_t = lm | ||
| else: | ||
| self.land_mask_np = None | ||
| self.land_mask_t = None | ||
|
|
||
| # Precompute the NaN mask before filling NaNs | ||
| # daily_mask: True where NaN (i.e. missing ocean data, not land) | ||
| self.daily_nan_mask = np.isnan(self.daily_np) # (M, T=31, H, W) | ||
| self.daily_nan_mask = torch.isnan(self.daily_t) # (M, T=31, H, W) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. adding the _t suffix to tensors for clarity as you've done above is a really good idea. Should we adopt that consistently, i.e. also for the daily_nan_mask? Or do you think that is not needed?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed, it is fixed. |
||
|
|
||
| # NaNs will be filled with 0 in-place | ||
| np.nan_to_num(self.daily_np, copy=False, nan=0.0) | ||
| self.daily_t.nan_to_num_(nan=0.0) | ||
|
|
||
| # Stats will be set later via set_stats() for train/test datasets | ||
| self.daily_mean = None | ||
| self.daily_std = None | ||
|
|
||
| # Precompute padded_days_mask as a tensor (same for all patches) | ||
| self.padded_days_tensor = torch.from_numpy(self.padded_mask_np).bool() | ||
| # Pre-build zero land tensor for the no-mask case | ||
| ph, pw = self.patch_size | ||
| self._zero_land = torch.zeros(ph, pw, dtype=torch.bool) | ||
|
|
||
| # Precompute lazy index mapping for patches | ||
| H, W = self.daily_np.shape[2], self.daily_np.shape[3] | ||
| H, W = self.daily_t.shape[2], self.daily_t.shape[3] | ||
| self.patch_indices = self._compute_patch_indices(H, W) | ||
|
|
||
| # Precompute geoposition and scale embeddings for patches | ||
|
|
@@ -101,6 +124,9 @@ def __init__( | |
| self.patch_geo_embeddings, self.patch_scale_features = ( | ||
| self._compute_geoscalepatch_embeddings() | ||
| ) | ||
| self.scale_f_dim = torch.tensor(self.patch_scale_features.shape[-1]) | ||
| self.sh_embed_dim_t = torch.tensor(self.sh_embed_dim) | ||
| self.harmonic_order_t = torch.tensor(self.sh_order_L) | ||
|
|
||
| def _get_geo_pos(self, sh_pos_table: str): | ||
| """Calculate or retrieve spherical harmonics based geo position embeddings.""" | ||
|
|
@@ -205,33 +231,19 @@ def __getitem__(self, idx): | |
| ph, pw = self.patch_size | ||
|
|
||
| # Extract spatial patch via numpy slicing — faster than xarray indexing | ||
| daily_patch = self.daily_np[ | ||
| :, :, i : i + ph, j : j + pw | ||
| ] # (M, T, H, W) -> (M,T,pH, pW) | ||
| monthly_patch = self.monthly_np[ | ||
| :, i : i + ph, j : j + pw | ||
| ] # (M, H, W) -> (M, pH, pW) | ||
| daily_nan_mask = self.daily_nan_mask[ | ||
| :, :, i : i + ph, j : j + pw | ||
| ] # (M, T, H, W) -> (M, T, pH, pW) | ||
|
|
||
| if self.land_mask_np is not None: | ||
| land_patch = self.land_mask_np[i : i + ph, j : j + pw] # (H, W) | ||
| land_tensor = torch.from_numpy(np.ascontiguousarray(land_patch)).bool() | ||
| else: | ||
| land_tensor = torch.zeros(ph, pw, dtype=torch.bool) | ||
| # (M, T, H, W) -> (M,T,pH, pW) | ||
| daily_tensor = self.daily_t[:, :, i : i + ph, j : j + pw].unsqueeze(0) | ||
|
|
||
| # geo_pos_tensor = self.sh_geo_pos[i: i + ph, j: j + pw] # (H,W, sh_emb_dim) -> (pH, pW, sh_embed_dim) | ||
| # (M, H, W) -> (M, pH, pW) | ||
| monthly_tensor = self.monthly_t[:, i : i + ph, j : j + pw] | ||
|
|
||
| # Convert to tensors (from_numpy is zero-copy on contiguous arrays) | ||
| # (1, M, T, H, W) | ||
| daily_tensor = torch.from_numpy(daily_patch).unsqueeze(0) | ||
| # (M, H, W) | ||
| monthly_tensor = torch.from_numpy(monthly_patch) | ||
| # (1, M, T, H, W) | ||
| daily_nan_mask = torch.from_numpy(daily_nan_mask).unsqueeze(0) | ||
| # ( M, T, 2) | ||
| daily_timef_tensor = torch.from_numpy(self.daily_timef_np) | ||
| # (M, T, H, W) -> (M, T, pH, pW) | ||
| daily_nan_mask = self.daily_nan_mask[:, :, i : i + ph, j : j + pw].unsqueeze(0) | ||
|
|
||
| if self.land_mask_t is not None: | ||
| land_tensor = self.land_mask_t[i : i + ph, j : j + pw] # (H, W) | ||
| else: | ||
| land_tensor = self._zero_land | ||
|
|
||
| # daily_mask: NaN locations that are NOT land | ||
| # Reshape land_tensor for broadcasting: (pH, pW) → (1, 1, 1, pH, pW) | ||
|
|
@@ -249,25 +261,20 @@ def __getitem__(self, idx): | |
| # get scale feature for patch | ||
| scale_feature_tensor = self.patch_scale_features[idx] # (10,) | ||
|
|
||
| # create tensors to pass sh embedding dimension, harmonic order, and scale feature dim | ||
| sh_embed_dim = torch.tensor(self.sh_embed_dim) | ||
| harmonic_order = torch.tensor(self.sh_order_L) | ||
| scale_f_dim = torch.tensor(len(scale_feature_tensor)) | ||
|
|
||
| # Convert to tensors | ||
| return { | ||
| "daily_patch": daily_tensor, # (C=1, M, T=31, pH, pW) | ||
| "monthly_patch": monthly_tensor, # (M, pH, pW) | ||
| "daily_mask_patch": daily_mask_tensor, # (C=1, M, T=31, pH, pW) | ||
| "land_mask_patch": land_tensor, # (pH,pW) True=Land | ||
| "daily_timef_patch": daily_timef_tensor, # (M, T=31, 2) | ||
| "daily_timef_patch": self.daily_timef_t, # (M, T=31, 2) | ||
| "padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded | ||
| "scale_feature_patch": scale_feature_tensor, # (10,) | ||
| "geo_pos_embedding_patch": geo_pos_embedding_tensor, # (sh_embed_dim,) | ||
| "sh_embed_dim": sh_embed_dim, | ||
| "harmonic_order": harmonic_order, | ||
| "scale_f_dim": scale_f_dim, | ||
| "coords": (i, j), | ||
| "sh_embed_dim": self.sh_embed_dim_t, | ||
| "harmonic_order": self.harmonic_order_t, | ||
| "scale_f_dim": self.scale_f_dim, | ||
| "coords": torch.tensor([i, j]), | ||
| "lat_patch": lat_patch, # (pH,) | ||
| "lon_patch": lon_patch, # (pW,) | ||
| } | ||
|
|
@@ -282,14 +289,14 @@ def compute_stats(self, indices: list = None) -> Tuple[np.ndarray, np.ndarray]: | |
| Tuple of (mean, std) arrays | ||
| """ | ||
| if indices is None: | ||
| data = self.monthly_np # (M, H, W) | ||
| data = self.monthly_t.numpy() # (M, H, W) | ||
| else: | ||
| # Stack selected spatial patches | ||
| ph, pw = self.patch_size | ||
| patches = [] | ||
| for idx in indices: | ||
| i, j = self.patch_indices[idx] | ||
| patch = self.monthly_np[:, i : i + ph, j : j + pw] | ||
| patch = self.monthly_t[:, i : i + ph, j : j + pw].numpy() | ||
| patches.append(patch) | ||
| data = np.concatenate(patches, axis=-1) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably update tis comment as the conversion format is now a torch tensor rather than numpy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.