Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 69 additions & 16 deletions autoparallel/graph_passes/graph_pp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@
logger.setLevel(logging.DEBUG)


def _execute_graph(
gm: fx.GraphModule, args: list[Any], *, inductor: bool = False
) -> Any:
"""Execute a graph module, optionally compiling with Inductor on first call."""
if inductor:
if not hasattr(gm, "_compiled"):
from torch._inductor.compile_fx import compile_fx_inner

gm._compiled = compile_fx_inner(gm, args) # type: ignore[assignment, attr-defined]
return gm._compiled(args) # type: ignore[operator, attr-defined]
return fx.Interpreter(gm).boxed_run(args)


@dataclass
class GraphCallables:
fw: fx.GraphModule
Expand Down Expand Up @@ -161,6 +174,7 @@ def __init__(
"sharded_grads": [],
"unsharded_grads": [],
}
self.inductor: bool = False
self.bwd_activation_cache: dict[int, tuple[Any]] = {}

def scale_grads(self, grad_scale_factor: int) -> None:
Expand Down Expand Up @@ -202,13 +216,14 @@ def _run_fw_module(
graph_meta: GraphMeta,
fw_args: list[Any],
numerics_logs: Optional[list[str]] = None,
inductor: bool = False,
) -> tuple[Any, tuple[tuple[Any], tuple[Any]]]:
if numerics_logs is not None:
debug_interpreter = DebugInterpreter(fw_module)
fw_outputs = debug_interpreter.boxed_run(fw_args)
numerics_logs += debug_interpreter.get_logs()
else:
fw_outputs = fx.Interpreter(fw_module).boxed_run(fw_args)
fw_outputs = _execute_graph(fw_module, fw_args, inductor=inductor)

