diff --git a/autoparallel/graph_passes/split_di_dw_graph.py b/autoparallel/graph_passes/split_di_dw_graph.py index ed6d6b97..05a19e34 100644 --- a/autoparallel/graph_passes/split_di_dw_graph.py +++ b/autoparallel/graph_passes/split_di_dw_graph.py @@ -4,25 +4,15 @@ # LICENSE file in the root directory of this source tree. import copy -import itertools import operator -import sympy import torch import torch.fx as fx from torch._functorch.partitioners import ( - SavedForBackwardsAOTOutput, + _extract_fwd_bwd_modules, _extract_fwd_bwd_outputs, _extract_graph_with_inputs_outputs, - _is_backward_state, - _is_bwd_seed_offset, - _is_fwd_seed_offset, - _is_primal, - _remove_by_name, - find_symbol_binding_fx_nodes, - free_symbols, is_sym_node, - is_symbol_binding_fx_node, ) from torch.utils._ordered_set import OrderedSet @@ -64,131 +54,6 @@ def reorder_output_grads(bw_gm, num_weight_gradients): return len(grad_inputs) -# This is a copy of the function used by the default partitioner, -# which does *not* reorder symint activations. -# This is reordering is needed by the custom autograd.Function in AOTDispatcher, -# but isn't needed in our dI/dW splitting since there is no autograd in the loop. -# TODO: provide a way to gt this behavior automatically out of the default partitioner -def _extract_fwd_bwd_modules( - joint_module: fx.GraphModule, - saved_values: list[fx.Node], - saved_sym_nodes: list[fx.Node], - *, - num_fwd_outputs: int, -) -> tuple[fx.GraphModule, fx.GraphModule]: - ( - fwd_outputs, - bwd_outputs, - fwd_outputs_descs, - bwd_outputs_descs, - ) = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) - placeholders = joint_module.graph.find_nodes(op="placeholder") - primal_inputs = [*filter(_is_primal, placeholders)] - fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)] - bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)] - backward_state_inputs = [*filter(_is_backward_state, placeholders)] - - bwd_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, - saved_values + saved_sym_nodes + bwd_seed_offset_inputs, - bwd_outputs, - bwd_outputs_descs, - "backward", - ignore_must_be_in_fw_bw=True, - ) - - distributed_enabled = torch.distributed.is_available() - - for node in bwd_graph.find_nodes(op="placeholder"): - # This is to filter out saved values that don't actually end up being used by the backwards pass - if not node.users: - _remove_by_name(saved_values, node.name) - _remove_by_name(saved_sym_nodes, node.name) - # wait_tensor is a bit special: if we have a "dead activation" that is not used in the bw, - # but this dead activation is actually a collective, - # then the collective will generally by followed by a wait_tensor() call. - # we need to peak one node further to see if this wait_tensor is dead as well. - elif distributed_enabled and all( - n.target is torch.ops._c10d_functional.wait_tensor.default - and len(n.users) == 0 - for n in node.users - ): - _remove_by_name(saved_values, node.name) - _remove_by_name(saved_sym_nodes, node.name) - elif _is_backward_state(node): - # BackwardState is saved directly - _remove_by_name(saved_values, node.name) - assert backward_state_inputs - - # Now that we have the finalized list of saved values, we need to ensure - # we propagate all symbols which are referenced by backwards inputs. - # These are not directly used in the graph but are required for downstream - # sizevar assignment - saved_symbols: OrderedSet[sympy.Symbol] = OrderedSet() - saved_sym_nodes_binding = [] - saved_sym_nodes_derived = [] - - # Some symbols may already be bound in the directly saved_sym_nodes, - # keep track of them so we don't re-bind them - for node in saved_sym_nodes: - symbol = is_symbol_binding_fx_node(node) - if symbol: - saved_symbols.add(symbol) - saved_sym_nodes_binding.append(node) - else: - saved_sym_nodes_derived.append(node) - - # Now go through all of the prospective backward inputs and track any - # other symbols we need to bind - symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph) - for node in itertools.chain(saved_sym_nodes_derived, saved_values): - if "val" not in node.meta: - continue - new_symbols = free_symbols(node.meta["val"]) - saved_symbols - # NB: Deterministic order please! - for s in sorted(new_symbols, key=lambda s: s.name): - # NB: For well formed graphs, the symbol should always be present, - # but we also have ways to produce ill-formed graphs, e.g., direct - # make_fx usages, so don't choke in this case - if s not in symbol_bindings: - continue - saved_sym_nodes_binding.append(symbol_bindings[s]) - saved_symbols |= new_symbols - - # Update saved_sym_nodes that are now reordered to have all bindings at - # front. This can also be used later on to figure out the position of saved - # sym nodes in the output of fwd graph. - saved_sym_nodes.clear() - saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived) - - # Now, we re-generate the fwd/bwd graphs. - # NB: This might increase compilation time, but I doubt it matters - fwd_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, - primal_inputs + fwd_seed_offset_inputs, - fwd_outputs + saved_values + saved_sym_nodes, - fwd_outputs_descs - + [ - SavedForBackwardsAOTOutput(i) - for i in range(len(saved_values) + len(saved_sym_nodes)) - ], - "forward", - ignore_must_be_in_fw_bw=True, - ) - bwd_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, - saved_values + saved_sym_nodes + bwd_seed_offset_inputs + backward_state_inputs, - bwd_outputs, - bwd_outputs_descs, - "backward", - ignore_must_be_in_fw_bw=True, - ) - - fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph) - bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph) - return fwd_module, bwd_module - - # TODO: in theory we can infer num_weight_gradients from the graph metadata directly def split_di_dw_graph( bw_gm_old: fx.GraphModule, *, num_weight_gradients: int @@ -230,6 +95,11 @@ def split_di_dw_graph( saved_values = [] saved_sym_nodes = [] + # TODO: this classification loop is a simplified version of default_partition's + # node classification. It does not handle: get_attr nodes, _assert_scalar/profiler + # ops, MUST_SAVE tags, impure/effectful ops, force_save_collectives, + # force_save_bw_mutation_src, must_recompute skipping, or post-split DCE. + # Ideally we would call default_partition directly instead of reimplementing. for node in bw_gm.graph.nodes: if node.name not in bw_inputs_gm_node_names: # Not handling mutations for now, @@ -262,5 +132,7 @@ def split_di_dw_graph( saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_input_gradients, + ignore_must_be_in_fw_bw=True, + omit_aot_autograd_runtime=True, ) return bw_inputs, bw_weights, num_input_gradients