Skip to content
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added StormCast SDA model
- Added beta serve utils with inference server and client implementations
- Added HealPix data assimilation (HealDA) model
- Added `energy_score` metric for multivariate ensemble forecast verification

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/modules/statistics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Various statistic and metric calculations for analysing inference data.
statistics.acc
statistics.brier
statistics.crps
statistics.energy_score
statistics.fss
statistics.lsd
statistics.mae
Expand Down
1 change: 1 addition & 0 deletions earth2studio/statistics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .base import Metric, Statistic
from .brier import brier_score
from .crps import crps
from .energy_score import energy_score
from .fss import fss
from .lsd import log_spectral_distance
from .moments import mean, std, variance # noqa
Expand Down
292 changes: 292 additions & 0 deletions earth2studio/statistics/energy_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from earth2studio.statistics.moments import mean
from earth2studio.utils import handshake_coords, handshake_dim
from earth2studio.utils.type import CoordSystem


class energy_score:
"""
Compute the Energy Score for multivariate ensemble forecast verification.

The Energy Score is the multivariate generalization of CRPS. Given an
ensemble forecast {x_1, ..., x_M} and an observation y, the Energy Score
is defined as:

ES = (1/M) * sum_m ||x_m - y|| - 1/(2*M^2) * sum_m sum_m' ||x_m - x_m'||

where ||.|| denotes the Euclidean norm computed over the multivariate
dimensions. This is a proper scoring rule that is minimized when the forecast
distribution matches the true distribution.

Unlike CRPS which evaluates each variable/grid point independently, the
Energy Score captures whether the ensemble preserves spatial correlations
across variables and grid points.

.. warning::
Setting ``multivariate_dimensions`` to large spatial grids (e.g.,
``['lat', 'lon']`` with 721x1440 = ~1M elements) produces a feature
vector of that size per ensemble member. For M=50 members this requires
~200 MB per tensor in float32. Prefer selecting a subset of dimensions
unless full-field verification is explicitly needed.

Parameters
----------
ensemble_dimension: str
A name corresponding to the dimension to perform the ensemble
reduction over. Example: 'ensemble'
multivariate_dimensions: list[str]
Dimensions over which to compute the Euclidean norm.
Example: ['variable', 'lat', 'lon'] for full spatial ES, or
['variable'] for per-grid-point multivariate ES.
reduction_dimensions: list[str], optional
Dimensions over which to average the energy score after computation.
By default None (no additional reduction).
weights: torch.Tensor, optional
Weights for the reduction dimensions. Must have the same number of
dimensions as passed in reduction_dimensions. By default None.
fair: bool, optional
If True, use the fair (unbiased) Energy Score estimator, which
replaces the ``1/(2*M^2)`` denominator with ``1/(2*M*(M-1))``,
excluding the zero self-distance diagonal from the denominator count.
Requires at least 2 ensemble members. By default False.

References
----------
Gneiting, T. and Raftery, A. E. (2007), "Strictly Proper Scoring Rules,
Prediction, and Estimation", Journal of the American Statistical
Association, 102(477), 359-378.
"""

def __init__(
self,
ensemble_dimension: str,
multivariate_dimensions: list[str],
reduction_dimensions: list[str] | None = None,
weights: torch.Tensor = None,
fair: bool = False,
):
if not isinstance(ensemble_dimension, str):
raise ValueError("Error! ensemble_dimension must be a string, not a list.")
if (
not isinstance(multivariate_dimensions, list)
or len(multivariate_dimensions) == 0
):
raise ValueError(
"Error! multivariate_dimensions must be a non-empty list of strings."
)

self.ensemble_dimension = ensemble_dimension
self.multivariate_dimensions = multivariate_dimensions
self._reduction_dimensions = reduction_dimensions
self.fair = fair
if reduction_dimensions is not None:
self.mean = mean(reduction_dimensions, weights=weights, batch_update=False)

def __str__(self) -> str:
dims = (
self._reduction_dimensions if self._reduction_dimensions is not None else []
)
return "_".join(dims + ["energy_score"])

@property
def reduction_dimensions(self) -> list[str]:
"""All dimensions that will be reduced/removed from the output."""
dims = [self.ensemble_dimension] + self.multivariate_dimensions
if self._reduction_dimensions is not None:
dims = dims + self._reduction_dimensions
return dims

def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
"""Output coordinate system of the computed statistic, corresponding to
the given input coordinates.

Parameters
----------
input_coords : CoordSystem
Input coordinate system to transform into output_coords

Returns
-------
CoordSystem
Coordinate system dictionary
"""
output_coords = input_coords.copy()
for dimension in self.reduction_dimensions:
handshake_dim(input_coords, dimension)
output_coords.pop(dimension)

return output_coords