num_inner_fwd_outputs = graph_meta.num_mutate_inputs + graph_meta.num_user_outputs
saved_intermediates = fw_outputs[num_inner_fwd_outputs:]
Expand All @@ -225,43 +240,59 @@ def _run_fw_module(


def _run_full_bw_module(
bw_module: fx.GraphModule, graph_meta: GraphMeta, bw_args
bw_module: fx.GraphModule, graph_meta: GraphMeta, bw_args, inductor: bool = False
) -> tuple[list[Any], list[Any]]:
bw_outputs = fx.Interpreter(bw_module).boxed_run(bw_args)
bw_outputs = _execute_graph(bw_module, bw_args, inductor=inductor)
num_params_buffers = graph_meta.num_params + graph_meta.num_buffers
param_buffer_grads = bw_outputs[:num_params_buffers]
input_grads = bw_outputs[num_params_buffers:]
return input_grads, param_buffer_grads


def _run_dI_bw_module(
bw_dI_module: fx.GraphModule, graph_meta: GraphMeta, bw_dI_args
bw_dI_module: fx.GraphModule,
graph_meta: GraphMeta,
bw_dI_args,
inductor: bool = False,
) -> tuple[list[Any], list[Any]]:
inp_grads_and_activations = fx.Interpreter(bw_dI_module).boxed_run(bw_dI_args)
inp_grads_and_activations = _execute_graph(
bw_dI_module, bw_dI_args, inductor=inductor
)
inp_grads, activations = inp_grads_and_activations[
: graph_meta.num_input_grads
], list(inp_grads_and_activations[graph_meta.num_input_grads :])
return inp_grads, activations


def _run_dW_bw_module(
bw_dW_module: fx.GraphModule, graph_meta: GraphMeta, bw_dW_args
bw_dW_module: fx.GraphModule,
graph_meta: GraphMeta,
bw_dW_args,
inductor: bool = False,
) -> list[Any]:
param_buffer_grads = fx.Interpreter(bw_dW_module).boxed_run(bw_dW_args)
param_buffer_grads = _execute_graph(bw_dW_module, bw_dW_args, inductor=inductor)
return param_buffer_grads


def _run_unshard_module(
unshard_module: fx.GraphModule, graph_meta: GraphMeta, unshard_args
unshard_module: fx.GraphModule,
graph_meta: GraphMeta,
unshard_args,
inductor: bool = False,
) -> list[Any]:
unsharded_params = fx.Interpreter(unshard_module).boxed_run(unshard_args)
unsharded_params = _execute_graph(unshard_module, unshard_args, inductor=inductor)
return unsharded_params


def _run_reduce_grad_module(
reduce_grad_module: fx.GraphModule, graph_meta: GraphMeta, reduce_grad_args
reduce_grad_module: fx.GraphModule,
graph_meta: GraphMeta,
reduce_grad_args,
inductor: bool = False,
) -> list[Any]:
sharded_grads = fx.Interpreter(reduce_grad_module).boxed_run(reduce_grad_args)
sharded_grads = _execute_graph(
reduce_grad_module, reduce_grad_args, inductor=inductor
)
return sharded_grads


Expand All @@ -270,8 +301,11 @@ def _run_multiplexed_fw_bw_module(
fw_graph_meta: GraphMeta,
bw_graph_meta: GraphMeta,
bw_fw_args,
inductor: bool = False,
) -> tuple[list[Any], list[Any], Any, tuple[tuple[Any], tuple[Any]]]:
multiplexed_outs = fx.Interpreter(multiplexed_fw_bw_module).boxed_run(bw_fw_args)
multiplexed_outs = _execute_graph(
multiplexed_fw_bw_module, bw_fw_args, inductor=inductor
)

num_params_buffers = bw_graph_meta.num_params + bw_graph_meta.num_buffers
num_bw_outs = bw_graph_meta.num_input_grads + num_params_buffers
Expand Down Expand Up @@ -507,7 +541,11 @@ def stage_forward(
action,
)
output, saved_intermediates = _run_fw_module(
stage.graph_callables.fw, stage.graph_meta, fw_args, numerics_logs=numerics_logs
stage.graph_callables.fw,
stage.graph_meta,
fw_args,
numerics_logs=numerics_logs,
inductor=stage.inductor,
)

_post_fwd_common(
Expand Down Expand Up @@ -673,7 +711,10 @@ def stage_full_backward(
action,
)
input_grads, param_buffer_grads = _run_full_bw_module(
bw_stage.graph_callables.full_bw, bw_stage.graph_meta, bw_args
bw_stage.graph_callables.full_bw,
bw_stage.graph_meta,
bw_args,
inductor=bw_stage.inductor,
)
bw_stage._accumulate_stage_unsharded_grads(param_buffer_grads)

Expand Down Expand Up @@ -732,7 +773,10 @@ def stage_backward_input(
)
assert bw_stage.graph_callables.bw_dI is not None
input_grads, activations_for_backward = _run_dI_bw_module(
bw_stage.graph_callables.bw_dI, bw_stage.graph_meta, bw_args
bw_stage.graph_callables.bw_dI,
bw_stage.graph_meta,
bw_args,
inductor=bw_stage.inductor,
)

bw_stage.bwd_activation_cache[bw_mb_index] = (
Expand Down Expand Up @@ -785,7 +829,10 @@ def stage_backward_weight(
del activations_for_backward
assert bw_stage.graph_callables.bw_dW is not None
param_buffer_grads = _run_dW_bw_module(
bw_stage.graph_callables.bw_dW, bw_stage.graph_meta, bw_args
bw_stage.graph_callables.bw_dW,
bw_stage.graph_meta,
bw_args,
inductor=bw_stage.inductor,
)
bw_stage._accumulate_stage_unsharded_grads(param_buffer_grads)

Expand Down Expand Up @@ -894,6 +941,7 @@ def stage_unshard(
stage.graph_callables.unshard,
stage.graph_meta,
sharded_params,
inductor=stage.inductor,
)
stage.state["unsharded_params"] = unsharded_params

Expand Down Expand Up @@ -926,6 +974,7 @@ def stage_reduce_grad(
stage.graph_callables.reduce_grad,
stage.graph_meta,
stage.state["unsharded_grads"],
inductor=stage.inductor,
)
stage.state["sharded_grads"] = sharded_grads

Expand All @@ -934,6 +983,7 @@ class GraphPPRunner:
def __init__(
self,
schedule: _PipelineScheduleRuntime,
inductor: bool = False,
):
self.schedule = schedule
if not schedule._backward_requires_autograd:
Expand All @@ -949,6 +999,9 @@ def __init__(
for stage in schedule._stages
)
self.schedule._has_backward = True
for stage in schedule._stages:
assert isinstance(stage, GraphPipelineStage)
stage.inductor = inductor

def _populate_stage_states(self, stage: GraphPipelineStage) -> None:
sharded_params = [
Expand Down
2 changes: 1 addition & 1 deletion examples/example_ds3_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def init_weights(self, *args, **kwargs):
)

# Step 7. Register the schedule with the graph runner
graph_pp_runner = GraphPPRunner(schedule)
graph_pp_runner = GraphPPRunner(schedule) # inductor=True to compile with Inductor

# Step 8. Run the whole pipeline once using the graph runner
has_last_stage = (total_pp_stages - 1) in stage_mods
Expand Down
27 changes: 24 additions & 3 deletions tests/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# LICENSE file in the root directory of this source tree.

import functools
import multiprocessing

import numpy as np
import pytest
import torch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor
Expand All @@ -26,6 +28,7 @@
with_comms,
)

import autoparallel.shardings.dtensor_sharding_helpers as dtensor_sharding_helpers
from autoparallel.shardings.dtensor_sharding_helpers import (
batch_shard_strategy,
get_op_strategy,
Expand Down Expand Up @@ -346,9 +349,27 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin
return output_sharding


dispatcher = DTensor._op_dispatcher
# change to the customized sharding_propagator for testing implicit fallback
dispatcher.sharding_propagator = CustomShardingPropagator()
# Install the custom propagator so that subprocesses spawned by
# MultiProcessTestCase (which use ``spawn`` and re-import this module) see it.
# In the main pytest process we use a module-scoped fixture to install/restore
# it only while this module's tests run, preventing leaks into other modules.
_custom_propagator = CustomShardingPropagator()
_orig_dispatcher_propagator = DTensor._op_dispatcher.sharding_propagator
_orig_helpers_propagator = dtensor_sharding_helpers.propagator

if multiprocessing.current_process().name != "MainProcess":
# Subprocess worker (spawned by MultiProcessTestCase): install immediately.
DTensor._op_dispatcher.sharding_propagator = _custom_propagator
dtensor_sharding_helpers.propagator = _custom_propagator


@pytest.fixture(autouse=True, scope="module")
def _install_custom_propagator():
DTensor._op_dispatcher.sharding_propagator = _custom_propagator
dtensor_sharding_helpers.propagator = _custom_propagator
yield
DTensor._op_dispatcher.sharding_propagator = _orig_dispatcher_propagator
dtensor_sharding_helpers.propagator = _orig_helpers_propagator


class ImplicitRegistrationTest(DTensorTestBase):
Expand Down
Loading