diff --git a/autoparallel/api.py b/autoparallel/api.py index a52bb549..1aab612b 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -28,7 +28,9 @@ from .apply_sharding import apply_sharding_to_model from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast -from .graph_passes.activation_checkpointing import ac_joint_pass +from .graph_passes.activation_checkpointing import ( + tag_fsdp_collectives_for_recomputation, +) from .graph_passes.graph_utils import ( _add_alias, _replace_view_mm_view_with_einsum, @@ -257,9 +259,6 @@ def __init__( mesh: DeviceMesh, mp_policy: Optional[MixedPrecisionPolicy] = None, compile: bool = False, - enable_ac: bool = True, - # None means 'auto' - ac_stage_size_in_GiB: Optional[Union[float, str]] = "auto", reshard_after_forward: bool = True, dynamic: bool = False, numerics_logger: NumericsLogger | None = None, @@ -305,8 +304,6 @@ def __init__( ) else: self.compiler_fn = boxed_nop_preserve_node_meta # type: ignore[assignment] - self.enable_ac = enable_ac - self.ac_stage_size_in_GiB = ac_stage_size_in_GiB self.reshard_after_forward = reshard_after_forward if dynamic: @@ -552,10 +549,9 @@ def _apply_placement_common(self, sharding_placement): ), ) - if self.enable_ac: - ac_joint_pass( - parallel_gm.graph, self.ac_stage_size_in_GiB, self.reshard_after_forward - ) + tag_fsdp_collectives_for_recomputation( + parallel_gm.graph, self.reshard_after_forward + ) # now rename input/param/tangent/output/grad_param/grad_input nodes following # our convention # apply_node_renaming( @@ -921,8 +917,6 @@ def auto_parallel( mesh, mp_policy=mp_policy, compile=compile, - # enable_ac=True, - enable_ac=False, ) as autop: # Add constraints # autop.add_parameter_memory_constraint(low=None, high=None) diff --git a/autoparallel/graph_passes/activation_checkpointing.py b/autoparallel/graph_passes/activation_checkpointing.py index 060491c8..820a3d30 100644 --- a/autoparallel/graph_passes/activation_checkpointing.py +++ b/autoparallel/graph_passes/activation_checkpointing.py @@ -4,17 +4,10 @@ # LICENSE file in the root directory of this source tree. import logging import operator -from collections import defaultdict -from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional import torch -from torch._functorch.partitioners import ( - _has_tag_is_backward, - _has_tag_must_be_in_forward, - _size_of, -) -from torch.utils._ordered_set import OrderedSet +from torch._functorch.partitioners import _has_tag_must_be_in_forward from torch.utils.checkpoint import CheckpointPolicy logger: logging.Logger = logging.getLogger(__name__) @@ -26,14 +19,6 @@ AP_AC_GRAPH_ID = 100000 -# reimplement torch._functorch.partitioners.must_recompute -# to only check for MUST_RECOMPUTE tag, and not PREFER_RECOMPUTE -# For now there isn't any distinction in the partitioner between both -# and I think this is a bug -def must_recompute(node: torch.fx.Node) -> bool: - return node.meta.get("recompute", None) is CheckpointPolicy.MUST_RECOMPUTE - - def is_graph_input(node: torch.fx.Node) -> bool: return node.op == "placeholder" @@ -266,265 +251,11 @@ def force_recompute_node(node): force_recompute_node(ag_node.all_input_nodes[0]) -INT_INF = int(1e9) - - -# NOTE: this is taken from PyTorch partitioner -def get_required_fwd_nodes( - joint_graph: torch.fx.Graph, -) -> OrderedSet[torch.fx.Node]: - """ - Return the set of nodes that are required in the forward graph. - NOTE: this is doing similar things as classify_nodes() in _functorch/partitioenrs.py - where nodes are classified into three types -- fwd, bwd, and unclaimed - both bwd and unclaimed nodes have partitioner_tag equal to "is_backward" - """ - required_fwd_nodes: OrderedSet[torch.fx.Node] = OrderedSet() - for node in joint_graph.nodes: - if node.op == "placeholder" and "tangents" in node.target: - continue - if node.op == "output": - continue - if _has_tag_is_backward(node): - continue - required_fwd_nodes.add(node) - return required_fwd_nodes - - -# NOTE: this is taken from PyTorch partitioner -def get_node_distance_to_bwd( - joint_graph: torch.fx.Graph, - get_required_fwd_nodes: OrderedSet[torch.fx.Node], -) -> dict[torch.fx.Node, int]: - """ - Compute and return the distance of all nodes to the closest backward node. - If a node is not an ancestor of a backward node, then its distance is INT_INF. - NOTE: this is adapted from - https://github.com/pytorch/pytorch/blob/3196a3aca0f16792820158cfd451cb977f99ac7e/torch/_functorch/partitioners.py#L2089-L2097 - """ - dist_from_bw = {} - for node in reversed(joint_graph.nodes): - if node.op == "output": - dist_from_bw[node] = INT_INF - elif node not in get_required_fwd_nodes: - dist_from_bw[node] = 0 - else: - dist_from_bw[node] = INT_INF - for user in node.users: - dist_from_bw[node] = min(dist_from_bw[node], dist_from_bw[user] + 1) - return dist_from_bw - - -# NOTE: this is taken from PyTorch partitioner -def get_all_recomputable_forward_nodes( - joint_graph: torch.fx.Graph, -) -> OrderedSet[torch.fx.Node]: - """ - Return the set of all forward nodes that are recomputable - """ - required_fwd_nodes = get_required_fwd_nodes(joint_graph) - dist_from_bw = get_node_distance_to_bwd(joint_graph, required_fwd_nodes) - fwd_recomputable_nodes: OrderedSet[torch.fx.Node] = OrderedSet() - for node in joint_graph.nodes: - if ( - node in required_fwd_nodes - and dist_from_bw[node] < INT_INF - and node.op != "placeholder" - ): - fwd_recomputable_nodes.add(node) - return fwd_recomputable_nodes - - -def _mark_nodes_as_must_save(must_save_nodes: list[torch.fx.Node]) -> None: - """ - Given a list of nodes, mark them as must save. - """ - skipped_nodes = {} - for node in must_save_nodes: - if ( - node.meta.get("recompute", None) is not None - and node.meta.get("ac_graph_id", -1) != AP_AC_GRAPH_ID - ): - # Let user annotations take precedence - skipped_nodes[node] = node.meta["recompute"] - continue - node.meta["recompute"] = CheckpointPolicy.MUST_SAVE - print(f"mark_nodes_as_must_save, attempting to mark nodes: {must_save_nodes}") - print(f"mark_nodes_as_must_save, skipping already marked nodes: {skipped_nodes}") - - -def mark_nodes_as_must_save_to_stage_recomputation( - joint_graph: torch.fx.Graph, - stage_size_in_GiB: Optional[Union[float, str]] = "auto", -) -> None: - """ - Marks specific nodes as "must save" to optimize memory usage during recomputation. - With aggressive recomputation strategies, we often encounter situations where long chains - of forward nodes must be recomputed before executing backward pass nodes, causing high - peak memory usage. This function breaks these recomputation chains into smaller stages - based by periodically saving itermediate nodes, keeping peak memory usage below. - Args: - joint_graph: The joint graph containing both forward and backward nodes - stage_size_in_GiB: Target memory size per stage in GiB. None means no stage - recomputation, "auto" means we use sqrt(total_used_memory) as stage size. - """ - if stage_size_in_GiB is None: - return - - fwd_recomputable_nodes = get_all_recomputable_forward_nodes(joint_graph) - - # Initialize all nodes as 'prefer recompute' and then adjust only the must-save ones below - for node in fwd_recomputable_nodes: - if node.meta.get("recompute", None) is not None: - # do not mess with allgather nodes that have already been marked recompute! - continue - if node.target is operator.getitem: - # we need to be a bit careful: we are trying to manually emulate setting "precompute" tags - # in the same way that compiel does when it encounters userland SAC. - # - # torch.compile does this by using TorchDispatchModes to intercept ops as they are traced, - # and setting their "recompute" tag. - # - # However, TorchDispatchModes *only* intercept OpOverloads (and HOPs) - # getitem is neither, and so in vanilla torch.compile usage, - # getitem nodes recieve no tags. - # - # What happens if we blindly set all nodes to PREFER_RECOMPUTE? Example bad outcome: - # - user is using attention, so we see this series of ops in the joint graph: - # attention_fw -> getitem -> attention_bw (the getitem is an output used for the bw) - # - user runs SAC, and marks attention_fw as MUST_SAVE - # - if we mark getitem as PREFER_RECOMPUTE, and attention_fw as MUST_SAVE, - # the partitioner ends up generating an invalid graph. - # Today the partitioner relies on the fact that getitem's recompute behavior - # is implicitly determined by the recompute behavior of the multi-output op preceding it. - continue - node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE - node.meta["ac_graph_id"] = AP_AC_GRAPH_ID - - # get the mapping between node name and node - name_to_node_mapping = {} - for node in fwd_recomputable_nodes: - name_to_node_mapping[node.name] = node - - # populate node_to_predecessors, accounting for must_recompute nodes. In particular, - # if a node is marked as must recompute, then for its users, their predecessors should - # be updated to be instead the predecessors of the must recompute node. - node_to_predecessors = defaultdict(OrderedSet) - for node in fwd_recomputable_nodes: - node_to_predecessors[node] = OrderedSet( - [pred for pred in node.all_input_nodes if pred in fwd_recomputable_nodes] - ) - for node in fwd_recomputable_nodes: - if not must_recompute(node): - continue - for user in node.users: - if user in fwd_recomputable_nodes: - node_to_predecessors[user].remove(node) - node_to_predecessors[user].update(node_to_predecessors[node]) - - # populate node_to_last_usage - # if A is last used by B, then A \in node_to_last_usage[B] - node_to_last_usage = defaultdict(OrderedSet) - last_used_by = {} - for node in fwd_recomputable_nodes: - last_used_by[node] = node - for pred in node_to_predecessors[node]: - last_used_by[pred] = node - for producer, consumer in last_used_by.items(): - node_to_last_usage[consumer].add(producer) - - # loop through nodes in order of the forward graph and keep track of the following: - # for each node, right before its execution, the output of what nodes are in memory. - @dataclass - class NodeCutScore: - tot_mem: float - alive_node_names: OrderedSet[str] - - alive_nodes = OrderedSet() - node2score = {} - for node in fwd_recomputable_nodes: - if not must_recompute(node): - alive_nodes.add(node) - for a in node_to_last_usage[node]: - alive_nodes.remove(a) - tot_mem = sum(_size_of(node) for node in alive_nodes) - node2score[node] = NodeCutScore( - tot_mem, OrderedSet([n.name for n in alive_nodes]) - ) - - # divide the graph into stages with roughly equal memory usage. - stages = defaultdict(OrderedSet) - cum_mem_so_far = 0 - curr_stage_idx = 0 - - if stage_size_in_GiB == "auto": - total_used_memory = sum( - _size_of(node) - for node in fwd_recomputable_nodes - if not must_recompute(node) - ) - total_used_memory_in_GiB = total_used_memory / 2**30 - stage_size_in_GiB = total_used_memory_in_GiB**0.5 - print(f"Computed stage_size {stage_size_in_GiB=}") - - target_mem = stage_size_in_GiB * 2**30 - for node in fwd_recomputable_nodes: - stages[curr_stage_idx].add(node) - if not must_recompute(node): - cum_mem_so_far += _size_of(node) - if cum_mem_so_far >= target_mem: - curr_stage_idx += 1 - cum_mem_so_far = 0 - - # loop through each stage and pick the best node to cut on, and save - # the nodes that will be marked as must save. - nodes_to_save = OrderedSet() - for stage in stages.values(): - best_node = min(stage, key=lambda x: node2score[x].tot_mem) - nodes_to_save.update(node2score[best_node].alive_node_names) - _mark_nodes_as_must_save([name_to_node_mapping[n] for n in nodes_to_save]) - - -def _apply_ac_policy(joint_graph: torch.fx.Graph, save_list: set[torch.ops.OpOverload]): - """ - This is not very generic, and just applies an AC policy similar to what - TorchTitan is doing. I think we should just replace this altogether with - torch._functorch.config.activation_memory_budget - """ - fwd_recomputable_nodes = get_all_recomputable_forward_nodes(joint_graph) - must_save_nodes = [] - counter = 0 - for node in fwd_recomputable_nodes: - if node.target in save_list: - if node.target == torch.ops.aten.mm.default: - if counter % 2 == 0: - counter += 1 - else: - counter += 1 - continue - must_save_nodes.append(node) - _mark_nodes_as_must_save(must_save_nodes) - - -def ac_joint_pass( +def tag_fsdp_collectives_for_recomputation( graph: torch.fx.Graph, - ac_stage_size_in_GiB: Optional[Union[float, str]] = 2.0, reshard_after_forward: bool = True, -): +) -> None: if reshard_after_forward: force_recompute_fsdp_all_gather(graph) else: force_save_fsdp_all_gather(graph) - mark_nodes_as_must_save_to_stage_recomputation( - graph, stage_size_in_GiB=ac_stage_size_in_GiB - ) - - # TODO: we need to also enable sdpa perfectly mimic the TorchTitan - # policy, but this is not working yet - save_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops.aten._scaled_dot_product_cudnn_attention.default, - } - _apply_ac_policy(graph, save_list=save_list)