Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 8 additions & 136 deletions autoparallel/graph_passes/split_di_dw_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,15 @@
# LICENSE file in the root directory of this source tree.

import copy
import itertools
import operator

import sympy
import torch
import torch.fx as fx
from torch._functorch.partitioners import (
SavedForBackwardsAOTOutput,
_extract_fwd_bwd_modules,
_extract_fwd_bwd_outputs,
_extract_graph_with_inputs_outputs,
_is_backward_state,
_is_bwd_seed_offset,
_is_fwd_seed_offset,
_is_primal,
_remove_by_name,
find_symbol_binding_fx_nodes,
free_symbols,
is_sym_node,
is_symbol_binding_fx_node,
)
from torch.utils._ordered_set import OrderedSet

Expand Down Expand Up @@ -64,131 +54,6 @@ def reorder_output_grads(bw_gm, num_weight_gradients):
return len(grad_inputs)


# This is a copy of the function used by the default partitioner,
# which does *not* reorder symint activations.
# This is reordering is needed by the custom autograd.Function in AOTDispatcher,
# but isn't needed in our dI/dW splitting since there is no autograd in the loop.
# TODO: provide a way to gt this behavior automatically out of the default partitioner
def _extract_fwd_bwd_modules(
joint_module: fx.GraphModule,
saved_values: list[fx.Node],
saved_sym_nodes: list[fx.Node],
*,
num_fwd_outputs: int,
) -> tuple[fx.GraphModule, fx.GraphModule]:
(
fwd_outputs,
bwd_outputs,
fwd_outputs_descs,
bwd_outputs_descs,
) = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
placeholders = joint_module.graph.find_nodes(op="placeholder")
primal_inputs = [*filter(_is_primal, placeholders)]
fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)]
bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)]
backward_state_inputs = [*filter(_is_backward_state, placeholders)]

bwd_graph = _extract_graph_with_inputs_outputs(
joint_module.graph,
saved_values + saved_sym_nodes + bwd_seed_offset_inputs,
bwd_outputs,
bwd_outputs_descs,
"backward",
ignore_must_be_in_fw_bw=True,
)

distributed_enabled = torch.distributed.is_available()

for node in bwd_graph.find_nodes(op="placeholder"):
# This is to filter out saved values that don't actually end up being used by the backwards pass
if not node.users:
_remove_by_name(saved_values, node.name)
_remove_by_name(saved_sym_nodes, node.name)
# wait_tensor is a bit special: if we have a "dead activation" that is not used in the bw,
# but this dead activation is actually a collective,
# then the collective will generally by followed by a wait_tensor() call.
# we need to peak one node further to see if this wait_tensor is dead as well.
elif distributed_enabled and all(
n.target is torch.ops._c10d_functional.wait_tensor.default
and len(n.users) == 0
for n in node.users
):
_remove_by_name(saved_values, node.name)
_remove_by_name(saved_sym_nodes, node.name)
elif _is_backward_state(node):
# BackwardState is saved directly
_remove_by_name(saved_values, node.name)
assert backward_state_inputs

# Now that we have the finalized list of saved values, we need to ensure
# we propagate all symbols which are referenced by backwards inputs.
# These are not directly used in the graph but are required for downstream
# sizevar assignment
saved_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
saved_sym_nodes_binding = []
saved_sym_nodes_derived = []

# Some symbols may already be bound in the directly saved_sym_nodes,
# keep track of them so we don't re-bind them
for node in saved_sym_nodes:
symbol = is_symbol_binding_fx_node(node)
if symbol:
saved_symbols.add(symbol)
saved_sym_nodes_binding.append(node)
else:
saved_sym_nodes_derived.append(node)

# Now go through all of the prospective backward inputs and track any
# other symbols we need to bind
symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph)
for node in itertools.chain(saved_sym_nodes_derived, saved_values):
if "val" not in node.meta:
continue
new_symbols = free_symbols(node.meta["val"]) - saved_symbols
# NB: Deterministic order please!
for s in sorted(new_symbols, key=lambda s: s.name):
# NB: For well formed graphs, the symbol should always be present,
# but we also have ways to produce ill-formed graphs, e.g., direct
# make_fx usages, so don't choke in this case
if s not in symbol_bindings:
continue
saved_sym_nodes_binding.append(symbol_bindings[s])
saved_symbols |= new_symbols

# Update saved_sym_nodes that are now reordered to have all bindings at
# front. This can also be used later on to figure out the position of saved
# sym nodes in the output of fwd graph.
saved_sym_nodes.clear()
saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived)

# Now, we re-generate the fwd/bwd graphs.
# NB: This might increase compilation time, but I doubt it matters
fwd_graph = _extract_graph_with_inputs_outputs(
joint_module.graph,
primal_inputs + fwd_seed_offset_inputs,
fwd_outputs + saved_values + saved_sym_nodes,
fwd_outputs_descs
+ [
SavedForBackwardsAOTOutput(i)
for i in range(len(saved_values) + len(saved_sym_nodes))
],
"forward",
ignore_must_be_in_fw_bw=True,
)
bwd_graph = _extract_graph_with_inputs_outputs(
joint_module.graph,
saved_values + saved_sym_nodes + bwd_seed_offset_inputs + backward_state_inputs,
bwd_outputs,
bwd_outputs_descs,
"backward",
ignore_must_be_in_fw_bw=True,
)

fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph)
bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph)
return fwd_module, bwd_module


# TODO: in theory we can infer num_weight_gradients from the graph metadata directly
def split_di_dw_graph(
bw_gm_old: fx.GraphModule, *, num_weight_gradients: int
Expand Down Expand Up @@ -230,6 +95,11 @@ def split_di_dw_graph(
saved_values = []
saved_sym_nodes = []

# TODO: this classification loop is a simplified version of default_partition's
# node classification. It does not handle: get_attr nodes, _assert_scalar/profiler
# ops, MUST_SAVE tags, impure/effectful ops, force_save_collectives,
# force_save_bw_mutation_src, must_recompute skipping, or post-split DCE.
# Ideally we would call default_partition directly instead of reimplementing.
for node in bw_gm.graph.nodes:
if node.name not in bw_inputs_gm_node_names:
# Not handling mutations for now,
Expand Down Expand Up @@ -262,5 +132,7 @@ def split_di_dw_graph(
saved_values,
saved_sym_nodes=saved_sym_nodes,
num_fwd_outputs=num_input_gradients,
ignore_must_be_in_fw_bw=True,
omit_aot_autograd_runtime=True,
)
return bw_inputs, bw_weights, num_input_gradients
Loading