autocast.metrics.ensemble#

class BTSCMMetric(score_dims='spatial', reduce_all=True, dist_sync_on_step=False)[source]#

Bases: BaseMetric[Float[Tensor, 'batch time spatial *spatial channel ensemble'], Float[Tensor, 'batch time spatial *spatial channel']]

Base class for ensemble metrics that operate on spatial tensors.

Checks input types and shapes and converts to Tensor.

Parameters:
  • score_dims (Literal['spatial', 'temporal'] | None) – Which dimension to compute the score. ‘spatial’: average over spatial dims → (B, T, C) ‘temporal’: average over temporal dim → (B, S, C) None: no reduction → (B, T, S, C) These are the respective dimensions after calling .score() This works in conjunction with the reduce_all parameter that is applied in the compute() method to determine the final output shape

  • reduce_all (bool) – If True, return scalar by averaging over all non-batch dims

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward()

score(y_pred, y_true)[source]#

Compute metric score, then reduce according to self.score_dims.

Parameters:
  • y_pred (Tensor | ndarray) – Predictions of shape (B, T, S, C, M)

  • y_true (Tensor | ndarray) – Ground truth of shape (B, T, S, C)

Returns:

Tensor of shape (B, T, C) if reduce_over=’spatial’, (B, S, C) if reduce_over=’temporal’, or (B, T, S, C) if None.

Return type:

Float[Tensor, ‘batch time channel’] | Float[Tensor, ‘batch spatial *spatial channel’] | Float[Tensor, ‘batch time spatial *spatial channel’]

class CRPS(score_dims='spatial', reduce_all=True, dist_sync_on_step=False)[source]#

Bases: BTSCMMetric

Continuous Ranked Probability Score (CRPS) for ensemble forecasts.

References:#

Hersbach, H., 2000: Decomposition of the Continuous Ranked Probability Score for Ensemble Prediction Systems. Wea. Forecasting, 15, 559-570, https://doi.org/10.1175/1520-0434(2000)015<0559:DOTCRP>2.0.CO;2.

name: str = 'crps'#
Parameters:
  • score_dims (Literal['spatial', 'temporal'] | None)

  • reduce_all (bool)

  • dist_sync_on_step (bool)

class CRPSMAETerm(score_dims='spatial', reduce_all=True, dist_sync_on_step=False)[source]#

Bases: BTSCMMetric

Mean-absolute-error term in the CRPS decomposition.

Note

This is the first CRPS term,

\[\frac{1}{M}\sum_{m=1}^{M} |x_m - y|,\]

so it is MAE-like, but it is not the deterministic MAE of the ensemble mean.

Parameters:
  • score_dims (Literal['spatial', 'temporal'] | None)

  • reduce_all (bool)

  • dist_sync_on_step (bool)

name: str = 'crps_mae_term'#
class CRPSSpreadTerm(score_dims='spatial', reduce_all=True, dist_sync_on_step=False)[source]#

Bases: BTSCMMetric

Pairwise spread term in the CRPS decomposition.

Note

This is the second CRPS term,

\[\frac{1}{2M^2}\sum_{j=1}^{M}\sum_{k=1}^{M}|x_j - x_k|,\]

represented via the sort-based identity used elsewhere in this module.

Parameters:
  • score_dims (Literal['spatial', 'temporal'] | None)

  • reduce_all (bool)

  • dist_sync_on_step (bool)

name: str = 'crps_spread_term'#
class FairCRPS(score_dims='spatial', reduce_all=True, dist_sync_on_step=False)[source]#

Bases: BTSCMMetric

Fair Continuous Ranked Probability Score (fCRPS) for ensemble forecasts.

References:#

Ferro, C.A.T. (2014), Fair scores for ensemble forecasts. Q.J.R. Meteorol. Soc., 140: 1917-1923. https://doi.org/10.1002/qj.2270

name: str = 'fcrps'#
Parameters:
  • score_dims (Literal['spatial', 'temporal'] | None)

  • reduce_all (bool)

  • dist_sync_on_step (bool)

class FairCRPSMAETerm(score_dims='spatial', reduce_all=True, dist_sync_on_step=False)[source]#

Bases: BTSCMMetric

Mean-absolute-error term in the fCRPS decomposition.

Parameters:
  • score_dims (Literal['spatial', 'temporal'] | None)

  • reduce_all (bool)

  • dist_sync_on_step (bool)

