Skip to content

Update Graphcast Latitudes#742

Merged
NickGeneva merged 2 commits intoNVIDIA:mainfrom
NickGeneva:ngeneva/graphcast_update
Mar 11, 2026
Merged

Update Graphcast Latitudes#742
NickGeneva merged 2 commits intoNVIDIA:mainfrom
NickGeneva:ngeneva/graphcast_update

Conversation

@NickGeneva
Copy link
Collaborator

@NickGeneva NickGeneva commented Mar 11, 2026

Earth2Studio Pull Request

Description

Closes: #710

Validated with

Code
from earth2studio.models.px import GraphCastOperational
from earth2studio.data import GFS
from earth2studio.io import ZarrBackend
from earth2studio.run import deterministic as run

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

# Run forecast
package = GraphCastOperational.load_default_package()
model = GraphCastOperational.load_model(package)
data = GFS()
io = ZarrBackend("outputs/graphcast_operational_forecast.zarr")
run(["2025-01-01T00:00:00"], 4, model, data, io)

# Open results
ds = xr.open_zarr("outputs/graphcast_operational_forecast.zarr")

# Plot t2m and z500 at each lead time
variables = ["t2m", "z500"]
lead_times = ds.coords["lead_time"].values

fig, axes = plt.subplots(
    len(variables), len(lead_times), figsize=(5 * len(lead_times), 4 * len(variables))
)
if axes.ndim == 1:
    axes = axes[np.newaxis, :]

for i, var in enumerate(variables):
    arr = ds[var].isel(time=0)  # first init time
    for j, lt in enumerate(lead_times):
        ax = axes[i, j]
        field = arr.sel(lead_time=lt).values.squeeze()
        im = ax.imshow(
            field,
            origin="upper",
            extent=[0, 360, -90, 90],
            aspect="auto",
            cmap="RdBu_r" if var == "t2m" else "viridis",
        )
        hours = int(lt / np.timedelta64(1, "h"))
        ax.set_title(f"{var}  +{hours}h")
        ax.set_xlabel("Longitude")
        ax.set_ylabel("Latitude")
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.tight_layout()
plt.savefig("outputs/graphcast_operational_forecast.png", dpi=150)
print("Plot saved to outputs/graphcast_operational_forecast.png")

Main
graphcast_operational_forecast

New
graphcast_operational_forecast_1

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.
  • Assess and address Greptile feedback (AI code review bot for guidance; use discretion, addressing all feedback is not required).

Dependencies

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 11, 2026

Greptile Summary

This PR aligns both GraphCastOperational and GraphCastSmall with Earth2Studio's standard latitude convention by changing _input_coords and _output_coords lat arrays from ascending [-90, 90] to descending [90, -90], and adding a .flip(-2) call in iterator_result_to_tensor to flip the latitude dimension of the JAX output before returning it.

Key observations:

  • The existing out_data.reindex(lat=sorted(out_data.lat.values)) call inside from_dataarray_to_dataset (pre-existing, unchanged) correctly converts descending input lat to ascending order for JAX — this is the symmetrical counterpart to the new output flip.
  • The static land_sea_mask and geopotential_at_surface arrays are assigned positionally (no explicit lat coordinate) to out_data after the reindex, so they remain consistent as long as the pre-packaged ERA5 sample NetCDF stores them in ascending lat order — behaviour unchanged by this PR.
  • The first step yielded by _default_generator passes the input tensor x directly (before JAX), which is now also in descending lat order, keeping the convention consistent for all steps of the iterator.
  • Tests verify output shape but do not assert lat value ordering in the output tensor; a sanity assertion on out_coords["lat"][0] > out_coords["lat"][-1] (i.e. descending) would add an extra safety net for future regressions.

Confidence Score: 4/5

  • Safe to merge — the coordinate flip is internally consistent and well-contained.
  • The implementation is correct: the pre-existing reindex(lat=sorted(...)) handles descending→ascending for JAX input, and the new flip(-2) handles ascending→descending for the output. No logic errors found. Minor concern is the absence of an explicit lat-ordering assertion in the unit tests, which means a future regression could go undetected, but this does not block merging.
  • No files require special attention beyond the two model files, which have been reviewed.

Important Files Changed

Filename Overview
earth2studio/models/px/graphcast_operational.py Updated _input_coords and _output_coords lat from ascending [-90, 90] to descending [90, -90], and added out.flip(-2) in iterator_result_to_tensor to convert JAX's native ascending output back to descending order. Logic is correct — the pre-existing reindex(lat=sorted(...)) in from_dataarray_to_dataset already handles the ascending-to-descending conversion for JAX inputs.
earth2studio/models/px/graphcast_small.py Same lat coordinate direction fix as graphcast_operational.py: _input_coords/_output_coords now use [90, -90] and iterator_result_to_tensor applies out.flip(-2). Minor docstring improvement removing "south-pole including" in favour of "pole including".
CHANGELOG.md Added a changelog entry under the "Changed" section noting the updated GraphCast latitude convention. Clear and accurate.

Last reviewed commit: 28e6c0c

@NickGeneva
Copy link
Collaborator Author

/blossom-ci

Copy link
Collaborator

@loliverhennigh loliverhennigh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm

@NickGeneva
Copy link
Collaborator Author

/blossom-ci

@NickGeneva NickGeneva merged commit 1ab4eb8 into NVIDIA:main Mar 11, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🐛[BUG]: Graphcast output data latitude from -90 to 90, instead of 90 to -90

2 participants