Skip to content

Deduplicate split_di_dw_graph by reusing upstream _extract_fwd_bwd_modules#359

Open
xmfan wants to merge 1 commit intomainfrom
xmfan/stack/29
Open

Deduplicate split_di_dw_graph by reusing upstream _extract_fwd_bwd_modules#359
xmfan wants to merge 1 commit intomainfrom
xmfan/stack/29

Conversation

@xmfan
Copy link
Member

@xmfan xmfan commented Mar 10, 2026

Stacked PRs:


Requires pytorch/pytorch#177058

Deduplicate split_di_dw_graph by reusing upstream _extract_fwd_bwd_modules

Delete the ~120-line local copy of _extract_fwd_bwd_modules in favor of
calling the upstream PyTorch function with two new parameters:
ignore_must_be_in_fw_bw=True and omit_aot_autograd_runtime=True.

Requires the corresponding PyTorch change that adds these parameters to
torch/_functorch/partitioners.py.

Authored with Claude.

…dules

Delete the ~120-line local copy of _extract_fwd_bwd_modules in favor of
calling the upstream PyTorch function with two new parameters:
ignore_must_be_in_fw_bw=True and omit_aot_autograd_runtime=True.

Requires the corresponding PyTorch change that adds these parameters to
torch/_functorch/partitioners.py.

Authored with Claude.

stack-info: PR: #359, branch: xmfan/stack/29
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 10, 2026
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Mar 12, 2026
…177058)

Allows for better out of tree path splitting UX: meta-pytorch/autoparallel#359.

Add two new keyword arguments to _extract_fwd_bwd_modules to support
callers that re-partition a backward graph (e.g. splitting into dI/dW
subgraphs for pipelined training):

- ignore_must_be_in_fw_bw: threads through to _extract_graph_with_inputs_outputs
  to disable forward/backward placement enforcement.
- omit_aot_autograd_runtime: skips tangent input handling, version-counter
  sorting, opaque object separation, and fp8 activation quantization —
  postprocessing only needed when wrapping in a custom autograd.Function.

Both default to False, so all existing callers are unaffected.

Test graphs:

bw_gm (full backward):
```python
def forward(self, primals_1, primals_2, not_tngnt_1):
    expand = torch.ops.aten.expand.default(not_tngnt_1, [4, 4]);  not_tngnt_1 = None
    t = torch.ops.aten.t.default(primals_1);  primals_1 = None
    mm_1 = torch.ops.aten.mm.default(t, expand);  t = None
    t_1 = torch.ops.aten.t.default(primals_2);  primals_2 = None
    mm_2 = torch.ops.aten.mm.default(expand, t_1);  expand = t_1 = None
    return (mm_2, mm_1)
```

di_mod (only grads wrt input activations):
```python
def forward(self, primals_1, primals_2, not_tngnt_1):
    expand = torch.ops.aten.expand.default(not_tngnt_1, [4, 4]);  not_tngnt_1 = None
    t_1 = torch.ops.aten.t.default(primals_2);  primals_2 = None
    mm_2 = torch.ops.aten.mm.default(expand, t_1);  t_1 = None
    return (mm_2, primals_1, expand)
```

dw_mod
```python
def forward(self, primals_1, expand):
    t = torch.ops.aten.t.default(primals_1);  primals_1 = None
    mm_1 = torch.ops.aten.mm.default(t, expand);  t = expand = None
    return (mm_1,)
```

Authored with Claude.
Pull Request resolved: #177058
Approved by: https://github.com/soulitzer
@xmfan xmfan marked this pull request as ready for review March 12, 2026 06:07
@xmfan xmfan requested a review from sanketpurandare March 12, 2026 06:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant