Skip to content

Decouple FSDP all-gather tagging from activation checkpointing#336

Draft
fmassa wants to merge 1 commit intomainfrom
fmassa/ac_fsdp_tagging
Draft

Decouple FSDP all-gather tagging from activation checkpointing#336
fmassa wants to merge 1 commit intomainfrom
fmassa/ac_fsdp_tagging

Conversation

@fmassa
Copy link
Contributor

@fmassa fmassa commented Mar 5, 2026

The enable_ac flag in AutoParallel was a temporary workaround from when PyTorch's tracing couldn't capture user-specified activation checkpointing. PyTorch now properly propagates user AC annotations (node.meta["recompute"] and node.meta["ac_graph_id"]) during AOT tracing, so AutoParallel no longer needs to apply its own AC staging policy.

The only non-user-AC concern bundled into enable_ac was tagging FSDP all-gather collectives for recomputation/saving, which is a parallelism strategy decision that should happen unconditionally based on reshard_after_forward.

This PR:

  • Renames ac_joint_pass to tag_fsdp_collectives_for_recomputation, keeping only the FSDP tagging logic
  • Removes the AC staging/policy functions (mark_nodes_as_must_save_to_stage_recomputation, _apply_ac_policy, _mark_nodes_as_must_save, and related helpers) that are now handled by user-specified torch.utils.checkpoint.checkpoint()
  • Removes enable_ac and ac_stage_size_in_GiB parameters from AutoParallel.init
  • Makes FSDP collective tagging unconditional (no longer gated behind enable_ac)

Authored with Claude.

The enable_ac flag in AutoParallel was a temporary workaround from when PyTorch's tracing couldn't capture
user-specified activation checkpointing. PyTorch now properly propagates user AC annotations (node.meta["recompute"]
and node.meta["ac_graph_id"]) during AOT tracing, so AutoParallel no longer needs to apply its own AC staging policy.

The only non-user-AC concern bundled into enable_ac was tagging FSDP all-gather collectives for recomputation/saving,
which is a parallelism strategy decision that should happen unconditionally based on reshard_after_forward.

This PR:
- Renames ac_joint_pass to tag_fsdp_collectives_for_recomputation, keeping only the FSDP tagging logic
- Removes the AC staging/policy functions (mark_nodes_as_must_save_to_stage_recomputation, _apply_ac_policy,
_mark_nodes_as_must_save, and related helpers) that are now handled by user-specified
torch.utils.checkpoint.checkpoint()
- Removes enable_ac and ac_stage_size_in_GiB parameters from AutoParallel.__init__
- Makes FSDP collective tagging unconditional (no longer gated behind enable_ac)

Authored with Claude.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 5, 2026
@fmassa fmassa marked this pull request as draft March 6, 2026 15:15
Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants