Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
*__pycache__*
profsea.egg-info
data/
profsea/profsea-assets/

# Ignore the documentation build directory
docs/build
*.DS_Store
Expand Down
25 changes: 22 additions & 3 deletions profsea/components/spatial/sterodynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,23 @@ def _load_CMIP6_slopes(self) -> da.Array:
# Concatenate along a new dimension (representing the ensemble/models)
slopes_stack = xr.concat(datasets, dim="model")

return slopes_stack
# Read land mask if present
mask_files = list(Path(self.patterns_dir).glob("*/zos_mask_ssp585_*.nc"))

if mask_files:
self.land_mask_present = True

datasets_mask = [
xr.open_dataset(f, chunks={"lat": 45, "lon": 45})["zos_mask"] for f in mask_files
]
mask_stack = xr.concat(datasets_mask, dim="model")
mask_stack = mask_stack.sum(dim='model', skipna=True)

else:
self.land_mask_present = False
mask_stack = None

return slopes_stack, mask_stack

def _calc_expansion_contribution(
self, rng: np.random.Generator, state: ClimateState
Expand All @@ -99,7 +115,10 @@ def _calc_expansion_contribution(
A dask array of shape (members, years, lat, lon) containing the thermal expansion contribution to the sterodynamic component for each member and year.
"""
# Select slope coefficients based on the MIP
coeffs_da = self._load_CMIP6_slopes()
coeffs_da, mask_da = self._load_CMIP6_slopes()

if self.land_mask_present: # apply land mask
coeffs_da = coeffs_da.where(mask_da == 0.)

# Align the grid coordinates + interpolate if necessary
interp_da = interpolate_to_grid(coeffs_da, state.grid_lats, state.grid_lons)
Expand All @@ -112,7 +131,7 @@ def _calc_expansion_contribution(
return coeffs[rand_samples, :, :]
else:
# Calc pattern ensemble mean
mean_coeff = da.mean(coeffs, axis=0)
mean_coeff = da.nanmean(coeffs, axis=0)
return da.broadcast_to(
mean_coeff,
(state.n_members, state.grid_lats.shape[0], state.grid_lons.shape[0]),
Expand Down
36 changes: 31 additions & 5 deletions profsea/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
def sample_members_2D(array: np.ndarray, percentiles: list | np.ndarray) -> np.ndarray:
"""Sample real ensemble members from a 2D numpy array."""
# Caculate statistical timeseries, then match with closest real timeseries
array_percentiles = np.percentile(array, percentiles, axis=0)
array_percentiles = np.nanpercentile(array, percentiles, axis=0)
distances = cdist(array_percentiles, array)
mem_indices = np.argmin(distances, axis=1)
return array[mem_indices]
Expand All @@ -28,7 +28,6 @@ def interpolate(data: da.array, lats: int, lons: int) -> np.ndarray:
).data
return data_interp


def interpolate_to_grid(
data: xr.DataArray,
target_lats: np.ndarray,
Expand All @@ -41,15 +40,42 @@ def interpolate_to_grid(
"""
# Normalize source longitudes to [-180, 180) and sort monotonically
data = data.assign_coords(lon=(((data.lon + 180) % 360) - 180))
data = data.sortby("lon")
data = data.sortby(["lat", "lon"])

# Normalize target longitudes to [-180, 180) and sort
target_lons_norm = np.sort(((target_lons + 180) % 360) - 180)

# Interpolate!
data_interp = data.interp(
# Pad longitude with one points from each end to handle periodicity in zonal direction
data_padded = data.pad(lon=1, mode='wrap') # need more padding for higher-order interpolation
lon = data.lon.values
lon_padded = np.concatenate([[lon[-1] - 360], lon, [lon[0] + 360]])
data_padded['lon'] = lon_padded
data_padded = data_padded.sortby(["lat", "lon"])

# Now interpolate
data_padded = data_padded.chunk({"lat": -1, "lon": -1})
for dim in ["lat", "lon"]:
data_padded = data_padded.interpolate_na(
dim=dim, method="nearest",
) # this to handle nan values or land mask

data_interp = data_padded.interp(
lat=target_lats, lon=target_lons_norm, method=grid_interpolation
)

# Account for land mask (1 where ocean, 0 where land (NaN))
ocean_mask = (~data.isnull()).astype(float)
ocean_mask_padded = ocean_mask.pad(lon=1, mode='wrap')
ocean_mask_padded['lon'] = lon_padded
ocean_mask_padded = ocean_mask_padded.sortby(["lat", "lon"])
ocean_mask_padded = ocean_mask_padded.chunk({"lat": -1, "lon": -1})
ocean_mask_interp = ocean_mask_padded.interp(
lat=target_lats, lon=target_lons_norm, method="nearest"
)

data_interp = data_interp.where(ocean_mask_interp == 1)
data_interp = data_interp.chunk("auto")

return data_interp


Expand Down