name: str = 'fcrps_mae_term'#
class FairCRPSSpreadTerm(score_dims='spatial', reduce_all=True, dist_sync_on_step=False)[source]#

Bases: BTSCMMetric

Pairwise spread term in the fCRPS decomposition.

Parameters:
  • score_dims (Literal['spatial', 'temporal'] | None)

  • reduce_all (bool)

  • dist_sync_on_step (bool)

name: str = 'fcrps_spread_term'#
class AlphaFairCRPS(alpha=0.95, *, score_dims='spatial', reduce_all=True, dist_sync_on_step=False)[source]#

Bases: BTSCMMetric

Almost Fair Continuous Ranked Probability Score (afCRPS) (stable form).

Note

Definition: .. math:

\text{afCRPS}_\alpha := \alpha \text{fCRPS} + (1-\alpha) \text{CRPS}

Implementation follows eq. (4) in the AIFS-CRPS paper: rearranged sum of positive terms to avoid instability.

References

Lang, S., Alexe, M., Clare, M. C., Roberts, C., Adewoyin, R., Bouallègue, Z. B., … & Leutbecher, M. (2024). AIFS-CRPS: ensemble forecasting using a model trained with a loss function based on the continuous ranked probability score. arXiv preprint arXiv:2412.15832.

Parameters:
  • alpha (float)

  • score_dims (Literal['spatial', 'temporal'] | None)

  • reduce_all (bool)

  • dist_sync_on_step (bool)

name: str = 'afcrps'#
class AlphaFairCRPSMAETerm(alpha=0.95, *, score_dims='spatial', reduce_all=True, dist_sync_on_step=False)[source]#

Bases: BTSCMMetric

Mean-absolute-error term paired with afCRPS monitoring.

The MAE-like term itself does not depend on alpha, but this class accepts it so experiment configs can keep the afCRPS diagnostic bundle parameterized consistently.

Parameters:
  • alpha (float)

  • score_dims (Literal['spatial', 'temporal'] | None)

  • reduce_all (bool)

  • dist_sync_on_step (bool)

name: str = 'afcrps_mae_term'#
class AlphaFairCRPSSpreadTerm(alpha=0.95, *, score_dims='spatial', reduce_all=True, dist_sync_on_step=False)[source]#

Bases: BTSCMMetric

Pairwise spread term in the afCRPS decomposition.

Parameters:
  • alpha (float)

  • score_dims (Literal['spatial', 'temporal'] | None)

  • reduce_all (bool)

  • dist_sync_on_step (bool)

name: str = 'afcrps_spread_term'#
class EnergyScore(alpha=1.0, vector_dims='spatial_temporal', **kwargs)[source]#

Bases: BTSCMMetric

Energy score (multivariate CRPS) for ensemble forecasts.

For a vector-valued forecast with ensemble members \(x_m \in \mathbb{R}^d\) and observation \(y \in \mathbb{R}^d\), this computes

\[ES_\alpha(F, y) = \frac{1}{M} \sum_{m=1}^M \lVert x_m - y \rVert_2^\alpha - \frac{1}{2M^2} \sum_{m=1}^M \sum_{j=1}^M \lVert x_m - x_j \rVert_2^\alpha,\]

with \(\alpha \in (0, 2)\).

Note

The vector_dims argument controls which dimensions define the multivariate vector used in the norm.

Parameters:
  • alpha (float)

  • vector_dims (Literal['spatial', 'temporal', 'spatial_temporal', 'spatial_temporal_channels'])

name: str = 'energy'#
vector_dims: Literal['spatial', 'temporal', 'spatial_temporal', 'spatial_temporal_channels']#
class VariogramScore(p=0.5, weights=None, vector_dims='spatial_temporal', **kwargs)[source]#

Bases: BTSCMMetric

Variogram score for multivariate ensemble forecasts.

For vector-valued forecast members \(x_m \in \mathbb{R}^d\) and observation \(y \in \mathbb{R}^d\), this computes

\[VS_p(F, y) = \sum_{i,j=1}^d w_{ij} \left(\frac{1}{M}\sum_{m=1}^M |x_{m,i} - x_{m,j}|^p - |y_i - y_j|^p\right)^2,\]

with \(p > 0\) and non-negative weights \(w_{ij}\).

Note

The vector_dims argument controls which dimensions define the multivariate vector used by the variogram transformation.

