Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
a6d33a8
Converitng to 2.0 samplers
NickGeneva Mar 3, 2026
c26b3b7
Updating stormcast
NickGeneva Mar 4, 2026
eb8cedc
SDA running
NickGeneva Mar 4, 2026
458f5b3
Random pytest fixes
NickGeneva Mar 4, 2026
c2e08cf
adding cupy support in fetch_data
NickGeneva Mar 5, 2026
65c5ec1
Xarray bits
NickGeneva Mar 5, 2026
3215a58
Adding obs
NickGeneva Mar 5, 2026
69ab063
generator
NickGeneva Mar 5, 2026
95ae396
Updates
NickGeneva Mar 5, 2026
9c57dda
Fixing iterator
NickGeneva Mar 5, 2026
e3356de
Updates
NickGeneva Mar 6, 2026
ba0244b
Adding tests
NickGeneva Mar 6, 2026
ab1d9d5
Updating tolerance name
NickGeneva Mar 6, 2026
60bec2b
Draft example
NickGeneva Mar 6, 2026
c389222
Updates
NickGeneva Mar 6, 2026
00ba3f6
Updates
NickGeneva Mar 6, 2026
447ae2f
Some updates
NickGeneva Mar 7, 2026
fa28e80
Example working better
NickGeneva Mar 7, 2026
82093af
Example update
NickGeneva Mar 7, 2026
bbd7290
Updates
NickGeneva Mar 9, 2026
467d171
Improvments
NickGeneva Mar 9, 2026
d02baec
Merge branch 'main' into ngeneva/sda_stormcast
NickGeneva Mar 9, 2026
ba75a90
Greptile 1
NickGeneva Mar 9, 2026
1f8aa7e
Greptile 2
NickGeneva Mar 9, 2026
71d1ad0
Update earth2studio/models/da/sda_stormcast.py
NickGeneva Mar 9, 2026
6c9a744
Update earth2studio/models/da/sda_stormcast.py
NickGeneva Mar 9, 2026
7dad722
Update earth2studio/models/da/sda_stormcast.py
NickGeneva Mar 9, 2026
aaf4659
Greptile
NickGeneva Mar 9, 2026
bf99be4
Clean up
NickGeneva Mar 9, 2026
b6a5463
Clean up
NickGeneva Mar 9, 2026
c2d1010
Clean up
NickGeneva Mar 9, 2026
f902c35
Simplify lead time
NickGeneva Mar 9, 2026
46cb97a
Simplify lead time
NickGeneva Mar 9, 2026
5ab87fb
Update example
NickGeneva Mar 9, 2026
d2f4338
Update example
NickGeneva Mar 9, 2026
7757082
revert
NickGeneva Mar 9, 2026
6136ab8
revert
NickGeneva Mar 9, 2026
5b06f12
Little improvements
NickGeneva Mar 9, 2026
9b73799
Adding average
NickGeneva Mar 9, 2026
6af783f
Revert original stormcast
NickGeneva Mar 9, 2026
2561807
Fix the interp
NickGeneva Mar 10, 2026
ff4a920
improvements
NickGeneva Mar 10, 2026
8709b16
improvements
NickGeneva Mar 10, 2026
021d01a
improvements
NickGeneva Mar 10, 2026
f15cc9c
improvements
NickGeneva Mar 10, 2026
54f903a
improvements
NickGeneva Mar 10, 2026
117c5e9
Update
NickGeneva Mar 10, 2026
51998d0
Merge branch 'main' into ngeneva/sda_stormcast
NickGeneva Mar 10, 2026
6ea3d2b
Update interp function
NickGeneva Mar 10, 2026
119c4d1
Greptile
NickGeneva Mar 10, 2026
45d157f
Greptile
NickGeneva Mar 10, 2026
65d5f01
Remove chardet
NickGeneva Mar 10, 2026
7f8a288
Testing
NickGeneva Mar 10, 2026
513e582
remove repeat imports
NickGeneva Mar 10, 2026
96f40ae
revert isd
NickGeneva Mar 10, 2026
720e442
Fixing interp
NickGeneva Mar 10, 2026
ae57064
Drop s3fs version
NickGeneva Mar 10, 2026
f0be113
Drop s3fs version
NickGeneva Mar 10, 2026
96ceff9
Merge branch 'main' into ngeneva/sda_stormcast
NickGeneva Mar 11, 2026
47dbaf4
Feedback
NickGeneva Mar 11, 2026
deebc08
Feedback
NickGeneva Mar 11, 2026
6958de5
Adding das to all install
NickGeneva Mar 11, 2026
7b9415b
Adding das to all install
NickGeneva Mar 11, 2026
bdd2760
Adding install
NickGeneva Mar 11, 2026
0246b63
Lint Fixes
NickGeneva Mar 11, 2026
dc9e312
Merge branch 'main' into ngeneva/sda_stormcast
NickGeneva Mar 11, 2026
8d2d420
fix test
NickGeneva Mar 11, 2026
e0fa6d9
fix test
NickGeneva Mar 11, 2026
349d9d5
Merge branch 'main' into ngeneva/sda_stormcast
NickGeneva Mar 11, 2026
86e1148
Merge branch 'main' into ngeneva/sda_stormcast
NickGeneva Mar 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .cursor/rules/e2s-009-prognostic-models.mdc
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ def to(self, device: torch.device | str) -> PrognosticModel:
- Call `super().to(device)` for PyTorch module
- Move any custom buffers/parameters to device
- Return `self` for chaining
- Torch.nn.Module parent class addresses this requirement most of the time
- Generally its good to have `self.register_buffer("device_buffer", torch.empty(0))` in their init to help track what the current device of the model is

