Source code for icclim._core.generic.indicator

"""Contain the GenericIndicator class."""

from __future__ import annotations

import contextlib
import logging
from collections.abc import Sequence
from copy import deepcopy
from functools import reduce
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
import xarray as xr
from jinja2 import Environment
from pint.errors import DimensionalityError, UndefinedUnitError
from xarray import DataArray

# xclim imports are deferred to avoid triggering fire module load (and numba cache errors) on import.
from icclim._core.climate_variable import must_run_bootstrap
from icclim._core.constants import MIN_LEN_FOR_FREQ_INFERENCE, RESAMPLE_METHOD
from icclim._core.generic.functions import check_freq
from icclim._core.generic.generic_templates import INDICATORS_TEMPLATES_EN
from icclim._core.model.indicator import Indicator
from icclim.exception import InvalidIcclimArgumentError

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from collections.abc import Callable

    import jinja2

    from icclim._core.climate_variable import ClimateVariable
    from icclim._core.model.index_config import IndexConfig
    from icclim.frequency import Frequency


def _get_ureg() -> Any:  # noqa: ANN401
    """Lazily initialize the unit registry and its hydro context."""
    from xclim.core.units import units as ureg  # noqa: PLC0415

    if "hydro" not in ureg._contexts:  # noqa: SLF001
        context_hydro = ureg.Context("hydro")
        context_hydro.add_transformation(
            "[mass] / [length] ** 2",
            "[length]",
            lambda ureg, x: x * ureg("1 mm") / ureg("1 kg / m^2"),
        )
        context_hydro.add_transformation(
            "[length]",
            "[mass] / [length] ** 2",
            lambda ureg, x: x * ureg("1 kg / m^2") / ureg("1 mm"),
        )
        context_hydro.add_transformation(
            "[mass] / [length] ** 2 / [time]",
            "[length] / [time]",
            lambda ureg, x: x * ureg("1 mm / s") / ureg("1 kg / m^2 / s"),
        )
        context_hydro.add_transformation(
            "[length] / [time]",
            "[mass] / [length] ** 2 / [time]",
            lambda ureg, x: x * ureg("1 kg / m^2 / s") / ureg("1 mm / s"),
        )
        ureg.add_context(context_hydro)
    return ureg


jinja_env = Environment(autoescape=True)


