feat(checkpoint): zero-copy storage sharing in CheckpointWithoutOutput#3649
Open
Victarry wants to merge 2 commits intoNVIDIA:mainfrom
Open
feat(checkpoint): zero-copy storage sharing in CheckpointWithoutOutput#3649Victarry wants to merge 2 commits intoNVIDIA:mainfrom
Victarry wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Replace `UntypedStorage.resize_() + copy_()` with a StorageImpl-level data pointer swap in `_recompute()`. This eliminates a full GPU memcpy and the temporary double allocation during the backward recomputation phase. The key insight is that `resize_() + copy_()` restores data into the same StorageImpl object, so all views (reshape/split) see it — but at the cost of a full copy. A naive `Tensor.set_()` is zero-copy but only redirects one TensorImpl, leaving views (e.g. TE GroupedLinear's inp.reshape() + torch.split() saved for backward) pointing at the old empty storage. The solution operates at the StorageImpl level: a ~20-line C++ extension (`share_storage`) replaces the existing StorageImpl's `data_ptr_` to point to the recomputed tensor's memory, with a custom DataPtr deleter that holds a refcounted reference to the source StorageImpl. This is the same `StorageImpl::set_data_ptr()` API that PyTorch uses internally for IPC storage sharing (`_share_cuda_` in StorageSharing.cpp). Benefits: - Zero-copy: no GPU memcpy, just a pointer swap - No temporary double allocation during restore - View-compatible: all TensorImpls sharing the StorageImpl see the data - No autograd version counter bump (operates below TensorImpl) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Replace
UntypedStorage.resize_() + copy_()with a StorageImpl-level data pointer swap inCheckpointWithoutOutput._recompute(). This eliminates a full GPU memcpy and the temporary double allocation during the backward recomputation phase, while remaining compatible with downstream modules that create views (e.g., TransformerEngineGroupedLinear).Cherry-picked from dev PR: #3641
Motivation
How CheckpointWithoutOutput works
Standard activation checkpointing (
torch.utils.checkpoint) skips saving intermediate activations during forward and recomputes them during backward. However, the output tensors of the checkpointed function are still held by downstream modules for their own backward, so that memory is never freed.CheckpointWithoutOutputgoes one step further — it explicitly frees the output storage right after forward and restores it just-in-time during backward:Problem with the original restore logic
The original code restores the recomputed data by allocating + copying:
This means two copies of the same activation coexist on the GPU momentarily (
output's fresh allocation +recomp's allocation), and thecopy_has bandwidth cost.Goal
Make
output's storage point directly torecomp's memory — zero-copy, zero temporary double allocation.Why a naive
Tensor.set_()does not workoutput.set_(recomp.untyped_storage(), ...)is zero-copy at the tensor level, but breaks downstream views. When downstream modules (e.g. TEGroupedLinear) create views viainp.reshape()+torch.split()and save them for backward, those view TensorImpls still reference the old, empty StorageImpl afterset_().Solution: StorageImpl-level data pointer sharing
A ~20-line C++ extension (
share_storage) replaces the existing StorageImpl'sdata_ptr_to point to the recomputed tensor's memory, with a custom DataPtr deleter that holds a refcounted reference to the source StorageImpl. This is the sameStorageImpl::set_data_ptr()API that PyTorch uses internally for IPC storage sharing (_share_cuda_inStorageSharing.cpp).Since all TensorImpls — including views — reference the same StorageImpl object, they all see the new data immediately.
Benefits
Files changed
megatron/core/tensor_parallel/random.pyshare_storageC++ extension; replaceresize_+copy_withshare_storagein_recompute()tests/unit_tests/tensor_parallel/test_random.pyTest plan
Tensor.set_(), passes withshare_storagelocal_test_e2e.sh(8 GPU, PP=4, EP=2, MoE + MLA, bf16)Made with Cursor