diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ba72b8e5..143086a1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed ISD data frame return to master schema - handshake_coords is now accepting list of dimensions while remaining backwards-compatible - Updated CBottle infill to mixture of model checkpoints +- Updated GraphCastOperational and GraphCastSmall latitude input / output to be [90,-90] ### Deprecated diff --git a/earth2studio/models/px/graphcast_operational.py b/earth2studio/models/px/graphcast_operational.py index cd036fa25..8989c6063 100644 --- a/earth2studio/models/px/graphcast_operational.py +++ b/earth2studio/models/px/graphcast_operational.py @@ -237,7 +237,7 @@ def __init__( ] ), "variable": np.array(VARIABLES), - "lat": np.linspace(-90, 90, 721, endpoint=True), + "lat": np.linspace(90, -90, 721, endpoint=True), "lon": np.linspace(0, 360, 1440, endpoint=False), } ) @@ -248,7 +248,7 @@ def __init__( "time": np.empty(0), "lead_time": np.array([np.timedelta64(6, "h")]), "variable": np.array(VARIABLES + ["tp06"]), - "lat": np.linspace(-90, 90, 721, endpoint=True), + "lat": np.linspace(90, -90, 721, endpoint=True), "lon": np.linspace(0, 360, 1440, endpoint=False), } ) @@ -561,7 +561,9 @@ def iterator_result_to_tensor(self, dataset: xr.Dataset) -> torch.Tensor: .T.transpose(..., "time", "lead_time", "variable", "lat", "lon") ) - return torch.from_numpy(dataarray.to_numpy().copy()) + out = torch.from_numpy(dataarray.to_numpy().copy()) + out = out.flip(-2) # Flip lat from ascending (-90->90, JAX native) to (90->-90) + return out @staticmethod def get_jax_device_from_tensor(x: torch.Tensor) -> "jax.Device": diff --git a/earth2studio/models/px/graphcast_small.py b/earth2studio/models/px/graphcast_small.py index 0650df74d..c8d427a4e 100644 --- a/earth2studio/models/px/graphcast_small.py +++ b/earth2studio/models/px/graphcast_small.py @@ -169,7 +169,7 @@ class GraphCastSmall(torch.nn.Module, AutoModelMixin, PrognosticMixin): A smaller, low-resolution version of GraphCast (1 degree resolution, 13 pressure levels and a smaller mesh), trained on ERA5 data from 1979 to 2015. This model is useful for running with lower memory and compute constraints while maintaining good - forecast skill. The model operates on a 1-degree lat-lon grid (south-pole including) + forecast skill. The model operates on a 1-degree lat-lon grid (pole including) equirectangular grid with 85 variables including: - Surface variables (2m temperature, 10m winds, etc.) @@ -238,7 +238,7 @@ def __init__( ] ), "variable": np.array(VARIABLES), - "lat": np.linspace(-90, 90, 181, endpoint=True), + "lat": np.linspace(90, -90, 181, endpoint=True), "lon": np.linspace(0, 360, 360, endpoint=False), } ) @@ -249,7 +249,7 @@ def __init__( "time": np.empty(0), "lead_time": np.array([np.timedelta64(6, "h")]), "variable": np.array(VARIABLES), - "lat": np.linspace(-90, 90, 181, endpoint=True), + "lat": np.linspace(90, -90, 181, endpoint=True), "lon": np.linspace(0, 360, 360, endpoint=False), } ) @@ -559,7 +559,9 @@ def iterator_result_to_tensor(self, dataset: xr.Dataset) -> torch.Tensor: .T.transpose(..., "time", "lead_time", "variable", "lat", "lon") ) - return torch.from_numpy(dataarray.to_numpy().copy()) + out = torch.from_numpy(dataarray.to_numpy().copy()) + out = out.flip(-2) # Flip lat from ascending (-90->90, JAX native) to (90->-90) + return out @staticmethod def get_jax_device_from_tensor(x: torch.Tensor) -> "jax.Device":