diff --git a/.gitignore b/.gitignore index 8a9f4dc..5b04ab0 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ *__pycache__* profsea.egg-info data/ +profsea/profsea-assets/ + # Ignore the documentation build directory docs/build *.DS_Store diff --git a/profsea/components/spatial/sterodynamic.py b/profsea/components/spatial/sterodynamic.py index 03308db..74b969f 100644 --- a/profsea/components/spatial/sterodynamic.py +++ b/profsea/components/spatial/sterodynamic.py @@ -77,7 +77,23 @@ def _load_CMIP6_slopes(self) -> da.Array: # Concatenate along a new dimension (representing the ensemble/models) slopes_stack = xr.concat(datasets, dim="model") - return slopes_stack + # Read land mask if present + mask_files = list(Path(self.patterns_dir).glob("*/zos_mask_ssp585_*.nc")) + + if mask_files: + self.land_mask_present = True + + datasets_mask = [ + xr.open_dataset(f, chunks={"lat": 45, "lon": 45})["zos_mask"] for f in mask_files + ] + mask_stack = xr.concat(datasets_mask, dim="model") + mask_stack = mask_stack.sum(dim='model', skipna=True) + + else: + self.land_mask_present = False + mask_stack = None + + return slopes_stack, mask_stack def _calc_expansion_contribution( self, rng: np.random.Generator, state: ClimateState @@ -99,7 +115,10 @@ def _calc_expansion_contribution( A dask array of shape (members, years, lat, lon) containing the thermal expansion contribution to the sterodynamic component for each member and year. """ # Select slope coefficients based on the MIP - coeffs_da = self._load_CMIP6_slopes() + coeffs_da, mask_da = self._load_CMIP6_slopes() + + if self.land_mask_present: # apply land mask + coeffs_da = coeffs_da.where(mask_da == 0.) # Align the grid coordinates + interpolate if necessary interp_da = interpolate_to_grid(coeffs_da, state.grid_lats, state.grid_lons) @@ -112,7 +131,7 @@ def _calc_expansion_contribution( return coeffs[rand_samples, :, :] else: # Calc pattern ensemble mean - mean_coeff = da.mean(coeffs, axis=0) + mean_coeff = da.nanmean(coeffs, axis=0) return da.broadcast_to( mean_coeff, (state.n_members, state.grid_lats.shape[0], state.grid_lons.shape[0]), diff --git a/profsea/utils/utils.py b/profsea/utils/utils.py index 8657eda..d682f22 100644 --- a/profsea/utils/utils.py +++ b/profsea/utils/utils.py @@ -7,7 +7,7 @@ def sample_members_2D(array: np.ndarray, percentiles: list | np.ndarray) -> np.ndarray: """Sample real ensemble members from a 2D numpy array.""" # Caculate statistical timeseries, then match with closest real timeseries - array_percentiles = np.percentile(array, percentiles, axis=0) + array_percentiles = np.nanpercentile(array, percentiles, axis=0) distances = cdist(array_percentiles, array) mem_indices = np.argmin(distances, axis=1) return array[mem_indices] @@ -28,7 +28,6 @@ def interpolate(data: da.array, lats: int, lons: int) -> np.ndarray: ).data return data_interp - def interpolate_to_grid( data: xr.DataArray, target_lats: np.ndarray, @@ -41,15 +40,42 @@ def interpolate_to_grid( """ # Normalize source longitudes to [-180, 180) and sort monotonically data = data.assign_coords(lon=(((data.lon + 180) % 360) - 180)) - data = data.sortby("lon") + data = data.sortby(["lat", "lon"]) # Normalize target longitudes to [-180, 180) and sort target_lons_norm = np.sort(((target_lons + 180) % 360) - 180) - # Interpolate! - data_interp = data.interp( + # Pad longitude with one points from each end to handle periodicity in zonal direction + data_padded = data.pad(lon=1, mode='wrap') # need more padding for higher-order interpolation + lon = data.lon.values + lon_padded = np.concatenate([[lon[-1] - 360], lon, [lon[0] + 360]]) + data_padded['lon'] = lon_padded + data_padded = data_padded.sortby(["lat", "lon"]) + + # Now interpolate + data_padded = data_padded.chunk({"lat": -1, "lon": -1}) + for dim in ["lat", "lon"]: + data_padded = data_padded.interpolate_na( + dim=dim, method="nearest", + ) # this to handle nan values or land mask + + data_interp = data_padded.interp( lat=target_lats, lon=target_lons_norm, method=grid_interpolation ) + + # Account for land mask (1 where ocean, 0 where land (NaN)) + ocean_mask = (~data.isnull()).astype(float) + ocean_mask_padded = ocean_mask.pad(lon=1, mode='wrap') + ocean_mask_padded['lon'] = lon_padded + ocean_mask_padded = ocean_mask_padded.sortby(["lat", "lon"]) + ocean_mask_padded = ocean_mask_padded.chunk({"lat": -1, "lon": -1}) + ocean_mask_interp = ocean_mask_padded.interp( + lat=target_lats, lon=target_lons_norm, method="nearest" + ) + + data_interp = data_interp.where(ocean_mask_interp == 1) + data_interp = data_interp.chunk("auto") + return data_interp