diff --git a/esmvalcore/cmor/_fixes/fix.py b/esmvalcore/cmor/_fixes/fix.py index ac25a3a762..7f7ce98aac 100644 --- a/esmvalcore/cmor/_fixes/fix.py +++ b/esmvalcore/cmor/_fixes/fix.py @@ -48,6 +48,16 @@ class Fix: """Base class for dataset fixes.""" + GROUP_CUBES_BY_DATE = False + """Flag for grouping cubes for fix_metadata. + + Fixes are applied to each group element individually. + + If ``False`` (default), group cubes by file. If ``True``, group cubes by + date. + + """ + def __init__( self, vardef: VariableInfo, diff --git a/esmvalcore/cmor/_fixes/native6/era5.py b/esmvalcore/cmor/_fixes/native6/era5.py index 1f26b27138..c5718bade0 100644 --- a/esmvalcore/cmor/_fixes/native6/era5.py +++ b/esmvalcore/cmor/_fixes/native6/era5.py @@ -3,6 +3,7 @@ import datetime import logging +import iris import numpy as np from cf_units import Unit from iris.cube import CubeList @@ -433,6 +434,44 @@ def fix_metadata(self, cubes): return cubes +class Rsut(Fix): + """Fixes for rsut.""" + + # Enable grouping cubes by date for fix_metadata since multiple variables + # from multiple files are needed + GROUP_CUBES_BY_DATE = True + + def fix_metadata(self, cubes): + """Fix metadata. + + Derive rsut as + + rsut = rsdt - rsnt + + with + + rsut = TOA Outgoing Shortwave Radiation + rsdt = TOA Incoming Shortwave Radiation + rsnt = TOA Net Incoming Shortwave Radiation + + """ + rsdt_cube = cubes.extract_cube( + iris.NameConstraint(long_name="TOA incident solar radiation"), + ) + rsnt_cube = cubes.extract_cube( + iris.NameConstraint( + long_name="Mean top net short-wave radiation flux", + ), + ) + rsdt_cube = Rsdt(None).fix_metadata([rsdt_cube])[0] + rsdt_cube.convert_units(self.vardef.units) + + rsdt_cube.data = rsdt_cube.core_data() - rsnt_cube.core_data() + rsdt_cube.attributes["positive"] = "up" + + return iris.cube.CubeList([rsdt_cube]) + + class Rss(Fix): """Fixes for Rss.""" diff --git a/esmvalcore/cmor/fix.py b/esmvalcore/cmor/fix.py index ef23022846..c0fd96bf11 100644 --- a/esmvalcore/cmor/fix.py +++ b/esmvalcore/cmor/fix.py @@ -16,9 +16,11 @@ from iris.cube import Cube, CubeList from esmvalcore.cmor._fixes.fix import Fix -from esmvalcore.io.local import LocalFile +from esmvalcore.io.local import LocalFile, _get_start_end_date if TYPE_CHECKING: + from collections.abc import Iterable + from esmvalcore.config import Session logger = logging.getLogger(__name__) @@ -129,6 +131,27 @@ def fix_file( # noqa: PLR0913 return result +def _group_cubes(fixes: Iterable[Fix], cubes: CubeList) -> dict[Any, CubeList]: + """Group cubes for fix_metadata; each group is processed individually.""" + grouped_cubes: dict[Any, CubeList] = defaultdict(CubeList) + + # Group by date + if any(fix.GROUP_CUBES_BY_DATE for fix in fixes): + for cube in cubes: + if "source_file" in cube.attributes: + dates = _get_start_end_date(cube.attributes["source_file"]) + else: + dates = None + grouped_cubes[dates].append(cube) + + # Group by file name + else: + for cube in cubes: + grouped_cubes[cube.attributes.get("source_file", "")].append(cube) + + return grouped_cubes + + def fix_metadata( cubes: Sequence[Cube], short_name: str, @@ -192,13 +215,14 @@ def fix_metadata( ) fixed_cubes = CubeList() - # Group cubes by input file and apply all fixes to each group element - # (i.e., each file) individually - by_file = defaultdict(list) - for cube in cubes: - by_file[cube.attributes.get("source_file", "")].append(cube) - - for group in by_file.values(): + # Group cubes and apply all fixes to each group element individually. There + # are two options for grouping: + # (1) By input file name (default). + # (2) By time range (can be enabled by setting the attribute + # GROUP_CUBES_BY_DATE=True for the fix class; see + # _fixes.native6.era5.Rsut for an example). + grouped_cubes = _group_cubes(fixes, cubes) + for group in grouped_cubes.values(): cube_list = CubeList(group) for fix in fixes: cube_list = fix.fix_metadata(cube_list) diff --git a/tests/unit/cmor/test_fix.py b/tests/unit/cmor/test_fix.py index 80f37f32c8..56c5e26054 100644 --- a/tests/unit/cmor/test_fix.py +++ b/tests/unit/cmor/test_fix.py @@ -189,6 +189,7 @@ def setUp(self): self.cube = self._create_mock_cube() self.fixed_cube = self._create_mock_cube() self.mock_fix = Mock() + self.mock_fix.GROUP_CUBES_BY_DATE = False self.mock_fix.fix_metadata.return_value = [self.fixed_cube] self.expected_get_fixes_call = { "project": "project",