Decouple FSDP all-gather tagging from activation checkpointing#336
Draft
Decouple FSDP all-gather tagging from activation checkpointing#336
Conversation
The enable_ac flag in AutoParallel was a temporary workaround from when PyTorch's tracing couldn't capture user-specified activation checkpointing. PyTorch now properly propagates user AC annotations (node.meta["recompute"] and node.meta["ac_graph_id"]) during AOT tracing, so AutoParallel no longer needs to apply its own AC staging policy. The only non-user-AC concern bundled into enable_ac was tagging FSDP all-gather collectives for recomputation/saving, which is a parallelism strategy decision that should happen unconditionally based on reshard_after_forward. This PR: - Renames ac_joint_pass to tag_fsdp_collectives_for_recomputation, keeping only the FSDP tagging logic - Removes the AC staging/policy functions (mark_nodes_as_must_save_to_stage_recomputation, _apply_ac_policy, _mark_nodes_as_must_save, and related helpers) that are now handled by user-specified torch.utils.checkpoint.checkpoint() - Removes enable_ac and ac_stage_size_in_GiB parameters from AutoParallel.__init__ - Makes FSDP collective tagging unconditional (no longer gated behind enable_ac) Authored with Claude.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The
enable_acflag in AutoParallel was a temporary workaround from when PyTorch's tracing couldn't capture user-specified activation checkpointing. PyTorch now properly propagates user AC annotations (node.meta["recompute"]andnode.meta["ac_graph_id"]) during AOT tracing, so AutoParallel no longer needs to apply its own AC staging policy.The only non-user-AC concern bundled into enable_ac was tagging FSDP all-gather collectives for recomputation/saving, which is a parallelism strategy decision that should happen unconditionally based on reshard_after_forward.
This PR:
tag_fsdp_collectives_for_recomputation, keeping only the FSDP tagging logicmark_nodes_as_must_save_to_stage_recomputation,_apply_ac_policy,_mark_nodes_as_must_save, and related helpers) that are now handled by user-specified torch.utils.checkpoint.checkpoint()Authored with Claude.