Parameters:
  • p (float)

  • weights (Tensor | ndarray | None)

  • vector_dims (Literal['spatial', 'temporal', 'spatial_temporal', 'spatial_temporal_channels'])

name: str = 'variogram'#
vector_dims: Literal['spatial', 'temporal', 'spatial_temporal', 'spatial_temporal_channels']#
class SpreadSkillRatio(eps=1e-06, **kwargs)[source]#

Bases: BTSCMMetric

Corrected spread-to-skill ratio (SSR) for ensemble forecasts.

Note

Uses the corrected finite-ensemble form: .. math:

\text{SSR}_{\text{corrected}} = \frac{\text{Spread}}{\text{Skill}}
\sqrt{\frac{M + 1}{M}},

where skill is the pointwise RMSE of the ensemble mean and spread is the pointwise ensemble standard deviation. Spatial/temporal reductions are then handled by the base class according to score_dims.

Parameters:

eps (float)

name: str = 'ssr'#
score(y_pred, y_true)[source]#

Compute corrected spread-to-skill ratio.

Reductions (spatial/temporal) are applied to the variance and MSE before taking the square root and computing the ratio (i.e., reduce variance/MSE first, then sqrt, then divide).

Parameters:
Returns:

(B, T, C) if score_dims=’spatial’, (B, S, C) if temporal,

or (B, T, S, C) if None.

Return type:

SSR

class EnsembleSpread(*, corrected=True, score_dims='spatial', reduce_all=True, dist_sync_on_step=False)[source]#

Bases: BTSCMMetric

Ensemble spread for probabilistic forecasts.

Note

By default, returns a finite-ensemble corrected spread:

\[\text{Spread}_{\text{corr}} = \sqrt{\left\langle \mathrm{Var}_{m,\text{unbiased}}(x_m)\right\rangle} \sqrt{\frac{M + 1}{M}}.\]

This correction is commonly used so that spread and skill are comparable for finite ensemble sizes when using unbiased sample variance. It matches the form used in LoLA/paper evaluations (Appendix “Spread / Skill”) where: spread = sqrt((M+1)/(M-1) * mean((x_m - mean_m)^2)), since Var_unbiased = (M/(M-1)) * mean((x_m - mean_m)^2).

If corrected=False, returns the uncorrected macroscopic ensemble standard deviation computed from the unbiased variance estimator:

\[\sqrt{\left\langle \mathrm{Var}_{m,\text{unbiased}}(x_m)\right\rangle}.\]
Parameters:
  • corrected (bool)

  • score_dims (Literal['spatial', 'temporal'] | None)

  • reduce_all (bool)

  • dist_sync_on_step (bool)

name: str = 'spread'#
score(y_pred, y_true)[source]#

Compute metric score, then reduce according to self.score_dims.

Parameters:
  • y_pred (Tensor | ndarray) – Predictions of shape (B, T, S, C, M)

  • y_true (Tensor | ndarray) – Ground truth of shape (B, T, S, C)

Returns:

Tensor of shape (B, T, C) if reduce_over=’spatial’, (B, S, C) if reduce_over=’temporal’, or (B, T, S, C) if None.

Return type:

Float[Tensor, ‘batch time channel’] | Float[Tensor, ‘batch spatial *spatial channel’] | Float[Tensor, ‘batch time spatial *spatial channel’]

class EnsembleSkill(score_dims='spatial', reduce_all=True, dist_sync_on_step=False)[source]#

Bases: BTSCMMetric

Ensemble skill defined as RMSE of the ensemble mean.

Note

Skill is defined as the RMSE of the ensemble mean:

\[\text{Skill} = \sqrt{\left\langle (\bar{x} - y)^2 \right\rangle},\]

where \(\langle \cdot \rangle\) denotes the spatial mean.

This metric reduces the squared error over spatial/temporal dimensions before taking the square root (macroscopic RMSE), as is commonly done in ensemble forecast evaluation (and in LoLA/paper appendices).

In the default spatial-reduction evaluation path, this is numerically equivalent to the deterministic RMSE metric applied to an ensemble prediction tensor, because RMSE first averages over the ensemble dimension and then computes RMSE.

Parameters:
  • score_dims (Literal['spatial', 'temporal'] | None)

  • reduce_all (bool)

  • dist_sync_on_step (bool)

name: str = 'skill'#
score(y_pred, y_true)[source]#

Compute metric score, then reduce according to self.score_dims.