def __call__(
self,
x: torch.Tensor,
x_coords: CoordSystem,
y: torch.Tensor,
y_coords: CoordSystem,
) -> tuple[torch.Tensor, CoordSystem]:
"""
Apply the Energy Score metric to ensemble forecast `x` and observation `y`.

Parameters
----------
x : torch.Tensor
Ensemble forecast tensor. Must contain the ensemble dimension.
x_coords : CoordSystem
Coordinate system describing the `x` tensor. Must contain
`ensemble_dimension` and all `multivariate_dimensions`.
y : torch.Tensor
Observation tensor. Must not contain the ensemble dimension.
y_coords : CoordSystem
Coordinate system describing the `y` tensor. Must contain
all `multivariate_dimensions`.

Returns
-------
tuple[torch.Tensor, CoordSystem]
Energy Score tensor with appropriate reduced coordinates.
"""
# Validate reduction dimensions exist in x_coords
if not all(rd in x_coords for rd in self.reduction_dimensions):
raise ValueError(
"Initialized reduction dimension does not appear in passed coords"
)

# Ensemble dimension should be in x but not in y
if self.ensemble_dimension in y_coords:
raise ValueError(
f"{self.ensemble_dimension} should not be in y_coords but is."
)
if x.ndim != y.ndim + 1:
raise ValueError(
"x and y must have broadcastable shapes but got "
+ f"{x.shape} and {y.shape}"
)

# Validate multivariate dimensions exist in y_coords
for mv_dim in self.multivariate_dimensions:
if mv_dim not in y_coords:
raise ValueError(
f"Multivariate dimension '{mv_dim}' not found in y_coords."
)

# Input coordinate checking (skip ensemble dim)
coord_count = 0
for c in x_coords:
if c != self.ensemble_dimension:
handshake_dim(y_coords, c, coord_count)
handshake_coords(y_coords, x_coords, c)
coord_count += 1

# Compute the Energy Score
out = _energy_score_compute(
x,
y,
x_coords,
self.ensemble_dimension,
self.multivariate_dimensions,
self.fair,
)

# Build output coords: remove ensemble and multivariate dims from y_coords
out_coords = y_coords.copy()
for mv_dim in self.multivariate_dimensions:
out_coords.pop(mv_dim, None)

# Apply additional reduction if requested
if self._reduction_dimensions is not None:
out, out_coords = self.mean(out, out_coords)

return out, out_coords


def _energy_score_compute(
ensemble: torch.Tensor,
truth: torch.Tensor,
x_coords: CoordSystem,
ensemble_dimension: str,
multivariate_dimensions: list[str],
fair: bool = False,
) -> torch.Tensor:
"""
Compute the Energy Score.

ES = (1/M) * sum_m ||x_m - y|| - 1/(2*M^2) * sum_m sum_m' ||x_m - x_m'||

For the fair (unbiased) estimator:
ES = (1/M) * sum_m ||x_m - y|| - 1/(2*M*(M-1)) * sum_{m!=m'} ||x_m - x_m'||

Parameters
----------
ensemble : torch.Tensor
Ensemble forecast tensor with ensemble dimension.
truth : torch.Tensor
Observation tensor without ensemble dimension.
x_coords : CoordSystem
Coordinate system for the ensemble tensor.
ensemble_dimension : str
Name of the ensemble dimension.
multivariate_dimensions : list[str]
Dimensions over which to compute the Euclidean norm.
fair : bool
Whether to use the fair (unbiased) estimator.

Returns
-------
torch.Tensor
Energy Score values.
"""
coord_keys = list(x_coords.keys())
ens_dim = coord_keys.index(ensemble_dimension)
M = ensemble.shape[ens_dim]
mv_dims = [coord_keys.index(d) for d in multivariate_dimensions]

# Term 1: (1/M) * sum_m ||x_m - y||
diff_xy = ensemble - truth.unsqueeze(ens_dim)
term1 = torch.sqrt((diff_xy * diff_xy).sum(dim=mv_dims)).mean(dim=ens_dim)

# Term 2: pairwise ensemble spread via cdist
# Permute to (*remaining, M, *mv) then flatten to (batch, M, D)
all_dims = list(range(ensemble.ndim))
remaining_dims = [d for d in all_dims if d != ens_dim and d not in mv_dims]
x_perm = ensemble.permute(*remaining_dims, ens_dim, *mv_dims)

remaining_shape = x_perm.shape[: len(remaining_dims)]
mv_size = 1
for d in mv_dims:
mv_size *= ensemble.shape[d]
batch_size = 1
for s in remaining_shape:
batch_size *= s
x_flat = x_perm.reshape(batch_size, M, mv_size)

pairwise_dists = torch.cdist(x_flat, x_flat, p=2)

if fair:
if M < 2:
raise ValueError(f"Fair Energy Score requires ensemble size >= 2, got {M}.")
mask = ~torch.eye(M, device=pairwise_dists.device, dtype=torch.bool)
pairwise_sum = (pairwise_dists * mask.unsqueeze(0)).sum(dim=(-1, -2))
term2 = pairwise_sum / (2.0 * M * (M - 1))
else:
pairwise_sum = pairwise_dists.sum(dim=(-1, -2))
term2 = pairwise_sum / (2.0 * M * M)

term2 = term2.reshape(remaining_shape)
return term1 - term2
Loading