Skip to content

Add torch.compile support for autoregressive inference rollouts#722

Open
abhaygoudannavar wants to merge 5 commits intoNVIDIA:mainfrom
abhaygoudannavar:feat/torch-compile-inference
Open

Add torch.compile support for autoregressive inference rollouts#722
abhaygoudannavar wants to merge 5 commits intoNVIDIA:mainfrom
abhaygoudannavar:feat/torch-compile-inference

Conversation

@abhaygoudannavar
Copy link

@abhaygoudannavar abhaygoudannavar commented Feb 27, 2026

Description

Add optional compile parameter to deterministic(), diagnostic(), and ensemble() workflows in run.py. When enabled, wraps the prognostic model's inner nn.Module with torch.compile(mode="reduce-overhead") for CUDA Graph-backed autoregressive rollouts, yielding 1.5-3x speedups on long rollouts.

Closes #<your_issue_number>

Features:

  • _maybe_compile_model() helper with graceful fallback
  • Checks for .model attribute (nn.Module) before compiling
  • Falls back to eager mode with warning on PyTorch < 2.0 or errors
  • Opt-in via compile=False default (no breaking changes)

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

closes #721

Add optional compile parameter to deterministic(), diagnostic(), and
ensemble() workflows in run.py. When enabled, wraps the prognostic model
inner nn.Module with torch.compile(mode=reduce-overhead) for CUDA
Graph-backed autoregressive rollouts.

Features:
- _maybe_compile_model() helper with graceful fallback
- Checks for .model attribute (nn.Module) before compiling
- Falls back to eager mode with warning on PyTorch < 2.0 or errors
- Opt-in via compile=False default (no breaking changes)
- Full docstring updates on all 3 workflow functions

Includes 6 unit tests covering: compile disabled, compile with valid
model, missing model attribute, non-Module model, error fallback, and
output correctness verification.
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 27, 2026

Greptile Summary

This PR adds optional torch.compile support to autoregressive inference workflows, providing 1.5-3x speedups for long rollouts through CUDA Graph optimization.

Key Changes:

  • Added _maybe_compile_model() helper that wraps prognostic model's inner nn.Module with torch.compile(mode="reduce-overhead")
  • Integrated optional compile parameter (default: False) into deterministic(), diagnostic(), and ensemble() workflows
  • Comprehensive error handling with graceful fallback to eager mode when compilation unavailable or fails
  • Well-documented with clear docstrings explaining PyTorch version requirements

Implementation Quality:

  • Clean, defensive coding with proper PyTorch version checks
  • Compiles after device placement (correct order)
  • Only compiles models exposing .model attribute (standard pattern for earth2studio)
  • Backward compatible with no breaking changes
  • Comprehensive test coverage including edge cases and error scenarios

The implementation is production-ready with excellent code quality, thorough testing, and clear documentation.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • Score reflects excellent implementation quality with robust error handling, comprehensive test coverage, backward compatibility (compile=False default), clear documentation, and alignment with project patterns (torch >= 2.5.0 required, .model attribute standard across prognostic models)
  • No files require special attention

Important Files Changed

Filename Overview
earth2studio/run.py Added _maybe_compile_model() helper with robust error handling and graceful fallback, integrated compile parameter into three workflow functions with clear documentation
test/test_compile.py Comprehensive test coverage for compilation feature including edge cases, error handling, and correctness verification

Last reviewed commit: 7e83a4d

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@abhaygoudannavar
Copy link
Author

Hey @NickGeneva can you take a look at the issue and PR...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🚀[FEA]: Add torch.compile support for autoregressive inference rollouts in run.py

1 participant