Skip to content

Score-based Data Assimilation StormCast#730

Open
NickGeneva wants to merge 68 commits intoNVIDIA:mainfrom
NickGeneva:ngeneva/sda_stormcast
Open

Score-based Data Assimilation StormCast#730
NickGeneva wants to merge 68 commits intoNVIDIA:mainfrom
NickGeneva:ngeneva/sda_stormcast

Conversation

@NickGeneva
Copy link
Collaborator

@NickGeneva NickGeneva commented Mar 5, 2026

Earth2Studio Pull Request

Description

  • Adds stormcast SDA model, a stateful data assimilation method (this is probably one of the more complex implementations of SDA compared to something like corrdiff)
  • An example demonstrating how to use it with the Integrated Surface Database provided by HRRR
  • Adds an additional requirement onto the DA model protocol, the concept of init_coords which reflect initial inputs needed for create_generator and the first set of parameters in call
  • Also update protocol to allow None obs

Coverage:

earth2studio/models/da/__init__.py                2      0   100%
earth2studio/models/da/base.py                    8      0   100%
earth2studio/models/da/interp.py                202    176    13%   86-100, 104, 116, 152-157, 170-174, 202-208, 227-311, 346-380, 397-411, 438-589
earth2studio/models/da/sda_stormcast.py         279     15    95%   285, 386-387, 391-393, 433, 597, 691, 881, 884, 890, 894, 903-906, 909
Rendered Example _home_ngeneva_Downloads__build__build58_html_examples_21_stormcast_sda html (1)
Closer results

21_stormcast_sda_comparison

21_stormcast_sda_gt_comparison

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.
  • Assess and address Greptile feedback (AI code review bot for guidance; use discretion, addressing all feedback is not required).

Dependencies

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 5, 2026

Greptile Summary

This PR introduces StormCastSDA, a stateful score-based data assimilation model that combines a regression network and a diffusion model with DPS (Diffusion Posterior Sampling) guidance to assimilate sparse point observations into convection-allowing regional forecasts. It also extends the AssimilationModel protocol with init_coords() and optional None observation support, adds a legacy=False mode to fetch_data that returns cupy-backed xr.DataArray directly (avoiding unnecessary CPU round-trips), and introduces a GPU-safe map_coords_xr utility.

Key changes:

  • earth2studio/models/da/sda_stormcast.py: Full SDA StormCast implementation, including KD-tree observation mapping, ray-casting point-in-polygon guard, GPU bilinear conditioning interpolation, and averaged scatter-add for duplicate observations. One open concern: NaN observation values are not filtered before the scatter-add in _build_obs_tensors, which can silently corrupt the DPS guidance if a DataFrame contains missing readings.
  • earth2studio/data/utils.py: New legacy=False path in fetch_data is well-guarded (device.index or 0); minor: interp_method is silently ignored in this path when interp_to=None.
  • earth2studio/utils/coords.py: New map_coords_xr correctly handles cupy backends via direct numpy/cupy indexing instead of xarray.interp().
  • earth2studio/models/da/base.py: Clean protocol extension.
  • examples/21_stormcast_sda.py: End-to-end example with minor duplicate imports.
  • Test coverage is thorough, including GPU vs. scipy reference validation for the bilinear interpolation path.

Confidence Score: 4/5

  • PR is safe to merge with one recommended fix — NaN observation filtering in _build_obs_tensors.
  • The implementation is well-structured with good test coverage and most previously-raised issues have been addressed. The primary remaining concern is that NaN observation values (common in real weather station data) are not filtered before the scatter-add in _build_obs_tensors, which can silently corrupt the DPS guidance tensor. The interp_method silent-ignore issue is minor. Everything else — GPU path correctness, averaging of duplicate observations, conditioning interpolation, ascending-order guard, docstrings — looks solid.
  • earth2studio/models/da/sda_stormcast.py — NaN filtering in _build_obs_tensors (line 599).

Important Files Changed

Filename Overview
earth2studio/models/da/sda_stormcast.py Core StormCast SDA implementation — large, well-structured new file. NaN observation values are not filtered before scatter-add in _build_obs_tensors, which can silently corrupt DPS guidance. Most previously-flagged issues (debug prints, mutable default args, typos, sorting guard, duplicate obs averaging, sampler_args key validation) have been addressed.
earth2studio/data/utils.py Adds a legacy mode to fetch_data that returns a raw xr.DataArray with cupy backing on CUDA. The device.index or 0 guard is correctly applied in the new non-legacy path. Minor: interp_method is silently ignored in the non-legacy path when interp_to=None.
earth2studio/models/da/base.py Protocol extended to allow None observations in __call__ and create_generator, and adds optional *args init parameters and init_coords() method. Changes are clean and well-documented.
earth2studio/models/da/interp.py Renames tolerance to time_tolerance and adds init_coords() returning None. The rename is a breaking API change but accepted per prior thread discussion. Smolyak interpolation logic is unchanged and looks correct.
earth2studio/utils/coords.py New map_coords_xr function implementing GPU/CPU-aware nearest-neighbor coordinate mapping without calling xarray.interp() (avoiding the previous scipy/CPU-only path). Uses sort-based searchsorted with correct ascending-order handling via np.argsort. Logic looks correct.
examples/21_stormcast_sda.py Well-written end-to-end example. cartopy/matplotlib are imported twice (lines 165–167 and 235–237), which is redundant. The .get() calls are CUDA-specific but the example is documented as GPU-only. Other previously flagged issues appear to be addressed.
test/models/da/test_da_sda_stormcast.py Comprehensive test suite covering polygon point-in-polygon, observation tensor building (including None, out-of-grid, and duplicate-averaging cases), conditioning fetch, __call__, create_generator, and exception handling. GPU interpolation test validates against scipy reference. Good coverage.
test/data/test_data_utils.py New tests for the legacy=False mode of fetch_data and updated prep_data_inputs behaviour. Coverage looks correct and complete.

Last reviewed commit: 45d157f

@NickGeneva
Copy link
Collaborator Author

@greptile-ai

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
@NickGeneva
Copy link
Collaborator Author

@greptile-ai

@NickGeneva
Copy link
Collaborator Author

@greptile-ai

"""Stateless forward pass"""
input_coords = self.input_coords()
(output_coords,) = self.output_coords(input_coords, **x.attrs)
(output_coords,) = self.output_coords(input_coords, **obs.attrs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have some check to make sure the obs.attrs contains the required request_time arg? Is this part of the general "not having super extensive checks/handshakes" status of the DA?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, deferring this for some later PR focused on these utils

@NickGeneva
Copy link
Collaborator Author

/blossom-ci

@NickGeneva
Copy link
Collaborator Author

/blossom-ci

@NickGeneva
Copy link
Collaborator Author

/blossom-ci

1 similar comment
@NickGeneva
Copy link
Collaborator Author

/blossom-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants