From 7e83a4d148765cc878ae91b6ff51e2efb1335524 Mon Sep 17 00:00:00 2001 From: abhaygoudannavar Date: Fri, 27 Feb 2026 19:08:23 +0000 Subject: [PATCH] feat: Add torch.compile support for autoregressive inference rollouts Add optional compile parameter to deterministic(), diagnostic(), and ensemble() workflows in run.py. When enabled, wraps the prognostic model inner nn.Module with torch.compile(mode=reduce-overhead) for CUDA Graph-backed autoregressive rollouts. Features: - _maybe_compile_model() helper with graceful fallback - Checks for .model attribute (nn.Module) before compiling - Falls back to eager mode with warning on PyTorch < 2.0 or errors - Opt-in via compile=False default (no breaking changes) - Full docstring updates on all 3 workflow functions Includes 6 unit tests covering: compile disabled, compile with valid model, missing model attribute, non-Module model, error fallback, and output correctness verification. --- earth2studio/run.py | 66 +++++++++++++++++++++++- test/test_compile.py | 116 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 180 insertions(+), 2 deletions(-) create mode 100644 test/test_compile.py diff --git a/earth2studio/run.py b/earth2studio/run.py index 562a5d082..c86cba5fb 100644 --- a/earth2studio/run.py +++ b/earth2studio/run.py @@ -35,6 +35,55 @@ logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True) +def _maybe_compile_model( + prognostic: PrognosticModel, + compile: bool, +) -> PrognosticModel: + """Optionally apply torch.compile to the prognostic model's internal network. + + When enabled, this wraps the model's underlying neural network with + ``torch.compile(mode="reduce-overhead")`` which uses CUDA Graphs under + the hood for faster autoregressive rollouts. + + Parameters + ---------- + prognostic : PrognosticModel + Prognostic model instance + compile : bool + Whether to apply torch.compile + + Returns + ------- + PrognosticModel + The same model, potentially with compiled internals + """ + if not compile: + return prognostic + + if not hasattr(torch, "compile"): + logger.warning( + "torch.compile requested but not available (requires PyTorch >= 2.0). " + "Falling back to eager mode." + ) + return prognostic + + try: + if hasattr(prognostic, "model") and isinstance( + prognostic.model, torch.nn.Module + ): + logger.info("Compiling prognostic model with torch.compile") + prognostic.model = torch.compile(prognostic.model, mode="reduce-overhead") + else: + logger.warning( + "torch.compile requested but prognostic model does not expose " + "a 'model' attribute of type nn.Module. Skipping compilation." + ) + except Exception as e: + logger.warning(f"torch.compile failed, falling back to eager mode: {e}") + + return prognostic + + # sphinx - deterministic start def deterministic( time: list[str] | list[datetime] | list[np.datetime64], @@ -45,6 +94,7 @@ def deterministic( output_coords: CoordSystem = OrderedDict({}), device: torch.device | None = None, verbose: bool = True, + compile: bool = False, ) -> IOBackend: """Built in deterministic workflow. This workflow creates a determinstic inference pipeline to produce a forecast @@ -68,6 +118,9 @@ def deterministic( Device to run inference on, by default None verbose : bool, optional Print inference progress, by default True + compile : bool, optional + If True, apply torch.compile to the model for faster autoregressive + rollouts using CUDA Graphs. Requires PyTorch >= 2.0. By default False Returns ------- @@ -84,6 +137,7 @@ def deterministic( ) logger.info(f"Inference device: {device}") prognostic = prognostic.to(device) + prognostic = _maybe_compile_model(prognostic, compile) # sphinx - fetch data start # Fetch data from data source and load onto device prognostic_ic = prognostic.input_coords() @@ -163,6 +217,7 @@ def diagnostic( output_coords: CoordSystem = OrderedDict({}), device: torch.device | None = None, verbose: bool = True, + compile: bool = False, ) -> IOBackend: """Built in diagnostic workflow. This workflow creates a determinstic inference pipeline that couples a prognostic @@ -188,6 +243,9 @@ def diagnostic( Device to run inference on, by default None verbose : bool, optional Print inference progress, by default True + compile : bool, optional + If True, apply torch.compile to the model for faster autoregressive + rollouts using CUDA Graphs. Requires PyTorch >= 2.0. By default False Returns ------- @@ -204,6 +262,7 @@ def diagnostic( ) logger.info(f"Inference device: {device}") prognostic = prognostic.to(device) + prognostic = _maybe_compile_model(prognostic, compile) diagnostic = diagnostic.to(device) # Fetch data from data source and load onto device prognostic_ic = prognostic.input_coords() @@ -262,7 +321,6 @@ def diagnostic( total=nsteps + 1, desc="Running inference", position=1, disable=(not verbose) ) as pbar: for step, (x, coords) in enumerate(model): - # Run diagnostic x, coords = map_coords(x, coords, diagnostic_ic) x, coords = diagnostic(x, coords) @@ -290,6 +348,7 @@ def ensemble( output_coords: CoordSystem = OrderedDict({}), device: torch.device | None = None, verbose: bool = True, + compile: bool = False, ) -> IOBackend: """Built in ensemble workflow. @@ -318,6 +377,9 @@ def ensemble( Device to run inference on, by default None verbose : bool, optional Print inference progress, by default True + compile : bool, optional + If True, apply torch.compile to the model for faster autoregressive + rollouts using CUDA Graphs. Requires PyTorch >= 2.0. By default False Returns ------- @@ -335,6 +397,7 @@ def ensemble( ) logger.info(f"Inference device: {device}") prognostic = prognostic.to(device) + prognostic = _maybe_compile_model(prognostic, compile) # Fetch data from data source and load onto device prognostic_ic = prognostic.input_coords() @@ -395,7 +458,6 @@ def ensemble( position=2, disable=(not verbose), ): - # Get fresh batch data x = x0.to(device) diff --git a/test/test_compile.py b/test/test_compile.py new file mode 100644 index 000000000..e5b953993 --- /dev/null +++ b/test/test_compile.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from earth2studio.run import _maybe_compile_model + + +class DummyNetwork(torch.nn.Module): + """A minimal nn.Module for testing torch.compile.""" + + def __init__(self, channels: int = 4): + super().__init__() + self.linear = torch.nn.Linear(channels, channels) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +class MockPrognosticWithModel: + """Mock prognostic model that exposes a .model attribute.""" + + def __init__(self): + self.model = DummyNetwork(channels=4) + + +class MockPrognosticWithoutModel: + """Mock prognostic model without a .model attribute.""" + + pass + + +class MockPrognosticNonModuleModel: + """Mock prognostic model where .model is not an nn.Module.""" + + def __init__(self): + self.model = "not_a_module" + + +def test_compile_disabled(): + """When compile=False, model should be returned unchanged.""" + prog = MockPrognosticWithModel() + original_model = prog.model + result = _maybe_compile_model(prog, compile=False) + assert result is prog + assert result.model is original_model + + +def test_compile_enabled_with_model(): + """When compile=True and model has .model attribute, it should be compiled.""" + prog = MockPrognosticWithModel() + original_model = prog.model + result = _maybe_compile_model(prog, compile=True) + assert result is prog + # After compilation, model should be wrapped (not the same object) + assert result.model is not original_model + + +def test_compile_without_model_attribute(): + """When model lacks .model attribute, should warn and return unchanged.""" + prog = MockPrognosticWithoutModel() + result = _maybe_compile_model(prog, compile=True) + assert result is prog + + +def test_compile_non_module_model(): + """When .model is not an nn.Module, should warn and return unchanged.""" + prog = MockPrognosticNonModuleModel() + result = _maybe_compile_model(prog, compile=True) + assert result is prog + assert result.model == "not_a_module" + + +def test_compile_fallback_on_error(): + """If torch.compile raises, should fall back gracefully.""" + prog = MockPrognosticWithModel() + + with patch("torch.compile", side_effect=RuntimeError("compilation failed")): + result = _maybe_compile_model(prog, compile=True) + # Should return the model unchanged on error + assert result is prog + + +def test_compiled_model_produces_output(): + """Verify that a compiled model still produces correct output.""" + prog = MockPrognosticWithModel() + + # Get reference output before compilation + x = torch.randn(2, 4) + with torch.no_grad(): + ref_output = prog.model(x) + + # Compile + _maybe_compile_model(prog, compile=True) + + # Get output after compilation (first call triggers compilation) + with torch.no_grad(): + compiled_output = prog.model(x) + + assert torch.allclose(ref_output, compiled_output, atol=1e-5)