[docs] class GenericIndicator(Indicator): """ GenericIndicator are climate indicators wich are not specific to a particular domain. They can be computed from any climate variable and are combined with `Threshold` objects to create personalized indicators. Parameters ---------- name: str The name of the indicator. process: Callable[..., DataArray] The function that processes the indicator. definition: str The definition of the indicator. check_vars: Callable[[list[ClimateVariable], GenericIndicator], None], optional A function that checks if the variables meet the indicator requirements. Defaults to None. sampling_methods: list[str], optional A list of sampling methods that can be used with the indicator. Defaults to None. missing: str, optional The method for handling missing values. Defaults to "any". missing_options: dict, optional Additional options for handling missing values. Defaults to None. qualifiers: tuple, optional Additional qualifiers for the indicator. Defaults to (). Attributes ---------- missing: str The method for handling missing values. missing_options: dict | None Additional options for handling missing values. """ missing: str missing_options: dict | None def __init__( self, name: str, process: Callable[..., DataArray], definition: str, check_vars: ( Callable[[list[ClimateVariable], GenericIndicator], None] | None ) = None, sampling_methods: list[str] | None = None, missing: str = "any", missing_options: dict | None = None, qualifiers: tuple = (), ) -> None: """ Initialize a GenericIndicator object. Parameters ---------- name : str The name of the indicator. process : Callable[..., DataArray] The processing function of the indicator. definition : str A definition for the indicator. check_vars : Callable[[list[ClimateVariable], GenericIndicator], None] | None, optional A function that checks the variables used by the indicator, by default None. sampling_methods : list[str] | None, optional The sampling methods used by the indicator, by default None. missing : str, optional The method for handling missing values, by default "any". missing_options : Any, optional The options for handling missing values, by default None. qualifiers : tuple, optional The qualifiers for the indicator, by default (). Raises ------ ValueError If `missing_options` is set with `missing` method being from context. Notes ----- See the `GenericIndicatorRegistry` class for a list of available indicators. Examples -------- >>> from icclim.generic_indices import GenericIndicator >>> def process(climate_vars, resample_freq): ... out = climate_vars[0].studied_data + climate_vars[1].studied_data ... out.resample(time=resample_freq).mean() ... return out >>> def check_vars(climate_vars, indicator): ... if len(climate_vars) != 2: ... raise ValueError( ... "This indicator requires exactly 2 climate variables." ... ) >>> indicator = GenericIndicator( ... name="test", ... process=process, ... definition="This is a test indicator", ... check_vars=check_vars, ... sampling_methods=["daily"], ... missing="skip", ... missing_options=None, ... qualifiers=(), ... ) """ super().__init__() self.missing_options = missing_options self.missing = missing or "any" en_indicator_templates = deepcopy(INDICATORS_TEMPLATES_EN[name]) self.name = name self.process = process self.standard_name = en_indicator_templates["standard_name"] self.cell_methods = en_indicator_templates["cell_methods"] self.long_name = en_indicator_templates["long_name"] self.check_vars = check_vars self.definition = definition self.qualifiers = qualifiers self.sampling_methods = ( sampling_methods if sampling_methods is not None else [RESAMPLE_METHOD] ) def __hash__(self) -> int: """Return the hash of the indicator.""" return hash(self.name)
[docs] def preprocess( self, climate_vars: list[ClimateVariable], jinja_scope: dict[str, Any], output_frequency: Frequency, src_freq: Frequency, output_unit: str | None, coef: float | None, sampling_method: str, ) -> list[ClimateVariable]: """ Preprocesses the climate variables before computing the indicator. Parameters ---------- climate_vars : list[ClimateVariable] The list of climate variables to be preprocessed. jinja_scope : dict[str, Any] The Jinja scope used for formatting the template. output_frequency : Frequency The desired frequency of the output. src_freq : Frequency The source frequency of the climate variables. output_unit : str | None The desired output unit of the indicator. If None, no unit conversion is performed. coef : float | None The coefficient to multiply the climate variable data with. If None, no multiplication is performed. sampling_method : str The sampling method used for some specific indicators. See `difference_of_means` for example. Returns ------- list[ClimateVariable] The preprocessed climate variables. """ self._check_for_invalid_setup(climate_vars, sampling_method) climate_vars = self._apply_transforms(climate_vars, output_unit, coef) if output_frequency.indexer: from xclim.core.calendar import select_time # noqa: PLC0415 for climate_var in climate_vars: climate_var.studied_data = select_time( climate_var.studied_data, **output_frequency.indexer, drop=True, ) if output_frequency.seasonal_bounds: _apply_seasonal_mask(climate_vars, output_frequency.seasonal_bounds) _check_data(climate_vars, src_freq.pandas_freq) _check_cf(climate_vars) self._format_template(jinja_scope=jinja_scope) return climate_vars
def _apply_transforms( self, climate_vars: list[ClimateVariable], output_unit: str | None, coef: float | None, ) -> list[ClimateVariable]: if output_unit is not None: if _is_amount_unit(output_unit): climate_vars = _convert_rates_to_amounts( climate_vars=climate_vars, output_unit=output_unit, ) elif _is_a_diff_indicator(self) and output_unit != "%": from xclim.core.units import convert_units_to # noqa: PLC0415 # [gh:255] Indicators computing the difference between two # variables must first convert the units of input variables # to the expected output unit in order to avoid converting # the output of the difference. # In other words: a 15 Kelvin difference *is* equivalent # to a 15 degC difference, but if we would convert the unit after # computing the difference, we could get -258.15 degC from the # 15 Kelvin. for climate_var in climate_vars: climate_var.studied_data = convert_units_to( climate_var.studied_data, target=output_unit, context="hydro", ) if climate_var.threshold is not None: climate_var.threshold.unit = output_unit if coef is not None: for climate_var in climate_vars: climate_var.studied_data = coef * climate_var.studied_data return climate_vars
[docs] def postprocess( # noqa: C901 self, result: DataArray, climate_vars: list[ClimateVariable], output_freq: str, src_freq: str, indexer: dict[Any, Any] | None, out_unit: str | None, allow_partial_seasons: bool | Literal["start", "end"], ) -> DataArray: """ Postprocesses the result of the indicator computation. Parameters ---------- result : DataArray The result of the indicator computation. climate_vars : list[ClimateVariable] The list of climate variables used for the computation. output_freq : str The desired output frequency of the postprocessed result. src_freq : str The source frequency of the input data. indexer : dict The indexer used to subset the input data. out_unit : str | None The desired output unit of the postprocessed result. If None, no unit conversion is performed. Returns ------- DataArray The postprocessed result. """ """ >>> PATCHED: Difference-aware postprocess Convert absolute temperatures to degC at the very end. Temperature differences (deltaT) remain in K (numerically identical to degC differences). Precipitation / amounts handled normally using hydro context. """ from xclim.core.units import convert_units_to # noqa: PLC0415 ureg = _get_ureg() if out_unit is not None and _is_amount_unit(out_unit): # Use Pint dimensionality + hydro context for precipitation-like units current_unit = result.attrs.get("units", None) if current_unit is not None: try: q = 1 * ureg(current_unit) if q.check("[mass] / [length] ** 2 / [time]") or q.check( "[mass] / [length] ** 2" ): # Always use hydro context for rates and amounts with ureg.context("hydro"): result = convert_units_to(result, out_unit, context="hydro") else: result = convert_units_to(result, out_unit, context="hydro") except ( DimensionalityError, UndefinedUnitError, ValueError, AttributeError, KeyError, ) as e: logger.warning( "Unit conversion failed for unit '%s': %s", current_unit, e, exc_info=True, ) else: result = convert_units_to(result, out_unit, context="hydro") elif out_unit is not None: result = convert_units_to(result, out_unit, context="hydro") if self.missing != "skip" and indexer is not None: # reference variable is a subset of the studied variable, # so no need to check it. it = filter(lambda cv: not cv.is_reference, climate_vars) das = [cv.studied_data for cv in it] if "time" in result.dims: # If src_freq cannot be inferred by xclim, fall back to universal check_freq if src_freq is None: try: src_freq = check_freq(result, dim="time") except (ValueError, TypeError, AttributeError, KeyError): src_freq = "D" # safe fallback: daily result = self._handle_missing_values( in_data=das, out_data=result, resample_freq=output_freq, src_freq=src_freq, indexer=indexer, allow_partial_seasons=allow_partial_seasons, ) for prop in self.templated_properties: result.attrs[prop] = getattr(self, prop) result.attrs["history"] = "" return result
# >>> PATCHED helper: difference-aware flag def _is_a_diff_indicator(self: Indicator) -> bool: return "compute_diff" in self.qualifiers def __call__(self, config: IndexConfig) -> DataArray: """ Compute the indicator based on the given configuration. Parameters ---------- config : IndexConfig The configuration object containing the settings for computing the indicator. Returns ------- DataArray The computed indicator as a DataArray. """ src_freq = config.climate_variables[0].source_frequency base_jinja_scope = { "np": np, "enumerate": enumerate, "len": len, "output_freq": config.frequency, "source_freq": src_freq, } climate_vars_meta = _get_climate_vars_metadata( config.climate_variables, src_freq, base_jinja_scope, jinja_env, ) jinja_scope: dict[str, Any] = { "min_spell_length": config.min_spell_length, "rolling_window_width": config.rolling_window_width, "climate_vars": climate_vars_meta, "is_compared_to_reference": config.is_compared_to_reference, "reference_period": config.reference_period, } jinja_scope.update(base_jinja_scope) climate_vars = self.preprocess( climate_vars=config.climate_variables, jinja_scope=jinja_scope, output_frequency=config.frequency, src_freq=src_freq, output_unit=config.out_unit, coef=config.coef, sampling_method=config.sampling_method, ) result = self.process( climate_vars=climate_vars, resample_freq=config.frequency, min_spell_length=config.min_spell_length, rolling_window_width=config.rolling_window_width, group_by_freq=config.frequency.group_by_key, is_compared_to_reference=config.is_compared_to_reference, logical_link=config.logical_link, date_event=config.date_event, source_freq_delta=src_freq.delta, to_percent=config.out_unit == "%", sampling_method=config.sampling_method, run_index=config.run_index, ) return self.postprocess( result, climate_vars=climate_vars, output_freq=config.frequency.pandas_freq, src_freq=src_freq.pandas_freq, indexer=config.frequency.indexer, out_unit=config.out_unit, allow_partial_seasons=config.allow_partial_seasons, ) def __eq__(self, other: object) -> bool: """ Check if two GenericIndicator objects are equal. Parameters ---------- other : Any The object to compare with. Returns ------- bool True if the two objects are equal, False otherwise. """ return ( isinstance(other, GenericIndicator) and self.long_name == other.long_name and self.standard_name == other.standard_name and self.process == other.process ) def __str__(self) -> str: """ Return the name of the indicator. Returns ------- str The name of the indicator. """ return self.name def _check_for_invalid_setup( self, climate_vars: list[ClimateVariable], sampling_method: str, ) -> None: if not _same_freq_for_all(climate_vars): msg = ( "All variables must have the same time frequency (for example daily) to" " be compared with each others, but this was not the case." ) raise InvalidIcclimArgumentError(msg) if sampling_method not in self.sampling_methods: msg = ( f"{self.name} can only be computed with the following" f" sampling_method(s): {self.sampling_methods}" ) raise InvalidIcclimArgumentError(msg) if self.check_vars is not None: # Run indicator specific check method self.check_vars(climate_vars, self) def _format_template(self, jinja_scope: dict) -> None: for templated_property in self.templated_properties: template = jinja_env.from_string( getattr(self, templated_property), globals=jinja_scope, ) setattr(self, templated_property, template.render()) def _handle_missing_values( self, in_data: Sequence[DataArray] | DataArray, out_data: DataArray, resample_freq: str | None = None, src_freq: str | None = None, indexer: dict[Any, Any] | None = None, allow_partial_seasons: bool | Literal["start", "end"] = False, ) -> DataArray: """ Handle missing values in climate index computations. Parameters ---------- in_data : list[xr.DataArray] Input DataArrays with time coordinates. out_data : xr.DataArray Output DataArray from the index computation. resample_freq : str, optional Target resampling frequency (e.g. "M", "Y"). src_freq : str, optional Source timestep frequency (e.g. "D"). indexer : dict, optional Extra arguments used by some missing value methods. """ from xclim.core.missing import MISSING_METHODS # noqa: PLC0415 missing_class = MISSING_METHODS[self.missing] # Get the class missing_obj = missing_class() # Instantiate with no args # We flag periods according to the missing method. Skip variables without a time coordinate. miss = ( missing_obj( da, freq=resample_freq, src_timestep=src_freq, **(indexer or {}) ) for da in (in_data if isinstance(in_data, Sequence) else [in_data]) if "time" in da.coords ) # Reduce by logical OR across all input masks mask = reduce(np.logical_or, miss) # Reindex mask to match output if needed if isinstance(mask, DataArray) and mask.time.size < out_data.time.size: mask = mask.reindex(time=out_data.time, fill_value=True) if allow_partial_seasons is True: # Unmask the first and last periods mask = xr.where( (mask.time == mask.time[0]) | (mask.time == mask.time[-1]), False, mask ) elif allow_partial_seasons == "start": # Unmask only the first period mask = xr.where(mask.time == mask.time[0], False, mask) elif allow_partial_seasons == "end": # Unmask only the last period mask = xr.where(mask.time == mask.time[-1], False, mask) return out_data.where(~mask)
def _same_freq_for_all(climate_vars: list[ClimateVariable]) -> bool: if len(climate_vars) == 1: return True freqs = [xr.infer_freq(a.studied_data.time) for a in climate_vars] return all(x == freqs[0] for x in freqs[1:]) def _get_climate_vars_metadata( climate_vars: list[ClimateVariable], resample_freq: Frequency, jinja_scope: dict[str, Any], jinja_env: jinja2.Environment, ) -> list[dict[str, Any]]: return [ c_var.build_indicator_metadata( resample_freq, must_run_bootstrap(c_var.studied_data, c_var.threshold), jinja_scope, jinja_env, ) for c_var in climate_vars ] def _convert_rates_to_amounts( climate_vars: list[ClimateVariable], output_unit: str ) -> list[ClimateVariable]: """ Convert rate-like climate variables to amount units using xclim's rate2amount. Handles both classic rates (e.g., mm/s) and precipitation (kg m-2 / time) using a dedicated Pint context. """ from xclim.core.units import rate2amount # noqa: PLC0415 ureg = _get_ureg() for climate_var in climate_vars: # Get the current unit of the variable current_unit = climate_var.studied_data.attrs.get("units", None) if current_unit is None: continue # Skip conversion if already an amount if _is_amount_unit(current_unit): continue try: # Clean unit string for Pint parsing unit_clean = current_unit.replace(" ", "*").replace("-", "**") q = 1 * ureg(unit_clean) dims = q.dimensionality # Determine if this is a precipitation-like unit (mass/area or mass/area/time) is_precip = ( dims == ureg("kg / m^2").dimensionality or dims == ureg("kg / m^2 / s").dimensionality or "kg/m2" in current_unit.replace(" ", "").lower() ) if is_precip: # >>> Hydro context applied here logger.info( "Converting %s: %s -> %s with hydro context", climate_var.name, current_unit, output_unit, ) with ureg.context("hydro"): da = rate2amount(climate_var.studied_data, out_units=output_unit) else: # Non-precipitation rates: normal conversion da = rate2amount(climate_var.studied_data, out_units=output_unit) # Update the variable with converted data climate_var.studied_data = da.astype("float64") except ( TypeError, ValueError, AttributeError, KeyError, DimensionalityError, UndefinedUnitError, ) as e: # Skip on error but report it logger.warning( "_convert_rates_to_amounts: exception for unit '%s', skipping: %s", current_unit, e, exc_info=True, # optional: include traceback ) continue return climate_vars def _is_amount_unit(unit: str) -> bool: """ Return True if `unit` is an amount unit (length-like), False otherwise. """ ureg = _get_ureg() unit = unit.strip() offset_units = ["K", "degC", "degree_Celsius", "C", "?C"] if any(u in unit for u in offset_units): return False try: # Fix for Pint parsing: replace spaces with * and '-' with '**' unit_clean = unit.replace(" ", "*").replace("-", "**") q = 1 * ureg(unit_clean) # Compare to the dimensionality of 1 meter return q.dimensionality == (1 * ureg("m")).dimensionality except (DimensionalityError, UndefinedUnitError, AttributeError, TypeError) as e: logger.warning( "_is_amount_unit: exception for unit '%s', returning False: %s", unit, e ) return False
[docs] def _check_cf(climate_vars: list[ClimateVariable]) -> None: """Compare metadata attributes to CF-Convention standards. Default cfchecks use the specifications in `xclim.core.utils.VARIABLES`, assuming the indicator's inputs are using the CMIP6/xclim variable names correctly. Variables absent from these default specs are silently ignored. When subclassing this method, use functions decorated using `xclim.core.options.cfcheck`. """ from xclim.core.cfchecks import cfcheck_from_name # noqa: PLC0415 for da in climate_vars: with contextlib.suppress(KeyError): # Silently ignore unknown variables. cfcheck_from_name(str(da.name), da.studied_data)
def _check_data(climate_vars: list, src_freq: str) -> None: if src_freq is None: return for climate_var in climate_vars: da = climate_var.studied_data if ( "time" in da.coords and da.time.ndim == 1 and len(da.time) > MIN_LEN_FOR_FREQ_INFERENCE ): inferred_freq = check_freq(da, dim="time", strict=True) if inferred_freq != src_freq: msg = ( f"[icclim] Frequency mismatch for variable '{climate_var.name}': " f"expected '{src_freq}', inferred '{inferred_freq}'" ) raise InvalidIcclimArgumentError(msg) def _is_a_diff_indicator(indicator: Indicator) -> bool: return "compute_diff" in indicator.qualifiers def _apply_seasonal_mask( climate_vars: list[ClimateVariable], seasonal_bounds: tuple[DataArray, DataArray], ) -> None: start_da, end_da = seasonal_bounds for climate_var in climate_vars: da = climate_var.studied_data doy = da.time.dt.dayofyear mask = xr.where( start_da <= end_da, (doy >= start_da) & (doy <= end_da), (doy >= start_da) | (doy <= end_da), ) climate_var.studied_data = da.where(mask)