Skip to content

[Dev] feat(checkpoint): zero-copy storage sharing in CheckpointWithoutOutput#3641

Merged
Victarry merged 3 commits intoNVIDIA:devfrom
Victarry:denliu/zero_copy_recompute
Mar 2, 2026
Merged

[Dev] feat(checkpoint): zero-copy storage sharing in CheckpointWithoutOutput#3641
Victarry merged 3 commits intoNVIDIA:devfrom
Victarry:denliu/zero_copy_recompute

Conversation

@Victarry
Copy link
Contributor

@Victarry Victarry commented Feb 28, 2026

Zero-Copy Storage Sharing for CheckpointWithoutOutput

PR to main #3649

Summary

Replace UntypedStorage.resize_() + copy_() with a StorageImpl-level data pointer swap in CheckpointWithoutOutput._recompute(). This eliminates a full GPU memcpy and the double allocation during the backward recomputation phase, while remaining compatible with downstream modules that create views (e.g., TransformerEngine GroupedLinear).

Motivation

How CheckpointWithoutOutput works

Standard activation checkpointing (torch.utils.checkpoint) skips saving intermediate activations during forward and recomputes them during backward. However, the output tensors of the checkpointed function are still held by downstream modules for their own backward, so that memory is never freed.

CheckpointWithoutOutput goes one step further — it explicitly frees the output storage right after forward and restores it just-in-time during backward:

forward:
  output = fn(*inputs)                       # normal forward pass
  output.untyped_storage().resize_(0)        # free output's storage (keep metadata)
  register_hook(downstream_tensor)           # register a backward hook

backward:
  hook fires → recompute fn(*inputs) → restore data into output's storage
  → downstream backward reads output's data normally

The GPU memory occupied by the checkpoint output is freed between the end of forward and the hook firing in backward.

Problem with the original restore logic

The original code restores the recomputed data by allocating + copying:

for output, recomp in zip(self.outputs, recomputed_outputs):
    output.untyped_storage().resize_(recomp.untyped_storage().size())   # allocate
    output.untyped_storage().copy_(recomp.untyped_storage())            # memcpy

This means two copies of the same activation coexist on the GPU momentarily (output's fresh allocation + recomp's allocation), and the copy_ has bandwidth cost.

Goal

Make output's storage point directly to recomp's memory — zero-copy, zero temporary double allocation.

Why a naive Tensor.set_() does not work

The obvious approach — output.set_(recomp.untyped_storage(), ...) — is zero-copy at the tensor level, but it breaks downstream views.

When downstream modules consume the checkpoint output, they commonly create views. For example, TransformerEngine's GroupedLinear.forward() does:

inp_view  = inp.reshape(-1, in_features)      # new TensorImpl, shares inp's StorageA
inputmats = torch.split(inp_view, m_splits)    # more TensorImpls, share StorageA
ctx.save_for_backward(*inputmats, ...)         # saved for backward

Multiple TensorImpls all reference the same StorageImpl (StorageA):

  output (orig)  ──→  TensorImpl  ──→  StorageImpl A
  inp_view       ──→  TensorImpl  ──┘       ↑
  inputmats[0]   ──→  TensorImpl  ──────────┘

Tensor.set_() only redirects output's TensorImpl to a new StorageImpl B. The view TensorImpls (inp_view, inputmats) still reference the old, empty StorageA:

  output (orig)  ──→  TensorImpl  ──→  StorageImpl B   ← has data ✓
  inp_view       ──→  TensorImpl  ──→  StorageImpl A   ← still empty ✗
  inputmats[0]   ──→  TensorImpl  ──→  StorageImpl A   ← still empty ✗

This causes TE's backward to crash: Assertion failed: A.has_data() ... Input A does not hold any data!

By contrast, resize_() + copy_() writes data back into the same StorageImpl A, so all views automatically see the restored data. But it requires a full copy.

Solution: StorageImpl-level data pointer sharing

Key insight

