Skip to content

feat(checkpoint): zero-copy storage sharing in CheckpointWithoutOutput#3649

Open
Victarry wants to merge 2 commits intoNVIDIA:mainfrom
Victarry:denliu/zero_copy_recompute_main
Open

feat(checkpoint): zero-copy storage sharing in CheckpointWithoutOutput#3649
Victarry wants to merge 2 commits intoNVIDIA:mainfrom
Victarry:denliu/zero_copy_recompute_main

Conversation

@Victarry
Copy link
Contributor

@Victarry Victarry commented Mar 2, 2026

Summary

Replace UntypedStorage.resize_() + copy_() with a StorageImpl-level data pointer swap in CheckpointWithoutOutput._recompute(). This eliminates a full GPU memcpy and the temporary double allocation during the backward recomputation phase, while remaining compatible with downstream modules that create views (e.g., TransformerEngine GroupedLinear).

Cherry-picked from dev PR: #3641

Motivation

How CheckpointWithoutOutput works

Standard activation checkpointing (torch.utils.checkpoint) skips saving intermediate activations during forward and recomputes them during backward. However, the output tensors of the checkpointed function are still held by downstream modules for their own backward, so that memory is never freed.

CheckpointWithoutOutput goes one step further — it explicitly frees the output storage right after forward and restores it just-in-time during backward:

forward:
  output = fn(*inputs)                       # normal forward pass
  output.untyped_storage().resize_(0)        # free output's storage (keep metadata)
  register_hook(downstream_tensor)           # register a backward hook

backward:
  hook fires → recompute fn(*inputs) → restore data into output's storage
  → downstream backward reads output's data normally

Problem with the original restore logic

The original code restores the recomputed data by allocating + copying:

for output, recomp in zip(self.outputs, recomputed_outputs):
    output.untyped_storage().resize_(recomp.untyped_storage().size())   # allocate
    output.untyped_storage().copy_(recomp.untyped_storage())            # memcpy

This means two copies of the same activation coexist on the GPU momentarily (output's fresh allocation + recomp's allocation), and the copy_ has bandwidth cost.

Goal

Make output's storage point directly to recomp's memory — zero-copy, zero temporary double allocation.

Why a naive Tensor.set_() does not work

output.set_(recomp.untyped_storage(), ...) is zero-copy at the tensor level, but breaks downstream views. When downstream modules (e.g. TE GroupedLinear) create views via inp.reshape() + torch.split() and save them for backward, those view TensorImpls still reference the old, empty StorageImpl after set_().

Solution: StorageImpl-level data pointer sharing

A ~20-line C++ extension (share_storage) replaces the existing StorageImpl's data_ptr_ to point to the recomputed tensor's memory, with a custom DataPtr deleter that holds a refcounted reference to the source StorageImpl. This is the same StorageImpl::set_data_ptr() API that PyTorch uses internally for IPC storage sharing (_share_cuda_ in StorageSharing.cpp).

Since all TensorImpls — including views — reference the same StorageImpl object, they all see the new data immediately.

Benefits

resize_ + copy_ (before) share_storage (this PR)
Memory copy Full GPU memcpy None (pointer swap)
Peak memory during restore 2× activation (alloc + recomp) 1× activation
View compatible Yes Yes
Version counter No impact No impact
TE GroupedLinear Yes Yes
Implementation Pure Python ~20 lines C++ extension

Files changed

File Change
megatron/core/tensor_parallel/random.py Add share_storage C++ extension; replace resize_ + copy_ with share_storage in _recompute()
tests/unit_tests/tensor_parallel/test_random.py Add view-sharing regression test mimicking TE GroupedLinear pattern

Test plan

  • Convergence test (loss curves match baseline)
  • POC correctness test (output & gradient match)
  • Memory savings validated (24 MB savings via snapshot)
  • View-sharing regression test (reshape + split saved for backward) — fails with Tensor.set_(), passes with share_storage
  • Megatron-LM E2E: local_test_e2e.sh (8 GPU, PP=4, EP=2, MoE + MLA, bf16)

Made with Cursor

Victarry and others added 2 commits March 2, 2026 14:36
Replace `UntypedStorage.resize_() + copy_()` with a StorageImpl-level
data pointer swap in `_recompute()`. This eliminates a full GPU memcpy
and the temporary double allocation during the backward recomputation
phase.

The key insight is that `resize_() + copy_()` restores data into the
same StorageImpl object, so all views (reshape/split) see it — but at
the cost of a full copy. A naive `Tensor.set_()` is zero-copy but only
redirects one TensorImpl, leaving views (e.g. TE GroupedLinear's
inp.reshape() + torch.split() saved for backward) pointing at the old
empty storage.

The solution operates at the StorageImpl level: a ~20-line C++ extension
(`share_storage`) replaces the existing StorageImpl's `data_ptr_` to
point to the recomputed tensor's memory, with a custom DataPtr deleter
that holds a refcounted reference to the source StorageImpl. This is
the same `StorageImpl::set_data_ptr()` API that PyTorch uses internally
for IPC storage sharing (`_share_cuda_` in StorageSharing.cpp).

Benefits:
- Zero-copy: no GPU memcpy, just a pointer swap
- No temporary double allocation during restore
- View-compatible: all TensorImpls sharing the StorageImpl see the data
- No autograd version counter bump (operates below TensorImpl)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Victarry Victarry requested review from a team as code owners March 2, 2026 06:37
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 2, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team March 2, 2026 06:37
@Victarry Victarry self-assigned this Mar 2, 2026
@Victarry Victarry added the Expert Review Apply this label to indicate that your PR is ready for expert review. label Mar 2, 2026
@Phlip79 Phlip79 added Final Review PR is in the "final review" stage complexity: low and removed Expert Review Apply this label to indicate that your PR is ready for expert review. labels Mar 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

complexity: low Final Review PR is in the "final review" stage

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants