[Dev] feat(checkpoint): zero-copy storage sharing in CheckpointWithoutOutput#3641
Merged
Victarry merged 3 commits intoNVIDIA:devfrom Mar 2, 2026
Merged
Conversation
Contributor
Author
|
/ok to test 8008bec |
Replace `UntypedStorage.resize_() + copy_()` with a StorageImpl-level data pointer swap in `_recompute()`. This eliminates a full GPU memcpy and the temporary double allocation during the backward recomputation phase. The key insight is that `resize_() + copy_()` restores data into the same StorageImpl object, so all views (reshape/split) see it — but at the cost of a full copy. A naive `Tensor.set_()` is zero-copy but only redirects one TensorImpl, leaving views (e.g. TE GroupedLinear's inp.reshape() + torch.split() saved for backward) pointing at the old empty storage. The solution operates at the StorageImpl level: a ~20-line C++ extension (`share_storage`) replaces the existing StorageImpl's `data_ptr_` to point to the recomputed tensor's memory, with a custom DataPtr deleter that holds a refcounted reference to the source StorageImpl. This is the same `StorageImpl::set_data_ptr()` API that PyTorch uses internally for IPC storage sharing (`_share_cuda_` in StorageSharing.cpp). Benefits: - Zero-copy: no GPU memcpy, just a pointer swap - No temporary double allocation during restore - View-compatible: all TensorImpls sharing the StorageImpl see the data - No autograd version counter bump (operates below TensorImpl) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
8008bec to
c833c0a
Compare
Contributor
Author
|
/ok to test c833c0a |
Contributor
Author
|
/ok to test c833c0a |
Contributor
Author
|
/ok to test c833c0a |
Contributor
Author
|
/ok to test 33d4d26 |
hxbai
approved these changes
Mar 2, 2026
5 tasks
Contributor
Author
|
/ok to test 74d10da |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22567814399 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Zero-Copy Storage Sharing for CheckpointWithoutOutput
PR to main #3649
Summary
Replace
UntypedStorage.resize_() + copy_()with a StorageImpl-level data pointer swap inCheckpointWithoutOutput._recompute(). This eliminates a full GPU memcpy and the double allocation during the backward recomputation phase, while remaining compatible with downstream modules that create views (e.g., TransformerEngineGroupedLinear).Motivation
How CheckpointWithoutOutput works
Standard activation checkpointing (
torch.utils.checkpoint) skips saving intermediate activations during forward and recomputes them during backward. However, the output tensors of the checkpointed function are still held by downstream modules for their own backward, so that memory is never freed.CheckpointWithoutOutputgoes one step further — it explicitly frees the output storage right after forward and restores it just-in-time during backward:The GPU memory occupied by the checkpoint output is freed between the end of forward and the hook firing in backward.
Problem with the original restore logic
The original code restores the recomputed data by allocating + copying:
This means two copies of the same activation coexist on the GPU momentarily (
output's fresh allocation +recomp's allocation), and thecopy_has bandwidth cost.Goal
Make
output's storage point directly torecomp's memory — zero-copy, zero temporary double allocation.Why a naive
Tensor.set_()does not workThe obvious approach —
output.set_(recomp.untyped_storage(), ...)— is zero-copy at the tensor level, but it breaks downstream views.When downstream modules consume the checkpoint output, they commonly create views. For example, TransformerEngine's
GroupedLinear.forward()does:Multiple TensorImpls all reference the same StorageImpl (StorageA):
Tensor.set_()only redirectsoutput's TensorImpl to a new StorageImpl B. The view TensorImpls (inp_view,inputmats) still reference the old, empty StorageA:This causes TE's backward to crash:
Assertion failed: A.has_data() ... Input A does not hold any data!By contrast,
resize_() + copy_()writes data back into the same StorageImpl A, so all views automatically see the restored data. But it requires a full copy.Solution: StorageImpl-level data pointer sharing
Key insight
Instead of redirecting the TensorImpl (tensor level), we replace the data pointer inside the existing StorageImpl (storage level). Since all TensorImpls — including views — reference the same StorageImpl object, they all see the new data immediately.
C++ extension (~20 lines)
PyTorch does not expose
StorageImpl::set_data_ptr()to Python (see Appendix: PyTorch API survey). We compile a minimal C++ extension viatorch.utils.cpp_extension.load_inline:This uses
StorageImpl::set_data_ptr(), the same C++ API that PyTorch's own_share_cuda_()IPC path uses internally (torch/csrc/StorageSharing.cpp:322).Call site
Safety analysis
View compatible — All TensorImpls reference the same StorageImpl. Replacing its
data_ptrmakes the change visible to every view immediately.No version counter bump —
StorageImpl::set_data_ptr()operates below the TensorImpl layer and does not triggerTensorImpl::bump_version(). No autograd "modified by an inplace operation" errors.No dangling pointers — The custom
DataPtrholds a refcountedc10::Storagereference tosrc's StorageImpl. Memory lifetime:_recompute()backward()runsctx.outputs = NoneorigGC'dRecomp autograd graph unaffected — Only
StorageA(dst) is modified.StorageB(src) and all intermediate tensors in the recomp graph are untouched.torch.autograd.backward(outputs, grad_outputs)executes normally.Autograd metadata preserved —
share_storageonly touchesStorageImpl.data_ptr_andStorageImpl.size_bytes_.requires_grad,grad_fn,grad,sizes_,strides_,storage_offset_, and version counter on TensorImpl are all unchanged.Comparison
Verification
Convergence Test
POC tests (all 5 passed)
The view-sharing test specifically mimics TE
GroupedLinear's pattern: the downstream layer performsinp.reshape()+torch.split()and saves those views viactx.save_for_backward(). This test fails withTensor.set_()but passes withshare_storage.Megatron-LM E2E
bash local_test_e2e.sh # 8 GPU, PP=4, EP=2, MoE + MLA, bf16Files changed
megatron/core/tensor_parallel/random.pyshare_storageC++ extension; replaceresize_+copy_withshare_storagein_recompute()Appendix: PyTorch API survey
We reviewed every candidate PyTorch API. None provides a general-purpose way to modify an existing StorageImpl's
data_ptr_from Python.Tensor.set_()torch.utils.swap_tensors()_construct_storage_from_data_pointer()UntypedStorage._set_cdata()_share_cuda_()/_share_fd_cpu_()safely_set_viewless_tensor_data()tensor.data = ...)set_()The only internal PyTorch code paths that call
StorageImpl::set_data_ptr()are the IPC sharing methods (_share_cuda_,_share_fd_cpu_inStorageSharing.cpp). Our extension uses the same C++ API.Appendix: PyTorch internals reference