Add torch.compile support for autoregressive inference rollouts#722
Open
abhaygoudannavar wants to merge 5 commits intoNVIDIA:mainfrom
Open
Add torch.compile support for autoregressive inference rollouts#722abhaygoudannavar wants to merge 5 commits intoNVIDIA:mainfrom
abhaygoudannavar wants to merge 5 commits intoNVIDIA:mainfrom
Conversation
Add optional compile parameter to deterministic(), diagnostic(), and ensemble() workflows in run.py. When enabled, wraps the prognostic model inner nn.Module with torch.compile(mode=reduce-overhead) for CUDA Graph-backed autoregressive rollouts. Features: - _maybe_compile_model() helper with graceful fallback - Checks for .model attribute (nn.Module) before compiling - Falls back to eager mode with warning on PyTorch < 2.0 or errors - Opt-in via compile=False default (no breaking changes) - Full docstring updates on all 3 workflow functions Includes 6 unit tests covering: compile disabled, compile with valid model, missing model attribute, non-Module model, error fallback, and output correctness verification.
Contributor
Greptile SummaryThis PR adds optional Key Changes:
Implementation Quality:
The implementation is production-ready with excellent code quality, thorough testing, and clear documentation.
|
| Filename | Overview |
|---|---|
| earth2studio/run.py | Added _maybe_compile_model() helper with robust error handling and graceful fallback, integrated compile parameter into three workflow functions with clear documentation |
| test/test_compile.py | Comprehensive test coverage for compilation feature including edge cases, error handling, and correctness verification |
Last reviewed commit: 7e83a4d
Author
|
Hey @NickGeneva can you take a look at the issue and PR... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Add optional
compileparameter todeterministic(),diagnostic(), andensemble()workflows inrun.py. When enabled, wraps the prognostic model's innernn.Modulewithtorch.compile(mode="reduce-overhead")for CUDA Graph-backed autoregressive rollouts, yielding 1.5-3x speedups on long rollouts.Closes #<your_issue_number>
Features:
_maybe_compile_model()helper with graceful fallback.modelattribute (nn.Module) before compilingcompile=Falsedefault (no breaking changes)Checklist
closes #721