"""Percentile based threshold module."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, cast
import xarray as xr
from xarray import DataArray, Dataset
from icclim._core.constants import (
DEFAULT_DOY_WINDOW,
DOY_COORDINATE,
DOY_PERCENTILE_UNIT,
PERIOD_PERCENTILE_UNIT,
UNITS_KEY,
)
from icclim._core.generic.threshold.threshold_templates import (
EN_THRESHOLD_TEMPLATE,
PercentileTemplateConfig,
ThresholdMetadata,
)
from icclim._core.input_parsing import (
PercentileDataArray,
build_reference_da,
get_dataarray_from_dataset,
is_dataset_path,
read_clim_bounds,
standardize_percentile_dim_name,
)
from icclim._core.model.quantile_interpolation import (
QuantileInterpolation,
QuantileInterpolationRegistry,
)
from icclim._core.model.threshold import Threshold, ThresholdValueType
if TYPE_CHECKING:
# Standard library
from collections.abc import Callable, Sequence
from datetime import datetime
# Third-party
import jinja2
import pint
# Local
from icclim._core.model.operator import Operator
[docs]
class PercentileThreshold(Threshold):
"""
Percentile based threshold (e.g. "<= 10 doy_per").
The percentiles to be computed are expected to be either:
* "doy percentiles" (unit: "doy_per"). They are usually used for temperatures
indices such as the ECAD :ref:`tx90p <ecad_functions_api>`.
These percentiles are computed per day of year (doy) and by aggregating
values on the time axis ranged by ``reference_period``, using the
``doy_window_width`` parameter to control the time axis window of aggregation.
The resulting `value` is a DataArray with a "dayofyear" dimension ranging from
0 to 365 with one value per day of the year.
* "period percentiles" (unit: "period_per"). They are usually used for liquide
precipitation indices such as the ECAD :ref:`r75p <ecad_functions_api>`
or even :ref:`r75ptot <ecad_functions_api>`.
These percentiles are computed per grid cell on the period ranged by
``reference_period``.
The resulting ``value`` is a DataArray with per grid cell values and no time axis.
``is_ready`` becomes True when `prepare` method has been called, the actual
percentiles are then computed and accessible in ``value`` property.
Once ``is_ready`` is True, ``unit`` property can be set and will attempt a pint unit
conversion similar to what is done on ``BasicThreshold``.
Before that, setting unit has no effect.
"""
reference_period: Sequence[datetime | str] | None
doy_window_width: int
only_leap_years: bool
interpolation: QuantileInterpolation
initial_value: list[float] | None
is_doy_per_threshold: bool
_prepared_value: PercentileDataArray
_initial_unit: str | None
@property # type: ignore[override]
def unit(self) -> str | None:
"""The unit of the threshold."""
res = None
if self.is_ready:
res = self._prepared_value.attrs.get(UNITS_KEY, None)
else:
res = self._initial_unit
return res.replace("°C", "degC") if res else None
[docs]
@unit.setter
def unit(self, unit: str | xr.DataArray | pint.Quantity | pint.Unit | None) -> None:
if self.is_ready:
if (
self._prepared_value.attrs.get(UNITS_KEY, None) is not None
and unit is not None
):
from xclim.core.units import convert_units_to # noqa: PLC0415
self._prepared_value = convert_units_to(
self._prepared_value,
unit,
context="hydro",
)
self._prepared_value.attrs[UNITS_KEY] = unit
else:
self._initial_unit = unit # type: ignore[assignment]
[docs]
@property # type: ignore[override]
def value(self) -> PercentileDataArray: # type: ignore[override]
"""
The computed percentile threshold.
Raises
------
RuntimeError
If the threshold is not ready (prepare has not been called).
"""
if self.is_ready:
return self._prepared_value
msg = (
"Property `value` is not ready. For PercentileDataArray,"
" you must call `.prepare` first and fill `studied_data`"
" parameter in order to prepare `value`."
)
raise RuntimeError(msg)
def __init__(
self,
operator: str | Operator,
value: DataArray | float | Sequence[float],
unit: str | None = None,
doy_window_width: int = DEFAULT_DOY_WINDOW,
only_leap_years: bool = False,
interpolation: QuantileInterpolation
| str = QuantileInterpolationRegistry.MEDIAN_UNBIASED,
reference_period: Sequence[datetime | str] | None = None,
threshold_min_value: pint.Quantity | None = None,
initial_query: str | None = None,
threshold_var_name: str | None = None,
**kwargs, # noqa: ARG002
) -> None:
if is_dataset_path(cast("Any", value)) or isinstance(value, Dataset):
value, is_doy_per_threshold = _build_per_thresh_from_dataset(
value=value,
unit=unit,
threshold_var_name=cast("str", threshold_var_name),
reference_period=cast("Sequence[datetime | str]", reference_period),
)
else:
is_doy_per_threshold = unit == DOY_PERCENTILE_UNIT
if isinstance(value, DataArray):
self.prepare = cast("Any", None) # type: ignore[method-assign]
self._prepared_value = PercentileDataArray.from_da(value)
self.is_ready = True
self._initial_unit = None
self.initial_value = None
self.unit = self._prepared_value.attrs[UNITS_KEY]
else:
self.is_ready = False
self._initial_unit = unit
if isinstance(value, (float, int)):
self.initial_value = [float(value)]
elif isinstance(value, (list, tuple)):
self.initial_value = [float(x) for x in value]
else:
self.initial_value = None
self.unit = unit
self.operator = operator
self.threshold_var_name = threshold_var_name
self.initial_query = initial_query
self.threshold_min_value = threshold_min_value
self.reference_period = reference_period
self.doy_window_width = doy_window_width
self.only_leap_years = only_leap_years
self.interpolation = QuantileInterpolationRegistry.lookup(interpolation)
self.unit = unit
self.is_doy_per_threshold = is_doy_per_threshold
[docs]
def prepare(self, studied_data: DataArray) -> None:
"""
Prepare the data for calculating percentiles.
Parameters
----------
studied_data : DataArray
The input data to be prepared.
Raises
------
NotImplementedError
If the percentile unit is unknown.
Returns
-------
None
"""
if self._initial_unit == DOY_PERCENTILE_UNIT:
prepared_data = _build_doy_per(
studied_data=studied_data,
per_val=cast("Sequence[float]", self.initial_value),
reference_period=cast("Sequence[str]", self.reference_period),
interpolation=self.interpolation,
only_leap_years=self.only_leap_years,
doy_window_width=self.doy_window_width,
percentile_min_value=self.threshold_min_value,
)
elif self._initial_unit == PERIOD_PERCENTILE_UNIT:
prepared_data = _build_period_per(
studied_data=studied_data,
per_val=cast("Sequence[float]", self.initial_value),
reference_period=cast("Sequence[str]", self.reference_period),
interpolation=self.interpolation,
only_leap_years=self.only_leap_years,
percentile_min_value=self.threshold_min_value,
)
else:
msg = f"Unknown percentile unit '{self._initial_unit}'."
raise NotImplementedError(msg)
self._prepared_value = prepared_data.chunk("auto")
self.is_ready = True
def __eq__(self, other: object) -> bool:
"""
Compare two PercentileThreshold objects for equality.
Parameters
----------
other : object
The object to compare with.
Returns
-------
bool
True if the objects are equal, False otherwise.
"""
return (
isinstance(other, PercentileThreshold)
and self.initial_query == other.initial_query
and self.doy_window_width == other.doy_window_width
and self.only_leap_years == other.only_leap_years
and self.interpolation == other.interpolation
and self.reference_period == other.reference_period
and self.unit == other.unit
and self.threshold_min_value == other.threshold_min_value
)
def __hash__(self) -> int:
"""Return the hash of the threshold."""
return hash(
(
self.operator,
self.initial_query,
self.doy_window_width,
self.only_leap_years,
self.interpolation,
tuple(self.reference_period) if self.reference_period else None,
self.unit,
self.threshold_min_value,
)
)
# noinspection PyMethodOverriding
# (reason: with * and **kwargs we can have a different signature while still
# being liskov proof)
[docs]
def compute(
self,
comparison_data: xr.DataArray,
override_op: Callable[[DataArray, DataArray], DataArray] | None = None,
**kwargs,
) -> DataArray:
"""
Compute the percentile threshold.
Parameters
----------
comparison_data : xr.DataArray
The data array to compare with the threshold.
override_op : Callable[[DataArray, DataArray], DataArray] | None, optional
An optional override operator to use instead of the default operator.
**kwargs
Additional keyword arguments.
The `freq` parameter is used to specify the frequency of the data.
The `bootstrap` parameter is used to specify whether to run bootstrap.
Returns
-------
DataArray
The computed percentile threshold.
Raises
------
RuntimeError
If the PercentileThreshold is not ready. You must first call `.prepare`
with a `studied_data` parameter in order to prepare the threshold
for computation.
"""
op_func = (
override_op
if override_op is not None
else cast("Operator", self.operator).compute
)
if self.is_ready:
return self._per_compute(
comparison_data,
self.value,
op_func,
self.is_doy_per_threshold,
cast("str", kwargs.get("freq", "")),
bootstrap=kwargs.get("bootstrap", False),
)
msg = (
"This PercentileThreshold is not ready. You must first call `.prepare`"
" with a `studied_data` parameter in order to prepare the threshold"
" for computation."
)
raise RuntimeError(msg)
def _get_metadata_templates(self, per_coord: DataArray) -> ThresholdMetadata:
if self.is_doy_per_threshold:
if per_coord.size == 1:
return EN_THRESHOLD_TEMPLATE["single_doy_percentile"]
return EN_THRESHOLD_TEMPLATE["multiple_doy_percentiles"]
if per_coord.size == 1:
return EN_THRESHOLD_TEMPLATE["single_period_percentile"]
return EN_THRESHOLD_TEMPLATE["multiple_period_percentiles"]
def _per_compute(
self,
comparison_data: xr.DataArray,
per: xr.DataArray,
op: Callable[[DataArray, DataArray], DataArray],
is_doy_per_threshold: bool,
freq: str,
bootstrap: bool,
) -> DataArray:
from xclim.core.bootstrapping import percentile_bootstrap # noqa: PLC0415
@percentile_bootstrap
def __per_compute(
da: xr.DataArray,
per: xr.DataArray,
op: Callable[[DataArray, DataArray], DataArray],
is_doy_per_threshold: bool,
freq: str, # noqa: ARG001
bootstrap: bool, # noqa: ARG001
) -> DataArray:
if self.threshold_min_value is not None:
# there is only a threshold_min_value when we are computing > or >=
thresh = self.threshold_min_value
from xclim.core.units import convert_units_to # noqa: PLC0415
thresh = convert_units_to(thresh, per, context="hydro")
per = per.where(per > thresh, thresh)
if is_doy_per_threshold:
from xclim.core.calendar import resample_doy # noqa: PLC0415
threshold_value = resample_doy(per, da)
else:
threshold_value = per
return op(da, threshold_value)
return __per_compute(
comparison_data,
per,
op,
is_doy_per_threshold,
freq,
bootstrap,
)
def _compute_per(
per_val: float | Sequence[float], alpha: float, beta: float, study: DataArray
) -> PercentileDataArray:
from collections.abc import Sequence # noqa: PLC0415
from xclim.core.calendar import build_climatology_bounds # noqa: PLC0415
from xclim.core.utils import calc_perc # noqa: PLC0415
computed_per = xr.apply_ufunc(
calc_perc,
study,
input_core_dims=[["time"]],
output_core_dims=[["percentiles"]],
keep_attrs=True,
kwargs={
"percentiles": per_val,
"alpha": alpha,
"beta": beta,
"copy": True,
},
dask="parallelized",
output_dtypes=[study.dtype],
dask_gufunc_kwargs={
"output_sizes": {
"percentiles": len(per_val) if isinstance(per_val, Sequence) else 1
},
"allow_rechunk": True,
},
)
computed_per = computed_per.assign_coords(
percentiles=xr.DataArray(per_val, dims=("percentiles",)),
)
return PercentileDataArray.from_da(
source=computed_per,
climatology_bounds=build_climatology_bounds(study),
)
def _build_period_per(
studied_data: DataArray,
per_val: Sequence[float],
reference_period: Sequence[str],
interpolation: QuantileInterpolation,
only_leap_years: bool,
percentile_min_value: pint.Quantity | None,
) -> PercentileDataArray:
reference = build_reference_da(
studied_data,
reference_period,
only_leap_years,
percentile_min_value=percentile_min_value,
)
return _compute_per(per_val, interpolation.alpha, interpolation.beta, reference)
def _build_doy_per(
studied_data: DataArray,
per_val: Sequence[float],
reference_period: Sequence[str],
interpolation: QuantileInterpolation,
only_leap_years: bool,
doy_window_width: int,
percentile_min_value: pint.Quantity | None,
) -> PercentileDataArray:
reference = build_reference_da(
studied_data,
reference_period,
only_leap_years,
percentile_min_value,
)
from xclim.core.calendar import percentile_doy # noqa: PLC0415
return percentile_doy(
arr=reference,
window=doy_window_width,
per=per_val,
alpha=interpolation.alpha,
beta=interpolation.beta,
)
def _build_per_thresh_from_dataset(
value: ThresholdValueType,
unit: str | None,
threshold_var_name: str,
reference_period: Sequence[datetime | str],
) -> tuple[DataArray, bool]:
if (isinstance(value, (str, Dataset)) and is_dataset_path(value)) or isinstance(
value, Dataset
):
v = cast("Dataset | str", value)
thresh_da = get_dataarray_from_dataset(threshold_var_name, v)
elif isinstance(value, DataArray):
thresh_da = value
else:
msg = f"Cannot build threshold from a {type(value)}."
raise NotImplementedError(msg)
built_value = PercentileDataArray.from_da(
standardize_percentile_dim_name(thresh_da),
read_clim_bounds(reference_period, thresh_da),
)
if unit is not None:
if built_value.attrs.get(UNITS_KEY, None) is not None:
from xclim.core.units import convert_units_to # noqa: PLC0415
built_value = convert_units_to(built_value, unit, context="hydro")
built_value.attrs[UNITS_KEY] = unit
return built_value, DOY_COORDINATE in built_value.coords