## Data Operations on GPU

Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `fetch_dataframe` utility function
- Added data assimilation model class
- Added equirectangular interpolation data assimilation model
- Added StormCast SDA model
- Adding Beta serve utils with inference server and client implementations

### Changed
Expand Down
1 change: 1 addition & 0 deletions docs/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ or maintain internal state across time steps.
:template: dataassim.rst

InterpEquirectangular
StormCastSDA

:mod:`earth2studio.models`: Utilities
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
57 changes: 57 additions & 0 deletions docs/userguide/about/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,63 @@ uv add earth2studio --extra windgust-afno
:::::
::::::

#### Data Assimilation

:::{admonition} Warning
:class: warning

Data assimilation model APIs are currently **in Beta** and may change in future
releases. Expect possible breaking changes as these APIs mature.
:::

:::{admonition} Warning
:class: warning

All data assimilation models require [CuPy](https://docs.cupy.dev/en/stable/) and [cuDF](https://docs.rapids.ai/api/cudf/stable/),
which are CUDA-dependent libraries.
The default installation uses CUDA 12 (i.e., `cupy-cuda12x`, `cudf-cu12`).
If your system uses a different CUDA version, you may need to adjust the dependencies.
:::

::::::{tab-set}
:::::{tab-item} InterpEquirectangular
::::{tab-set}
:::{tab-item} pip

```bash
pip install earth2studio[da-interp]
```

:::
:::{tab-item} uv

```bash
uv add earth2studio --extra da-interp
```

:::
::::
:::::
:::::{tab-item} StormCast SDA
::::{tab-set}
:::{tab-item} pip

```bash
pip install earth2studio[da-stormcast]
```

:::
:::{tab-item} uv

```bash
uv add earth2studio --extra da-stormcast
```

:::
::::
:::::
::::::

### Submodule Dependencies

A few features in various submodules require some specific dependencies that have been
Expand Down
53 changes: 42 additions & 11 deletions earth2studio/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,15 @@
from fsspec.implementations.cache_mapper import AbstractCacheMapper

try:
import cudf
import cupy as cp
except ImportError:
cudf = None
cp = None

try:
import cudf
except ImportError:
cudf = None


def fetch_data(
source: DataSource | ForecastSource,
Expand All @@ -82,7 +85,8 @@ def fetch_data(
device: torch.device = "cpu",
interp_to: CoordSystem | None = None,
interp_method: str = "nearest",
) -> tuple[torch.Tensor, CoordSystem]:
legacy: bool = True,
) -> tuple[torch.Tensor, CoordSystem] | xr.DataArray:
"""Utility function to fetch data arrays from particular sources and load data on
the target device. If desired, xarray interpolation/regridding in the spatial
domain can be used by passing a target coordinate system via the optional
Expand All @@ -106,13 +110,18 @@ def fetch_data(
specified by lat/lon arrays in this CoordSystem
interp_method : str
Interpolation method to use with xarray (by default 'nearest')
legacy : bool, optional
If True (default), returns tuple of (torch.Tensor, CoordSystem).
If False, returns xr.DataArray with numpy arrays for CPU or cupy arrays for CUDA.

Returns
-------
tuple[torch.Tensor, CoordSystem]
Tuple containing output tensor and coordinate OrderedDict
tuple[torch.Tensor, CoordSystem] | xr.DataArray
If legacy=True: Tuple containing output tensor and coordinate OrderedDict.
If legacy=False: xr.DataArray with numpy arrays (CPU) or cupy arrays (CUDA).
"""
sig = signature(source.__call__)
device = torch.device(device)

if "lead_time" in sig.parameters:
# Working with a Forecast Data Source
Expand All @@ -130,12 +139,31 @@ def fetch_data(

da = xr.concat(da, "lead_time")

return prep_data_array(
da,
device=device,
interp_to=interp_to,
interp_method=interp_method,
)
if legacy:
return prep_data_array(
da,
device=device,
interp_to=interp_to,
interp_method=interp_method,
)

# Non-legacy path: return xr.DataArray
else:
if interp_to is not None:
raise ValueError(
"The interp_to argument is not supported when legacy is False. Set legacy=True to use interpolation."
)
# Convert to cupy arrays if CUDA device and cupy is available
if device.type == "cuda":
if cp is not None:
with cp.cuda.Device(device.index or 0):
da = da.copy(data=cp.asarray(da.values))
else:
raise ImportError(
"cupy is required when using device='cuda' with legacy=False. "
"Install cupy or use legacy=True."
)
return da


def fetch_dataframe(
Expand Down Expand Up @@ -316,6 +344,9 @@ def prep_data_inputs(
if isinstance(variable, str):
variable = [variable]

if isinstance(variable, np.ndarray):
variable = variable.astype(str).tolist()

if isinstance(time, datetime):
time = [time]

Expand Down
1 change: 1 addition & 0 deletions earth2studio/models/da/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
# limitations under the License.

from .interp import InterpEquirectangular
from .sda_stormcast import StormCastSDA
41 changes: 34 additions & 7 deletions earth2studio/models/da/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class AssimilationModel(Protocol):

def __call__(
self,
*args: pd.DataFrame | xr.DataArray,
*args: pd.DataFrame | xr.DataArray | None,
) -> tuple[pd.DataFrame | xr.DataArray, ...]:
"""Stateless iteration for the data assimilation model.

Expand All @@ -45,10 +45,11 @@ def __call__(

Parameters
----------
*args : pd.DataFrame | xr.DataArray
*args : pd.DataFrame | xr.DataArray | None
Variable number of observation arguments. Each argument can be a
DataFrame (pandas or cudf DataFrame) or xarray DataArray
containing observation data.
containing observation data. None can be passed for optional
arguments when no input data is available.

Returns
-------
Expand All @@ -60,9 +61,10 @@ def __call__(

def create_generator(
self,
*args: pd.DataFrame | xr.DataArray,
) -> Generator[
tuple[pd.DataFrame | xr.DataArray, ...],
tuple[pd.DataFrame | xr.DataArray, ...],
tuple[pd.DataFrame | xr.DataArray | None, ...],
None,
]:
"""Creates a generator which accepts collection of input observations and
Expand All @@ -73,6 +75,13 @@ def create_generator(
method and yields assimilated data (DataFrame or DataArray) as output.
Supports any number of arguments (variadic).

Parameters
----------
*args : pd.DataFrame | xr.DataArray
Variable number of initialization arguments, if any are required by
the model. Each argument can be a DataFrame (pandas or cudf
DataFrame) or xarray DataArray containing initial state data.

Yields
------
tuple[pd.DataFrame | xr.DataArray, ...]
Expand All @@ -82,11 +91,12 @@ def create_generator(

Receives
--------
tuple[pd.DataFrame | xr.DataArray, ...]
tuple[pd.DataFrame | xr.DataArray | None, ...]
Observations sent via generator.send() as multiple arguments. Each
argument can be a DataFrame (PyArrow Table or cudf DataFrame) or xarray
DataArray. None is sent initially to start the generator. Supports any
number of arguments.
DataArray. None is sent initially to start the generator and can also be
sent for iterations where no input data is available. Supports any number
of arguments.

Examples
--------
Expand All @@ -100,6 +110,23 @@ def create_generator(
"""
pass

def init_coords(self) -> tuple[FrameSchema | CoordSystem, ...] | None:
"""Initialization coordinate system required by the assimilation model.

Specifies the coordinate system(s) for initial state data that must be provided
before the model can process observations. The returned coordinate systems should
match the expected input format for the first argument(s) passed to ``__call__``
or sent to ``create_generator`` when initializing the model state.

Returns
-------
tuple[FrameSchema | CoordSystem, ...] | None
Tuple of coordinate systems or frame schemas defining the structure of
required initialization data. Returns ``None`` if the model does not require
initialization data (e.g., stateless models).
"""
pass

def input_coords(self) -> tuple[FrameSchema | CoordSystem, ...]:
"""Input coordinate system of assimilation model.

Expand Down
18 changes: 11 additions & 7 deletions earth2studio/models/da/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class InterpEquirectangular(torch.nn.Module):
grid over CONUS)
interp_method : str, optional
Interpolation method to use: 'nearest' or 'smolyak', by default "smolyak"
tolerance : TimeTolerance, optional
time_tolerance : TimeTolerance, optional
Time tolerance for filtering observations. Observations within the tolerance
window around each requested time will be used for interpolation, by default
np.timedelta64(10, "m")
Expand All @@ -81,7 +81,7 @@ def __init__(
lat: np.ndarray | None = None,
lon: np.ndarray | None = None,
interp_method: str = "smolyak",
tolerance: TimeTolerance = np.timedelta64(10, "m"),
time_tolerance: TimeTolerance = np.timedelta64(10, "m"),
) -> None:
if interp_method not in ["nearest", "smolyak"]:
raise ValueError(
Expand All @@ -96,9 +96,13 @@ def __init__(
lon if lon is not None else np.linspace(235.0, 295.0, 241, dtype=np.float32)
)
self.interp_method = interp_method
self._tolerance = normalize_time_tolerance(tolerance)
self._tolerance = normalize_time_tolerance(time_tolerance)
self.register_buffer("device_buffer", torch.empty(0), persistent=False)

def init_coords(self) -> None:
"""Initialization coords (not required)"""
return None

def input_coords(self) -> tuple[FrameSchema]:
"""Input coordinate system specifying required DataFrame fields.

Expand Down Expand Up @@ -161,13 +165,13 @@ def output_coords(
),
)

def __call__(self, x: pd.DataFrame) -> xr.DataArray:
def __call__(self, obs: pd.DataFrame) -> xr.DataArray:
"""Stateless forward pass"""
input_coords = self.input_coords()
(output_coords,) = self.output_coords(input_coords, **x.attrs)
(output_coords,) = self.output_coords(input_coords, **obs.attrs)
# Validate observation types match input_coords
validate_observation_fields(x, required_fields=list(input_coords[0].keys()))
return self._interpolate_dataframe(x, output_coords)
validate_observation_fields(obs, required_fields=list(input_coords[0].keys()))
return self._interpolate_dataframe(obs, output_coords)

def create_generator(self) -> Generator[
xr.DataArray,
Expand Down
Loading