Deduplicate split_di_dw_graph by reusing upstream _extract_fwd_bwd_modules#359
Open
Deduplicate split_di_dw_graph by reusing upstream _extract_fwd_bwd_modules#359
Conversation
…dules Delete the ~120-line local copy of _extract_fwd_bwd_modules in favor of calling the upstream PyTorch function with two new parameters: ignore_must_be_in_fw_bw=True and omit_aot_autograd_runtime=True. Requires the corresponding PyTorch change that adds these parameters to torch/_functorch/partitioners.py. Authored with Claude. stack-info: PR: #359, branch: xmfan/stack/29
pytorchmergebot
pushed a commit
to pytorch/pytorch
that referenced
this pull request
Mar 12, 2026
…177058) Allows for better out of tree path splitting UX: meta-pytorch/autoparallel#359. Add two new keyword arguments to _extract_fwd_bwd_modules to support callers that re-partition a backward graph (e.g. splitting into dI/dW subgraphs for pipelined training): - ignore_must_be_in_fw_bw: threads through to _extract_graph_with_inputs_outputs to disable forward/backward placement enforcement. - omit_aot_autograd_runtime: skips tangent input handling, version-counter sorting, opaque object separation, and fp8 activation quantization — postprocessing only needed when wrapping in a custom autograd.Function. Both default to False, so all existing callers are unaffected. Test graphs: bw_gm (full backward): ```python def forward(self, primals_1, primals_2, not_tngnt_1): expand = torch.ops.aten.expand.default(not_tngnt_1, [4, 4]); not_tngnt_1 = None t = torch.ops.aten.t.default(primals_1); primals_1 = None mm_1 = torch.ops.aten.mm.default(t, expand); t = None t_1 = torch.ops.aten.t.default(primals_2); primals_2 = None mm_2 = torch.ops.aten.mm.default(expand, t_1); expand = t_1 = None return (mm_2, mm_1) ``` di_mod (only grads wrt input activations): ```python def forward(self, primals_1, primals_2, not_tngnt_1): expand = torch.ops.aten.expand.default(not_tngnt_1, [4, 4]); not_tngnt_1 = None t_1 = torch.ops.aten.t.default(primals_2); primals_2 = None mm_2 = torch.ops.aten.mm.default(expand, t_1); t_1 = None return (mm_2, primals_1, expand) ``` dw_mod ```python def forward(self, primals_1, expand): t = torch.ops.aten.t.default(primals_1); primals_1 = None mm_1 = torch.ops.aten.mm.default(t, expand); t = expand = None return (mm_1,) ``` Authored with Claude. Pull Request resolved: #177058 Approved by: https://github.com/soulitzer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked PRs:
Requires pytorch/pytorch#177058
Deduplicate split_di_dw_graph by reusing upstream _extract_fwd_bwd_modules
Delete the ~120-line local copy of _extract_fwd_bwd_modules in favor of
calling the upstream PyTorch function with two new parameters:
ignore_must_be_in_fw_bw=True and omit_aot_autograd_runtime=True.
Requires the corresponding PyTorch change that adds these parameters to
torch/_functorch/partitioners.py.
Authored with Claude.