From 7e15172077362638c749e8c3e7e42ff4a8a32bf6 Mon Sep 17 00:00:00 2001 From: abhaygoudannavar Date: Fri, 27 Feb 2026 18:46:01 +0000 Subject: [PATCH 1/3] feat: Add Energy Score metric for multivariate ensemble verification Add energy_score class to earth2studio.statistics module implementing the multivariate generalization of CRPS for ensemble forecast verification. Features: - Follows existing Metric protocol (ensemble_dimension, reduction_dimensions) - Supports configurable multivariate_dimensions for flexible norm computation - Fair (unbiased) and standard estimator variants - Efficient pairwise computation via torch.cdist - Optional weighted reduction over additional dimensions - Pure PyTorch, no new dependencies Includes comprehensive test suite (9 tests) with accuracy, non-negativity, perfect ensemble, and multi-dimensional tests. References: Gneiting and Raftery (2007), JASA 102(477), 359-378 --- earth2studio/statistics/__init__.py | 1 + earth2studio/statistics/energy_score.py | 327 +++++++++++++++++++++ test/statistics/test_energy_score.py | 359 ++++++++++++++++++++++++ 3 files changed, 687 insertions(+) create mode 100644 earth2studio/statistics/energy_score.py create mode 100644 test/statistics/test_energy_score.py diff --git a/earth2studio/statistics/__init__.py b/earth2studio/statistics/__init__.py index 13bf2a7cb..2952920cc 100644 --- a/earth2studio/statistics/__init__.py +++ b/earth2studio/statistics/__init__.py @@ -18,6 +18,7 @@ from .base import Metric, Statistic from .brier import brier_score from .crps import crps +from .energy_score import energy_score from .fss import fss from .lsd import log_spectral_distance from .moments import mean, std, variance # noqa diff --git a/earth2studio/statistics/energy_score.py b/earth2studio/statistics/energy_score.py new file mode 100644 index 000000000..453afa0f0 --- /dev/null +++ b/earth2studio/statistics/energy_score.py @@ -0,0 +1,327 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from earth2studio.statistics.moments import mean +from earth2studio.utils import handshake_coords, handshake_dim +from earth2studio.utils.type import CoordSystem + + +class energy_score: + """ + Compute the Energy Score for multivariate ensemble forecast verification. + + The Energy Score is the multivariate generalization of CRPS. Given an + ensemble forecast {x_1, ..., x_M} and an observation y, the Energy Score + is defined as: + + ES = (1/M) * sum_m ||x_m - y|| - 1/(2*M^2) * sum_m sum_m' ||x_m - x_m'|| + + where ||.|| denotes the Euclidean norm computed over the multivariate + dimensions. This is a proper scoring rule that is minimized when the + forecast distribution matches the true distribution. + + Unlike CRPS which evaluates each variable/grid point independently, the + Energy Score captures whether the ensemble preserves spatial correlations + across variables and grid points. + + Parameters + ---------- + ensemble_dimension: str + A name corresponding to the dimension to perform the ensemble + reduction over. Example: 'ensemble' + multivariate_dimensions: list[str] + Dimensions over which to compute the Euclidean norm. + Example: ['variable', 'lat', 'lon'] for full spatial ES, or + ['variable'] for per-grid-point multivariate ES. + reduction_dimensions: list[str], optional + Dimensions over which to average the energy score after computation. + By default None (no additional reduction). + weights: torch.Tensor, optional + Weights for the reduction dimensions. Must have the same number of + dimensions as passed in reduction_dimensions. By default None. + fair: bool, optional + If True, use the fair (unbiased) Energy Score estimator with the + correction factor M/(M-1). By default False. + + References + ---------- + Gneiting, T. and Raftery, A. E. (2007), "Strictly Proper Scoring Rules, + Prediction, and Estimation", Journal of the American Statistical + Association, 102(477), 359-378. + """ + + def __init__( + self, + ensemble_dimension: str, + multivariate_dimensions: list[str], + reduction_dimensions: list[str] | None = None, + weights: torch.Tensor = None, + fair: bool = False, + ): + if not isinstance(ensemble_dimension, str): + raise ValueError("Error! ensemble_dimension must be a string, not a list.") + if ( + not isinstance(multivariate_dimensions, list) + or len(multivariate_dimensions) == 0 + ): + raise ValueError( + "Error! multivariate_dimensions must be a non-empty list of strings." + ) + + self.ensemble_dimension = ensemble_dimension + self.multivariate_dimensions = multivariate_dimensions + self._reduction_dimensions = reduction_dimensions + self.fair = fair + if reduction_dimensions is not None: + self.mean = mean(reduction_dimensions, weights=weights, batch_update=False) + + def __str__(self) -> str: + dims = ( + self._reduction_dimensions if self._reduction_dimensions is not None else [] + ) + return "_".join(dims + ["energy_score"]) + + @property + def reduction_dimensions(self) -> list[str]: + """All dimensions that will be reduced/removed from the output.""" + dims = [self.ensemble_dimension] + self.multivariate_dimensions + if self._reduction_dimensions is not None: + dims = dims + self._reduction_dimensions + return dims + + def output_coords(self, input_coords: CoordSystem) -> CoordSystem: + """Output coordinate system of the computed statistic, corresponding to + the given input coordinates. + + Parameters + ---------- + input_coords : CoordSystem + Input coordinate system to transform into output_coords + + Returns + ------- + CoordSystem + Coordinate system dictionary + """ + output_coords = input_coords.copy() + for dimension in self.reduction_dimensions: + handshake_dim(input_coords, dimension) + output_coords.pop(dimension) + + return output_coords + + def __call__( + self, + x: torch.Tensor, + x_coords: CoordSystem, + y: torch.Tensor, + y_coords: CoordSystem, + ) -> tuple[torch.Tensor, CoordSystem]: + """ + Apply the Energy Score metric to ensemble forecast `x` and observation `y`. + + Parameters + ---------- + x : torch.Tensor + Ensemble forecast tensor. Must contain the ensemble dimension. + x_coords : CoordSystem + Coordinate system describing the `x` tensor. Must contain + `ensemble_dimension` and all `multivariate_dimensions`. + y : torch.Tensor + Observation tensor. Must not contain the ensemble dimension. + y_coords : CoordSystem + Coordinate system describing the `y` tensor. Must contain + all `multivariate_dimensions`. + + Returns + ------- + tuple[torch.Tensor, CoordSystem] + Energy Score tensor with appropriate reduced coordinates. + """ + # Validate reduction dimensions exist in x_coords + if not all(rd in x_coords for rd in self.reduction_dimensions): + raise ValueError( + "Initialized reduction dimension does not appear in passed coords" + ) + + # Ensemble dimension should be in x but not in y + if self.ensemble_dimension in y_coords: + raise ValueError( + f"{self.ensemble_dimension} should not be in y_coords but is." + ) + if x.ndim != y.ndim + 1: + raise ValueError( + "x and y must have broadcastable shapes but got " + + f"{x.shape} and {y.shape}" + ) + + # Validate multivariate dimensions exist in y_coords + for mv_dim in self.multivariate_dimensions: + if mv_dim not in y_coords: + raise ValueError( + f"Multivariate dimension '{mv_dim}' not found in y_coords." + ) + + # Input coordinate checking (skip ensemble dim) + coord_count = 0 + for c in x_coords: + if c != self.ensemble_dimension: + handshake_dim(y_coords, c, coord_count) + handshake_coords(y_coords, x_coords, c) + coord_count += 1 + + # Compute the Energy Score + out = _energy_score_compute( + x, + y, + x_coords, + self.ensemble_dimension, + self.multivariate_dimensions, + self.fair, + ) + + # Build output coords: remove ensemble and multivariate dims from y_coords + out_coords = y_coords.copy() + for mv_dim in self.multivariate_dimensions: + out_coords.pop(mv_dim, None) + + # Apply additional reduction if requested + if self._reduction_dimensions is not None: + out, out_coords = self.mean(out, out_coords) + + return out, out_coords + + +def _energy_score_compute( + ensemble: torch.Tensor, + truth: torch.Tensor, + x_coords: CoordSystem, + ensemble_dimension: str, + multivariate_dimensions: list[str], + fair: bool = False, +) -> torch.Tensor: + """ + Compute the Energy Score. + + ES = (1/M) * sum_m ||x_m - y|| - 1/(2*M^2) * sum_m sum_m' ||x_m - x_m'|| + + For the fair (unbiased) estimator: + ES = (1/M) * sum_m ||x_m - y|| - 1/(2*M*(M-1)) * sum_{m!=m'} ||x_m - x_m'|| + + Parameters + ---------- + ensemble : torch.Tensor + Ensemble forecast tensor with ensemble dimension. + truth : torch.Tensor + Observation tensor without ensemble dimension. + x_coords : CoordSystem + Coordinate system for the ensemble tensor. + ensemble_dimension : str + Name of the ensemble dimension. + multivariate_dimensions : list[str] + Dimensions over which to compute the Euclidean norm. + fair : bool + Whether to use the fair (unbiased) estimator. + + Returns + ------- + torch.Tensor + Energy Score values. + """ + coord_keys = list(x_coords.keys()) + ens_dim = coord_keys.index(ensemble_dimension) + M = ensemble.shape[ens_dim] + + # Get indices for multivariate dims in ensemble tensor + mv_dims = [coord_keys.index(d) for d in multivariate_dimensions] + + # Term 1: (1/M) * sum_m ||x_m - y|| + # Expand truth to match ensemble shape along ensemble dim + diff_xy = ensemble - truth.unsqueeze(ens_dim) + # Compute L2 norm over multivariate dims + # We need to square, sum over mv_dims, then sqrt + term1 = _l2_norm_over_dims(diff_xy, mv_dims).mean(dim=ens_dim) + + # Term 2: 1/(2*M^2) * sum_m sum_m' ||x_m - x_m'|| + # or for fair: 1/(2*M*(M-1)) * sum_{m!=m'} ||x_m - x_m'|| + # We use torch.cdist for efficiency + + # We need to reshape ensemble so that ens_dim and mv_dims are isolated + # Move ens_dim to position 0 and mv_dims to the end, flatten the rest + # into a batch dimension for cdist + + # Strategy: flatten multivariate dims into a single dim, then use cdist + # First, move ensemble dim to -2 and multivariate dims to -1 (flattened) + + # Determine permutation: [remaining_dims..., ens_dim, mv_dims_flattened...] + all_dims = list(range(ensemble.ndim)) + remaining_dims = [d for d in all_dims if d != ens_dim and d not in mv_dims] + perm = remaining_dims + [ens_dim] + mv_dims + x_perm = ensemble.permute(*perm) + + # Now shape is (*remaining, M, *mv_sizes) + remaining_shape = x_perm.shape[: len(remaining_dims)] + mv_size = 1 + for d in mv_dims: + mv_size *= ensemble.shape[d] + + # Reshape to (batch, M, D) where D = product of multivariate dim sizes + batch_size = 1 + for s in remaining_shape: + batch_size *= s + x_flat = x_perm.reshape(batch_size, M, mv_size) + + # Use cdist for pairwise distances + pairwise_dists = torch.cdist(x_flat, x_flat, p=2) # (batch, M, M) + + if fair: + # Fair estimator: exclude diagonal (m == m'), divide by M*(M-1) + if M < 2: + raise ValueError("Fair Energy Score requires at least 2 ensemble members.") + # Zero out diagonal then sum, which gives sum_{m!=m'} + mask = ~torch.eye(M, device=pairwise_dists.device, dtype=torch.bool) + pairwise_sum = (pairwise_dists * mask.unsqueeze(0)).sum(dim=(-1, -2)) + term2 = pairwise_sum / (2.0 * M * (M - 1)) + else: + # Standard estimator: sum all pairs including diagonal, divide by 2*M^2 + pairwise_sum = pairwise_dists.sum(dim=(-1, -2)) + term2 = pairwise_sum / (2.0 * M * M) + + # Reshape term2 back to remaining_shape + term2 = term2.reshape(remaining_shape) + + es = term1 - term2 + return es + + +def _l2_norm_over_dims(x: torch.Tensor, dims: list[int]) -> torch.Tensor: + """Compute L2 norm over specified dimensions. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + dims : list[int] + Dimensions to compute the norm over (will be reduced). + + Returns + ------- + torch.Tensor + Tensor with specified dimensions removed. + """ + return torch.sqrt((x * x).sum(dim=dims)) diff --git a/test/statistics/test_energy_score.py b/test/statistics/test_energy_score.py new file mode 100644 index 000000000..d485d3ef5 --- /dev/null +++ b/test/statistics/test_energy_score.py @@ -0,0 +1,359 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from collections import OrderedDict + +import numpy as np +import pytest +import torch + +from earth2studio.statistics import energy_score, lat_weight +from earth2studio.statistics.energy_score import _energy_score_compute +from earth2studio.utils.coords import handshake_coords, handshake_dim + +lat_weights = lat_weight(torch.as_tensor(np.linspace(-90.0, 90.0, 10))) + + +@pytest.mark.parametrize( + "ensemble_dimension", + [ + "ensemble", + ], +) +@pytest.mark.parametrize( + "reduction_weights", + [ + (None, None), + (["lat"], lat_weights), + ], +) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_energy_score( + ensemble_dimension: str, + reduction_weights: tuple[list[str], np.ndarray], + device: str, +) -> None: + + x = torch.randn((8, 1, 2, 10, 20), device=device) + + x_coords = OrderedDict( + { + "ensemble": np.arange(8), + "time": np.array([np.datetime64("1993-04-05T00:00")]), + "variable": np.array(["t2m", "tcwv"]), + "lat": np.linspace(-90.0, 90.0, 10), + "lon": np.linspace(0.0, 360.0, 20, endpoint=False), + } + ) + + y_coords = copy.deepcopy(x_coords) + y_coords.pop(ensemble_dimension) + y_shape = [len(y_coords[c]) for c in y_coords] + y = torch.randn(y_shape, device=device) + + reduction_dimensions, weights = reduction_weights + if weights is not None: + weights = weights.to(device) + + # Use lon as the multivariate dimension for testing + ES = energy_score( + ensemble_dimension, + multivariate_dimensions=["lon"], + reduction_dimensions=reduction_dimensions, + weights=weights, + ) + + z, c = ES(x, x_coords, y, y_coords) + + # Ensemble dim and multivariate dims should be removed + assert ensemble_dimension not in c + assert "lon" not in c + if reduction_dimensions is not None: + assert all(rd not in c for rd in reduction_dimensions) + assert list(z.shape) == [len(val) for val in c.values()] + + # Check output coords match + out_test_coords = ES.output_coords(x_coords) + for i, ci in enumerate(c): + handshake_dim(out_test_coords, ci, i) + handshake_coords(out_test_coords, c, ci) + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_energy_score_fair(device: str) -> None: + """Test fair Energy Score with > 1 ensemble members.""" + x = torch.randn((8, 1, 2, 10, 20), device=device) + + x_coords = OrderedDict( + { + "ensemble": np.arange(8), + "time": np.array([np.datetime64("1993-04-05T00:00")]), + "variable": np.array(["t2m", "tcwv"]), + "lat": np.linspace(-90.0, 90.0, 10), + "lon": np.linspace(0.0, 360.0, 20, endpoint=False), + } + ) + + y_coords = copy.deepcopy(x_coords) + y_coords.pop("ensemble") + y_shape = [len(y_coords[c]) for c in y_coords] + y = torch.randn(y_shape, device=device) + + ES = energy_score( + "ensemble", + multivariate_dimensions=["lon"], + fair=True, + ) + + z, c = ES(x, x_coords, y, y_coords) + + assert "ensemble" not in c + assert "lon" not in c + assert list(z.shape) == [len(val) for val in c.values()] + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_energy_score_failures(device: str) -> None: + x = torch.randn((8, 1, 2, 10, 20), device=device) + + x_coords = OrderedDict( + { + "ensemble": np.arange(8), + "time": np.array([np.datetime64("1993-04-05T00:00")]), + "variable": np.array(["t2m", "tcwv"]), + "lat": np.linspace(-90.0, 90.0, 10), + "lon": np.linspace(0.0, 360.0, 20, endpoint=False), + } + ) + + # Error: ensemble_dimension not a string + with pytest.raises(ValueError): + energy_score(["ensemble"], multivariate_dimensions=["lon"]) + + # Error: multivariate_dimensions is empty + with pytest.raises(ValueError): + energy_score("ensemble", multivariate_dimensions=[]) + + ES = energy_score("ensemble", multivariate_dimensions=["lon"]) + + # Error: ensemble_dimension in y_coords + with pytest.raises(ValueError): + y_coords = copy.deepcopy(x_coords) + y_shape = [len(y_coords[c]) for c in y_coords] + y = torch.randn(y_shape, device=device) + ES(x, x_coords, y, y_coords) + + # Error: x and y shapes not broadcastable + with pytest.raises(ValueError): + y_coords = OrderedDict({"phony": np.arange(1)}) + for c in x_coords: + if c != "ensemble": + y_coords[c] = x_coords[c] + + y_shape = [len(y_coords[c]) for c in y_coords] + y = torch.randn(y_shape, device=device) + ES(x, x_coords, y, y_coords) + + # Error: reduction_dimension not in x_coords + with pytest.raises(ValueError): + y_coords = OrderedDict() + for c in x_coords: + if c != "ensemble": + y_coords[c] = x_coords[c] + + y_shape = [len(y_coords[c]) for c in y_coords] + y = torch.randn(y_shape, device=device) + + bad_x_coords = copy.deepcopy(x_coords) + bad_x_coords.pop("ensemble") + ES(x, bad_x_coords, y, y_coords) + + # Error: fair with < 2 ensemble members + with pytest.raises(ValueError): + ES_fair = energy_score("ensemble", multivariate_dimensions=["lon"], fair=True) + x1 = torch.randn((1, 1, 2, 10, 20), device=device) + x1_coords = copy.deepcopy(x_coords) + x1_coords["ensemble"] = np.arange(1) + y_coords = OrderedDict() + for c in x1_coords: + if c != "ensemble": + y_coords[c] = x1_coords[c] + y_shape = [len(y_coords[c]) for c in y_coords] + y = torch.randn(y_shape, device=device) + ES_fair(x1, x1_coords, y, y_coords) + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_energy_score_accuracy( + device: str, rtol: float = 1e-2, atol: float = 1e-2 +) -> None: + """Test Energy Score accuracy against brute-force computation.""" + torch.manual_seed(42) + M = 50 + D = 5 + + # Generate ensemble and truth + x = torch.randn((M, D), device=device, dtype=torch.float64) + y = torch.randn((D,), device=device, dtype=torch.float64) + + # Manual brute-force computation + # Term 1: (1/M) * sum_m ||x_m - y|| + term1 = torch.norm(x - y.unsqueeze(0), dim=1).mean() + + # Term 2: 1/(2*M^2) * sum_m sum_m' ||x_m - x_m'|| + dists = torch.cdist(x.unsqueeze(0), x.unsqueeze(0), p=2).squeeze(0) + term2 = dists.sum() / (2.0 * M * M) + + expected_es = term1 - term2 + + # Compute via our function + x_coords = OrderedDict( + { + "ensemble": np.arange(M), + "variable": np.array([f"v{i}" for i in range(D)]), + } + ) + + computed_es = _energy_score_compute( + x, + y, + x_coords, + ensemble_dimension="ensemble", + multivariate_dimensions=["variable"], + fair=False, + ) + + assert torch.allclose( + computed_es, + expected_es, + rtol=rtol, + atol=atol, + ), f"Expected {expected_es.item()}, got {computed_es.item()}" + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_energy_score_nonnegativity(device: str) -> None: + """Energy Score should always be non-negative.""" + torch.manual_seed(123) + M = 20 + x = torch.randn((M, 3, 10), device=device) + y = torch.randn((3, 10), device=device) + + x_coords = OrderedDict( + { + "ensemble": np.arange(M), + "variable": np.array(["a", "b", "c"]), + "lon": np.linspace(0, 360, 10), + } + ) + + es = _energy_score_compute( + x, + y, + x_coords, + ensemble_dimension="ensemble", + multivariate_dimensions=["variable", "lon"], + fair=False, + ) + + assert es.item() >= 0.0, f"Energy Score should be non-negative, got {es.item()}" + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_energy_score_perfect_ensemble(device: str) -> None: + """Energy score should be zero for a perfect deterministic ensemble.""" + torch.manual_seed(0) + M = 10 + D = 8 + + y = torch.randn((D,), device=device, dtype=torch.float64) + x = y.unsqueeze(0).repeat(M, 1) + + x_coords = OrderedDict( + { + "ensemble": np.arange(M), + "variable": np.array([f"v{i}" for i in range(D)]), + } + ) + + es = _energy_score_compute( + x, + y, + x_coords, + ensemble_dimension="ensemble", + multivariate_dimensions=["variable"], + fair=False, + ) + + assert torch.allclose( + es, + torch.zeros_like(es), + atol=1e-6, + ), f"Perfect ensemble should have ES ~= 0, got {es.item()}" + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_energy_score_str(device: str) -> None: + """Test string representation.""" + ES = energy_score( + "ensemble", + multivariate_dimensions=["lon"], + reduction_dimensions=["lat"], + ) + assert str(ES) == "lat_energy_score" + + ES2 = energy_score( + "ensemble", + multivariate_dimensions=["lon"], + ) + assert str(ES2) == "energy_score" + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_energy_score_multiple_mv_dims(device: str) -> None: + """Test with multiple multivariate dimensions.""" + torch.manual_seed(99) + M = 5 + x = torch.randn((M, 1, 2, 10, 20), device=device) + y = torch.randn((1, 2, 10, 20), device=device) + + x_coords = OrderedDict( + { + "ensemble": np.arange(M), + "time": np.array([np.datetime64("2024-01-01")]), + "variable": np.array(["t2m", "u10m"]), + "lat": np.linspace(-90.0, 90.0, 10), + "lon": np.linspace(0.0, 360.0, 20, endpoint=False), + } + ) + + y_coords = copy.deepcopy(x_coords) + y_coords.pop("ensemble") + + ES = energy_score( + "ensemble", + multivariate_dimensions=["variable", "lat", "lon"], + ) + + z, c = ES(x, x_coords, y, y_coords) + + assert "ensemble" not in c + assert "variable" not in c + assert "lat" not in c + assert "lon" not in c + assert "time" in c + assert list(z.shape) == [len(val) for val in c.values()] From 9c31eee32301c9aa254b88423b724dd8825eda7e Mon Sep 17 00:00:00 2001 From: Nicholas Geneva Date: Tue, 17 Mar 2026 17:17:27 +0000 Subject: [PATCH 2/3] Little clean up / updates --- CHANGELOG.md | 1 + earth2studio/statistics/energy_score.py | 75 ++++++----------------- test/statistics/test_energy_score.py | 81 +++++++++++++++++++++++-- 3 files changed, 97 insertions(+), 60 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ff5216be8..802931c2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added data assimilation model class - Added equirectangular interpolation data assimilation model - Adding Beta serve utils with inference server and client implementations +- Added `energy_score` metric for multivariate ensemble forecast verification ### Changed diff --git a/earth2studio/statistics/energy_score.py b/earth2studio/statistics/energy_score.py index 453afa0f0..ca9af4997 100644 --- a/earth2studio/statistics/energy_score.py +++ b/earth2studio/statistics/energy_score.py @@ -32,13 +32,20 @@ class energy_score: ES = (1/M) * sum_m ||x_m - y|| - 1/(2*M^2) * sum_m sum_m' ||x_m - x_m'|| where ||.|| denotes the Euclidean norm computed over the multivariate - dimensions. This is a proper scoring rule that is minimized when the - forecast distribution matches the true distribution. + dimensions. This is a proper scoring rule that is minimized when the forecast + distribution matches the true distribution. Unlike CRPS which evaluates each variable/grid point independently, the Energy Score captures whether the ensemble preserves spatial correlations across variables and grid points. + .. warning:: + Setting ``multivariate_dimensions`` to large spatial grids (e.g., + ``['lat', 'lon']`` with 721x1440 = ~1M elements) produces a feature + vector of that size per ensemble member. For M=50 members this requires + ~200 MB per tensor in float32. Prefer selecting a subset of dimensions + unless full-field verification is explicitly needed. + Parameters ---------- ensemble_dimension: str @@ -55,8 +62,10 @@ class energy_score: Weights for the reduction dimensions. Must have the same number of dimensions as passed in reduction_dimensions. By default None. fair: bool, optional - If True, use the fair (unbiased) Energy Score estimator with the - correction factor M/(M-1). By default False. + If True, use the fair (unbiased) Energy Score estimator, which + replaces the ``1/(2*M^2)`` denominator with ``1/(2*M*(M-1))``, + excluding the zero self-distance diagonal from the denominator count. + Requires at least 2 ensemble members. By default False. References ---------- @@ -246,82 +255,38 @@ def _energy_score_compute( coord_keys = list(x_coords.keys()) ens_dim = coord_keys.index(ensemble_dimension) M = ensemble.shape[ens_dim] - - # Get indices for multivariate dims in ensemble tensor mv_dims = [coord_keys.index(d) for d in multivariate_dimensions] # Term 1: (1/M) * sum_m ||x_m - y|| - # Expand truth to match ensemble shape along ensemble dim diff_xy = ensemble - truth.unsqueeze(ens_dim) - # Compute L2 norm over multivariate dims - # We need to square, sum over mv_dims, then sqrt - term1 = _l2_norm_over_dims(diff_xy, mv_dims).mean(dim=ens_dim) - - # Term 2: 1/(2*M^2) * sum_m sum_m' ||x_m - x_m'|| - # or for fair: 1/(2*M*(M-1)) * sum_{m!=m'} ||x_m - x_m'|| - # We use torch.cdist for efficiency - - # We need to reshape ensemble so that ens_dim and mv_dims are isolated - # Move ens_dim to position 0 and mv_dims to the end, flatten the rest - # into a batch dimension for cdist - - # Strategy: flatten multivariate dims into a single dim, then use cdist - # First, move ensemble dim to -2 and multivariate dims to -1 (flattened) + term1 = torch.sqrt((diff_xy * diff_xy).sum(dim=mv_dims)).mean(dim=ens_dim) - # Determine permutation: [remaining_dims..., ens_dim, mv_dims_flattened...] + # Term 2: pairwise ensemble spread via cdist + # Permute to (*remaining, M, *mv) then flatten to (batch, M, D) all_dims = list(range(ensemble.ndim)) remaining_dims = [d for d in all_dims if d != ens_dim and d not in mv_dims] - perm = remaining_dims + [ens_dim] + mv_dims - x_perm = ensemble.permute(*perm) + x_perm = ensemble.permute(*remaining_dims, ens_dim, *mv_dims) - # Now shape is (*remaining, M, *mv_sizes) remaining_shape = x_perm.shape[: len(remaining_dims)] mv_size = 1 for d in mv_dims: mv_size *= ensemble.shape[d] - - # Reshape to (batch, M, D) where D = product of multivariate dim sizes batch_size = 1 for s in remaining_shape: batch_size *= s x_flat = x_perm.reshape(batch_size, M, mv_size) - # Use cdist for pairwise distances - pairwise_dists = torch.cdist(x_flat, x_flat, p=2) # (batch, M, M) + pairwise_dists = torch.cdist(x_flat, x_flat, p=2) if fair: - # Fair estimator: exclude diagonal (m == m'), divide by M*(M-1) if M < 2: - raise ValueError("Fair Energy Score requires at least 2 ensemble members.") - # Zero out diagonal then sum, which gives sum_{m!=m'} + raise ValueError(f"Fair Energy Score requires ensemble size >= 2, got {M}.") mask = ~torch.eye(M, device=pairwise_dists.device, dtype=torch.bool) pairwise_sum = (pairwise_dists * mask.unsqueeze(0)).sum(dim=(-1, -2)) term2 = pairwise_sum / (2.0 * M * (M - 1)) else: - # Standard estimator: sum all pairs including diagonal, divide by 2*M^2 pairwise_sum = pairwise_dists.sum(dim=(-1, -2)) term2 = pairwise_sum / (2.0 * M * M) - # Reshape term2 back to remaining_shape term2 = term2.reshape(remaining_shape) - - es = term1 - term2 - return es - - -def _l2_norm_over_dims(x: torch.Tensor, dims: list[int]) -> torch.Tensor: - """Compute L2 norm over specified dimensions. - - Parameters - ---------- - x : torch.Tensor - Input tensor. - dims : list[int] - Dimensions to compute the norm over (will be reduced). - - Returns - ------- - torch.Tensor - Tensor with specified dimensions removed. - """ - return torch.sqrt((x * x).sum(dim=dims)) + return term1 - term2 diff --git a/test/statistics/test_energy_score.py b/test/statistics/test_energy_score.py index d485d3ef5..cd4c86665 100644 --- a/test/statistics/test_energy_score.py +++ b/test/statistics/test_energy_score.py @@ -21,17 +21,20 @@ import pytest import torch -from earth2studio.statistics import energy_score, lat_weight +from earth2studio.statistics import crps, energy_score, lat_weight +from earth2studio.statistics.base import Metric from earth2studio.statistics.energy_score import _energy_score_compute from earth2studio.utils.coords import handshake_coords, handshake_dim lat_weights = lat_weight(torch.as_tensor(np.linspace(-90.0, 90.0, 10))) +@pytest.mark.parametrize("fair", [True, False]) @pytest.mark.parametrize( "ensemble_dimension", [ "ensemble", + "time", ], ) @pytest.mark.parametrize( @@ -46,6 +49,7 @@ def test_energy_score( ensemble_dimension: str, reduction_weights: tuple[list[str], np.ndarray], device: str, + fair: bool, ) -> None: x = torch.randn((8, 1, 2, 10, 20), device=device) @@ -60,6 +64,14 @@ def test_energy_score( } ) + # Swap ensemble_dimension into coords (e.g. treat "time" as the ensemble dim) + if ensemble_dimension != "ensemble": + keys = list(x_coords.keys()) + ens_idx = keys.index("ensemble") + mv_idx = keys.index(ensemble_dimension) + keys[ens_idx], keys[mv_idx] = keys[mv_idx], keys[ens_idx] + x_coords = OrderedDict((k, x_coords[k]) for k in keys) + y_coords = copy.deepcopy(x_coords) y_coords.pop(ensemble_dimension) y_shape = [len(y_coords[c]) for c in y_coords] @@ -75,6 +87,7 @@ def test_energy_score( multivariate_dimensions=["lon"], reduction_dimensions=reduction_dimensions, weights=weights, + fair=fair, ) z, c = ES(x, x_coords, y, y_coords) @@ -273,8 +286,9 @@ def test_energy_score_nonnegativity(device: str) -> None: assert es.item() >= 0.0, f"Energy Score should be non-negative, got {es.item()}" +@pytest.mark.parametrize("fair", [True, False]) @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) -def test_energy_score_perfect_ensemble(device: str) -> None: +def test_energy_score_perfect_ensemble(device: str, fair: bool) -> None: """Energy score should be zero for a perfect deterministic ensemble.""" torch.manual_seed(0) M = 10 @@ -296,7 +310,7 @@ def test_energy_score_perfect_ensemble(device: str) -> None: x_coords, ensemble_dimension="ensemble", multivariate_dimensions=["variable"], - fair=False, + fair=fair, ) assert torch.allclose( @@ -306,8 +320,7 @@ def test_energy_score_perfect_ensemble(device: str) -> None: ), f"Perfect ensemble should have ES ~= 0, got {es.item()}" -@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) -def test_energy_score_str(device: str) -> None: +def test_energy_score_str() -> None: """Test string representation.""" ES = energy_score( "ensemble", @@ -323,6 +336,12 @@ def test_energy_score_str(device: str) -> None: assert str(ES2) == "energy_score" +def test_energy_score_metric_protocol() -> None: + """energy_score must conform to the Metric protocol.""" + ES = energy_score("ensemble", multivariate_dimensions=["lon"]) + assert isinstance(ES, Metric) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) def test_energy_score_multiple_mv_dims(device: str) -> None: """Test with multiple multivariate dimensions.""" @@ -357,3 +376,55 @@ def test_energy_score_multiple_mv_dims(device: str) -> None: assert "lon" not in c assert "time" in c assert list(z.shape) == [len(val) for val in c.values()] + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_energy_score_univariate_equals_crps( + device: str, rtol: float = 1e-4, atol: float = 1e-4 +) -> None: + # For a single multivariate dimension of size 1, ES must equal CRPS. + torch.manual_seed(7) + M = 30 + + x = torch.randn((M, 5), device=device, dtype=torch.float64) + y = torch.randn((5,), device=device, dtype=torch.float64) + + x_coords = OrderedDict( + { + "ensemble": np.arange(M), + "lat": np.linspace(-90.0, 90.0, 5), + } + ) + y_coords = OrderedDict({"lat": x_coords["lat"]}) + + # Energy score over a single scalar 'variable' dimension (D=1) == CRPS + # Reshape to add a size-1 variable dimension + x_mv = x.unsqueeze(-1) # (M, 5, 1) + y_mv = y.unsqueeze(-1) # (5, 1) + x_coords_mv = OrderedDict( + { + "ensemble": np.arange(M), + "lat": x_coords["lat"], + "variable": np.array(["v0"]), + } + ) + + es_vals = _energy_score_compute( + x_mv, + y_mv, + x_coords_mv, + ensemble_dimension="ensemble", + multivariate_dimensions=["variable"], + fair=False, + ) # shape: (5,) + + # CRPS per grid point: sort ensemble members and compute + CRPS = crps("ensemble", fair=False) + crps_vals, _ = CRPS(x, x_coords, y, y_coords) # shape: (5,) + + assert torch.allclose( + es_vals, + crps_vals, + rtol=rtol, + atol=atol, + ), f"ES (D=1) should equal CRPS but max diff = {(es_vals - crps_vals).abs().max().item()}" From 9c4373c36aeb365516ace1d8c3a09b288b75f521 Mon Sep 17 00:00:00 2001 From: Nicholas Geneva Date: Tue, 17 Mar 2026 17:23:15 +0000 Subject: [PATCH 3/3] Docs --- docs/modules/statistics.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/modules/statistics.rst b/docs/modules/statistics.rst index 7800ccbf7..7f7c54d89 100644 --- a/docs/modules/statistics.rst +++ b/docs/modules/statistics.rst @@ -17,6 +17,7 @@ Various statistic and metric calculations for analysing inference data. statistics.acc statistics.brier statistics.crps + statistics.energy_score statistics.fss statistics.lsd statistics.mae