Parameters:
  • y_pred (Tensor | ndarray) – Predictions of shape (B, T, S, C, M)

  • y_true (Tensor | ndarray) – Ground truth of shape (B, T, S, C)

Returns:

Tensor of shape (B, T, C) if reduce_over=’spatial’, (B, S, C) if reduce_over=’temporal’, or (B, T, S, C) if None.

Return type:

Float[Tensor, ‘batch time channel’] | Float[Tensor, ‘batch spatial *spatial channel’] | Float[Tensor, ‘batch time spatial *spatial channel’]

class WinklerScore(alpha=0.1, **kwargs)[source]#

Bases: BTSCMMetric

Winkler interval score for central prediction intervals.

For significance level \(\alpha \in (0, 1)\), this metric computes central \((1-\alpha)\) prediction intervals from ensemble quantiles and returns the per-point interval score

\[ \begin{align}\begin{aligned}W_\alpha = (u - l) + \frac{2}{\alpha}(l - y)\mathbf{1}(y < l) + \frac{2}{\alpha}(y - u)\mathbf{1}(y > u),\\where :math:`l` and :math:`u` are the lower/upper interval bounds.\\Shape conventions ----------------- - Input prediction tensor: y_pred has shape (B, T, S..., C, M) - Input truth tensor: y_true has shape (B, T, S..., C) - Quantiles are computed along ensemble dim M: - l (lower): ``q_{alpha/2}``, shape (B, T, S..., C) - u (upper): ``q_{1-alpha/2}``, shape (B, T, S..., C) - y corresponds to y_true, same shape (B, T, S..., C).\\The internal ``_score`` returns pointwise Winkler scores with shape ``(B, T, S..., C)``. The base class then applies ``score_dims`` and ``reduce_all`` reductions: - ``score_dims='spatial'`` (default) -> ``(B, T, C)`` - ``score_dims='temporal'`` -> ``(B, S..., C)`` - ``score_dims=None`` -> ``(B, T, S..., C)`` - if ``reduce_all=True`` (default), ``compute()`` returns a scalar.\end{aligned}\end{align} \]

Lower values are better: narrow intervals are rewarded, and misses are penalized in proportion to their distance outside the interval.

References:#

Winkler, R. L. (1972). A Decision-Theoretic Approach to Interval Estimation. Journal of the American Statistical Association, 67(337), 187-191. https://doi.org/10.1080/01621459.1972.10481224

Gneiting, T., & Raftery, A. E. (2007). Strictly Proper Scoring Rules, Prediction, and Estimation. Journal of the American Statistical Association, 102(477), 359-378. https://doi.org/10.1198/016214506000001437

name: str = 'winkler'#
Parameters:

alpha (float)

class MultiWinkler(coverage_levels=None, score_dims='spatial', dist_sync_on_step=False)[source]#

Bases: Metric

Average Winkler interval score across multiple central coverage levels.

This is the interval-score analogue of MultiCoverage: it evaluates a grid of nominal central prediction intervals and returns one scalar by averaging the Winkler score across interval levels, time, space, channels, and samples. Lower values are better.

Note

This is an unweighted average of central interval scores across the chosen coverage grid. The weighted interval score (WIS) uses a related multi-level interval-score construction with prescribed weights.

References

Bracher, J., Ray, E. L., Gneiting, T., & Reich, N. G. (2021). Evaluating epidemic forecasts in an interval format. PLOS Computational Biology, 17(2), e1008618. https://doi.org/10.1371/journal.pcbi.1008618

Parameters:
  • coverage_levels (list[float] | None)

  • score_dims (Literal['spatial', 'temporal'] | None)

  • dist_sync_on_step (bool)

name: str = 'multiwinkler'#
update(y_pred, y_true)[source]#

Update metric state with one validation batch.

Parameters:
Return type:

None

compute()[source]#

Return average Winkler score across levels and reduced dimensions.

Return type:

Tensor

reset()[source]#

Reset metric state and dynamic state-shape flag.

Return type:

None

score(y_pred, y_true)[source]#

Compute per-level Winkler scores before batch aggregation.

Parameters:
Return type:

Tensor

plot(save_path=None, title='Winkler Interval Scores', cmap_str='viridis', save_csv=True)[source]#

Plot Winkler score against nominal interval coverage.

Parameters:
  • save_path (str | None)

  • title (str)

  • cmap_str (str)

  • save_csv (bool)