Instead of redirecting the TensorImpl (tensor level), we replace the data pointer inside the existing StorageImpl (storage level). Since all TensorImpls — including views — reference the same StorageImpl object, they all see the new data immediately.

  share_storage(output, recomp):

  StorageImpl A's data_ptr  ──→  StorageImpl B's memory (zero-copy)
                                       ↑
             refcounted reference keeps B alive

  Result:
  output (orig)  ──→  TensorImpl  ──→  StorageImpl A ─→ [B's data]  ✓
  inp_view       ──→  TensorImpl  ──┘       ↑                       ✓
  inputmats[0]   ──→  TensorImpl  ──────────┘                       ✓

C++ extension (~20 lines)

PyTorch does not expose StorageImpl::set_data_ptr() to Python (see Appendix: PyTorch API survey). We compile a minimal C++ extension via torch.utils.cpp_extension.load_inline:

#include <torch/extension.h>

void share_storage(at::Tensor dst, at::Tensor src) {
    auto* dst_impl = dst.storage().unsafeGetStorageImpl();

    // Copy src's c10::Storage — increments StorageImpl refcount,
    // preventing src's memory from being freed while dst uses it.
    auto* src_storage_ref = new c10::Storage(src.storage());

    void*       data   = src_storage_ref->data_ptr().get();
    size_t      nbytes = src_storage_ref->nbytes();
    c10::Device device = src_storage_ref->device();

    // Build a DataPtr with a custom deleter that releases the refcount.
    c10::DataPtr shared(
        data,
        static_cast<void*>(src_storage_ref),
        [](void* ctx) { delete static_cast<c10::Storage*>(ctx); },
        device);

    dst_impl->set_data_ptr(std::move(shared));
    dst_impl->set_nbytes(nbytes);
}

This uses StorageImpl::set_data_ptr(), the same C++ API that PyTorch's own _share_cuda_() IPC path uses internally (torch/csrc/StorageSharing.cpp:322).

Call site

share_storage = _get_share_storage()   # lazy compile, cached after first call
for orig, recomp in zip(self.outputs, outputs):
    share_storage(orig, recomp)

Safety analysis

View compatible — All TensorImpls reference the same StorageImpl. Replacing its data_ptr makes the change visible to every view immediately.

No version counter bumpStorageImpl::set_data_ptr() operates below the TensorImpl layer and does not trigger TensorImpl::bump_version(). No autograd "modified by an inplace operation" errors.

No dangling pointers — The custom DataPtr holds a refcounted c10::Storage reference to src's StorageImpl. Memory lifetime:

Phase What happens Data valid?
_recompute() StorageA's data_ptr → StorageB's memory; DataPtr context holds StorageB ref Yes
backward() runs Downstream ops read through StorageA → StorageB's memory Yes
ctx.outputs = None Python releases recomp tensors, but StorageA's DataPtr still holds StorageB ref Yes
orig GC'd StorageA destroyed → DataPtr deleter releases StorageB ref → memory returned to CUDA caching allocator N/A

Recomp autograd graph unaffected — Only StorageA (dst) is modified. StorageB (src) and all intermediate tensors in the recomp graph are untouched. torch.autograd.backward(outputs, grad_outputs) executes normally.

Autograd metadata preservedshare_storage only touches StorageImpl.data_ptr_ and StorageImpl.size_bytes_. requires_grad, grad_fn, grad, sizes_, strides_, storage_offset_, and version counter on TensorImpl are all unchanged.

Comparison

resize_ + copy_ (before) share_storage (this PR)
Memory copy Full GPU memcpy None (pointer swap)
Peak memory during restore 2x activation (alloc + recomp) 1x activation
View compatible Yes Yes
Version counter No impact No impact
TE GroupedLinear Yes Yes
Implementation Pure Python ~20 lines C++ extension

Verification

Convergence Test

image

POC tests (all 5 passed)

Correctness (output & gradient match)              ✓
Memory savings (snapshot-based)                     ✓  (24 MB savings)
Training loop (10 steps, loss decreasing)           ✓
Multi-checkpoint model (2 sub-layers)               ✓
View-sharing (reshape+split saved for backward)     ✓  ← new regression test

The view-sharing test specifically mimics TE GroupedLinear's pattern: the downstream layer performs inp.reshape() + torch.split() and saves those views via ctx.save_for_backward(). This test fails with Tensor.set_() but passes with share_storage.

Megatron-LM E2E

bash local_test_e2e.sh  # 8 GPU, PP=4, EP=2, MoE + MLA, bf16

Files changed

File Change
megatron/core/tensor_parallel/random.py Add share_storage C++ extension; replace resize_ + copy_ with share_storage in _recompute()

Appendix: PyTorch API survey

We reviewed every candidate PyTorch API. None provides a general-purpose way to modify an existing StorageImpl's data_ptr_ from Python.

API Level Modifies existing StorageImpl? View compatible? Why not usable
Tensor.set_() TensorImpl No (switches StorageImpl ref) No Views keep referencing old StorageImpl; bumps version counter
torch.utils.swap_tensors() TensorImpl No (swaps TensorImpl ptrs) No Also swaps grad_fn, breaks autograd graph
_construct_storage_from_data_pointer() New StorageImpl No (creates new) No Non-owning DataPtr (dangling risk); can't reassign old StorageImpl
UntypedStorage._set_cdata() Python wrapper No (swaps Python ref) No C++ TensorImpls unaffected
_share_cuda_() / _share_fd_cpu_() StorageImpl Yes Yes IPC-only; cannot point to arbitrary in-process CUDA memory
safely_set_viewless_tensor_data() TensorImpl No (tensor.data = ...) No Equivalent to set_()

The only internal PyTorch code paths that call StorageImpl::set_data_ptr() are the IPC sharing methods (_share_cuda_, _share_fd_cpu_ in StorageSharing.cpp). Our extension uses the same C++ API.


Appendix: PyTorch internals reference

torch.Tensor
  │
  ├─ .data, .grad, .grad_fn, .requires_grad   (autograd metadata)
  │
  └─ TensorImpl  (C++)
       ├─ sizes_, strides_, storage_offset_    (shape metadata)
       ├─ version_counter_                     (inplace op counter)
       └─ Storage  (c10::Storage)
            └─ intrusive_ptr<StorageImpl>
                 ├─ data_ptr_  (c10::DataPtr)  ← share_storage modifies here
                 │    ├─ data_   (void*)        ← raw data pointer
                 │    ├─ ctx_    (void*)        ← ref-holding context
                 │    └─ deleter (function)     ← called on destruction
                 └─ size_bytes_                 ← share_storage modifies here

Key:
  Tensor.set_()   → changes TensorImpl's Storage ref   → other TensorImpls unaffected
  resize_/copy_   → writes into the same StorageImpl    → all TensorImpls see the change
  share_storage   → replaces StorageImpl's data_ptr_    → all TensorImpls see the change, zero-copy

@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 28, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Victarry
Copy link
Contributor Author

/ok to test 8008bec

@svcnvidia-nemo-ci svcnvidia-nemo-ci added this to the Core 0.16 milestone Feb 28, 2026
Replace `UntypedStorage.resize_() + copy_()` with a StorageImpl-level
data pointer swap in `_recompute()`. This eliminates a full GPU memcpy
and the temporary double allocation during the backward recomputation
phase.

The key insight is that `resize_() + copy_()` restores data into the
same StorageImpl object, so all views (reshape/split) see it — but at
the cost of a full copy. A naive `Tensor.set_()` is zero-copy but only
redirects one TensorImpl, leaving views (e.g. TE GroupedLinear's
inp.reshape() + torch.split() saved for backward) pointing at the old
empty storage.

The solution operates at the StorageImpl level: a ~20-line C++ extension
(`share_storage`) replaces the existing StorageImpl's `data_ptr_` to
point to the recomputed tensor's memory, with a custom DataPtr deleter
that holds a refcounted reference to the source StorageImpl. This is
the same `StorageImpl::set_data_ptr()` API that PyTorch uses internally
for IPC storage sharing (`_share_cuda_` in StorageSharing.cpp).

Benefits:
- Zero-copy: no GPU memcpy, just a pointer swap
- No temporary double allocation during restore
- View-compatible: all TensorImpls sharing the StorageImpl see the data
- No autograd version counter bump (operates below TensorImpl)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Victarry Victarry force-pushed the denliu/zero_copy_recompute branch from 8008bec to c833c0a Compare February 28, 2026 04:21
@Victarry
Copy link
Contributor Author

/ok to test c833c0a

@Victarry
Copy link
Contributor Author

/ok to test c833c0a

@Victarry
Copy link
Contributor Author

/ok to test c833c0a

@Victarry
Copy link
Contributor Author

/ok to test 33d4d26

@Victarry Victarry marked this pull request as ready for review February 28, 2026 10:31
@Victarry Victarry requested review from a team as code owners February 28, 2026 10:31
@yaox12 yaox12 added the Expert Review Apply this label to indicate that your PR is ready for expert review. label Mar 2, 2026
@Victarry
Copy link
Contributor Author

Victarry commented Mar 2, 2026

/ok to test 74d10da

@Victarry Victarry added this pull request to the merge queue Mar 2, 2026
@svcnvidia-nemo-ci
Copy link

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22567814399

Merged via the queue into NVIDIA:dev with commit bc9298c Mar 2, 2026
78 of 80 checks passed
@Victarry Victarry deleted the denliu/zero_copy_recompute branch March 2, 2026 09:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Expert Review Apply this label to indicate that your PR is ready for expert review.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants