Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions esmvalcore/cmor/_fixes/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@
class Fix:
"""Base class for dataset fixes."""

GROUP_CUBES_BY_DATE = False
"""Flag for grouping cubes for fix_metadata.

Fixes are applied to each group element individually.

If ``False`` (default), group cubes by file. If ``True``, group cubes by
date.

"""

def __init__(
self,
vardef: VariableInfo,
Expand Down
39 changes: 39 additions & 0 deletions esmvalcore/cmor/_fixes/native6/era5.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime
import logging

import iris
import numpy as np
from cf_units import Unit
from iris.cube import CubeList
Expand Down Expand Up @@ -433,6 +434,44 @@ def fix_metadata(self, cubes):
return cubes


class Rsut(Fix):
"""Fixes for rsut."""

# Enable grouping cubes by date for fix_metadata since multiple variables
# from multiple files are needed
GROUP_CUBES_BY_DATE = True

def fix_metadata(self, cubes):
"""Fix metadata.

Derive rsut as

rsut = rsdt - rsnt

with

rsut = TOA Outgoing Shortwave Radiation
rsdt = TOA Incoming Shortwave Radiation
rsnt = TOA Net Incoming Shortwave Radiation

"""
rsdt_cube = cubes.extract_cube(
iris.NameConstraint(long_name="TOA incident solar radiation"),
)
rsnt_cube = cubes.extract_cube(
iris.NameConstraint(
long_name="Mean top net short-wave radiation flux",
),
)
rsdt_cube = Rsdt(None).fix_metadata([rsdt_cube])[0]
rsdt_cube.convert_units(self.vardef.units)

rsdt_cube.data = rsdt_cube.core_data() - rsnt_cube.core_data()
rsdt_cube.attributes["positive"] = "up"

return iris.cube.CubeList([rsdt_cube])


class Rss(Fix):
"""Fixes for Rss."""

Expand Down
40 changes: 32 additions & 8 deletions esmvalcore/cmor/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from iris.cube import Cube, CubeList

from esmvalcore.cmor._fixes.fix import Fix
from esmvalcore.io.local import LocalFile
from esmvalcore.io.local import LocalFile, _get_start_end_date

if TYPE_CHECKING:
from collections.abc import Iterable

from esmvalcore.config import Session

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -129,6 +131,27 @@ def fix_file( # noqa: PLR0913
return result


def _group_cubes(fixes: Iterable[Fix], cubes: CubeList) -> dict[Any, CubeList]:
"""Group cubes for fix_metadata; each group is processed individually."""
grouped_cubes: dict[Any, CubeList] = defaultdict(CubeList)

# Group by date
if any(fix.GROUP_CUBES_BY_DATE for fix in fixes):
for cube in cubes:
if "source_file" in cube.attributes:
dates = _get_start_end_date(cube.attributes["source_file"])
else:
dates = None
grouped_cubes[dates].append(cube)

# Group by file name
else:
for cube in cubes:
grouped_cubes[cube.attributes.get("source_file", "")].append(cube)

return grouped_cubes


def fix_metadata(
cubes: Sequence[Cube],
short_name: str,
Expand Down Expand Up @@ -192,13 +215,14 @@ def fix_metadata(
)
fixed_cubes = CubeList()

# Group cubes by input file and apply all fixes to each group element
# (i.e., each file) individually
by_file = defaultdict(list)
for cube in cubes:
by_file[cube.attributes.get("source_file", "")].append(cube)

for group in by_file.values():
# Group cubes and apply all fixes to each group element individually. There
# are two options for grouping:
# (1) By input file name (default).
# (2) By time range (can be enabled by setting the attribute
# GROUP_CUBES_BY_DATE=True for the fix class; see
# _fixes.native6.era5.Rsut for an example).
grouped_cubes = _group_cubes(fixes, cubes)
for group in grouped_cubes.values():
cube_list = CubeList(group)
for fix in fixes:
cube_list = fix.fix_metadata(cube_list)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/cmor/test_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def setUp(self):
self.cube = self._create_mock_cube()
self.fixed_cube = self._create_mock_cube()
self.mock_fix = Mock()
self.mock_fix.GROUP_CUBES_BY_DATE = False
self.mock_fix.fix_metadata.return_value = [self.fixed_cube]
self.expected_get_fixes_call = {
"project": "project",
Expand Down