From 8b273848a4163180b5acfdaa98ccf5feffbeccf0 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Mon, 16 Mar 2026 17:50:56 -0700 Subject: [PATCH 1/2] Scope CustomShardingPropagator to test_dtensor tests via pytest fixture The module-level `dispatcher.sharding_propagator = CustomShardingPropagator()` was leaking into other test files (e.g. test_api.py) when run in the same pytest process, causing `aten.copy_` failures because the custom propagator doesn't have rules for ops that the default DTensor propagator handles. test_dtensor.py's two test classes (ImplicitRegistrationTest, DimShardingTest) inherit from DTensorTestBase which uses MultiProcessTestCase -- each test spawns subprocesses that re-import the module. Those subprocesses don't run pytest fixtures, so they need the custom propagator installed at module level. We gate the module-level install on `multiprocessing.current_process().name` to only run in spawned workers, and use a module-scoped autouse pytest fixture to install/restore the propagator in the main process. Authored with Claude. stack-info: PR: https://github.com/meta-pytorch/autoparallel/pull/367, branch: xmfan/stack/32 --- tests/test_dtensor.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/tests/test_dtensor.py b/tests/test_dtensor.py index 824d4afd..5628ac84 100644 --- a/tests/test_dtensor.py +++ b/tests/test_dtensor.py @@ -4,8 +4,10 @@ # LICENSE file in the root directory of this source tree. import functools +import multiprocessing import numpy as np +import pytest import torch from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor @@ -26,6 +28,7 @@ with_comms, ) +import autoparallel.shardings.dtensor_sharding_helpers as dtensor_sharding_helpers from autoparallel.shardings.dtensor_sharding_helpers import ( batch_shard_strategy, get_op_strategy, @@ -346,9 +349,27 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin return output_sharding -dispatcher = DTensor._op_dispatcher -# change to the customized sharding_propagator for testing implicit fallback -dispatcher.sharding_propagator = CustomShardingPropagator() +# Install the custom propagator so that subprocesses spawned by +# MultiProcessTestCase (which use ``spawn`` and re-import this module) see it. +# In the main pytest process we use a module-scoped fixture to install/restore +# it only while this module's tests run, preventing leaks into other modules. +_custom_propagator = CustomShardingPropagator() +_orig_dispatcher_propagator = DTensor._op_dispatcher.sharding_propagator +_orig_helpers_propagator = dtensor_sharding_helpers.propagator + +if multiprocessing.current_process().name != "MainProcess": + # Subprocess worker (spawned by MultiProcessTestCase): install immediately. + DTensor._op_dispatcher.sharding_propagator = _custom_propagator + dtensor_sharding_helpers.propagator = _custom_propagator + + +@pytest.fixture(autouse=True, scope="module") +def _install_custom_propagator(): + DTensor._op_dispatcher.sharding_propagator = _custom_propagator + dtensor_sharding_helpers.propagator = _custom_propagator + yield + DTensor._op_dispatcher.sharding_propagator = _orig_dispatcher_propagator + dtensor_sharding_helpers.propagator = _orig_helpers_propagator class ImplicitRegistrationTest(DTensorTestBase): From 02225522419dacd0d8ac29994b2969aa523d9499 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 10 Mar 2026 13:24:10 -0700 Subject: [PATCH 2/2] Add lazy Inductor compilation to graph_pp_runner Add _execute_graph() that lazily compiles graph modules with compile_fx_inner on first invocation. Controlled by an inductor kwarg threaded through all _run_* functions. GraphPPRunner accepts inductor=True and propagates it to all GraphPipelineStage instances, which the stage_* action functions read when calling _run_*. Authored with Claude. stack-info: PR: https://github.com/meta-pytorch/autoparallel/pull/360, branch: xmfan/stack/30 --- autoparallel/graph_passes/graph_pp_runner.py | 85 ++++++++++++++++---- examples/example_ds3_pp.py | 2 +- 2 files changed, 70 insertions(+), 17 deletions(-) diff --git a/autoparallel/graph_passes/graph_pp_runner.py b/autoparallel/graph_passes/graph_pp_runner.py index 4ad52b20..02c1c0ee 100644 --- a/autoparallel/graph_passes/graph_pp_runner.py +++ b/autoparallel/graph_passes/graph_pp_runner.py @@ -28,6 +28,19 @@ logger.setLevel(logging.DEBUG) +def _execute_graph( + gm: fx.GraphModule, args: list[Any], *, inductor: bool = False +) -> Any: + """Execute a graph module, optionally compiling with Inductor on first call.""" + if inductor: + if not hasattr(gm, "_compiled"): + from torch._inductor.compile_fx import compile_fx_inner + + gm._compiled = compile_fx_inner(gm, args) # type: ignore[assignment, attr-defined] + return gm._compiled(args) # type: ignore[operator, attr-defined] + return fx.Interpreter(gm).boxed_run(args) + + @dataclass class GraphCallables: fw: fx.GraphModule @@ -161,6 +174,7 @@ def __init__( "sharded_grads": [], "unsharded_grads": [], } + self.inductor: bool = False self.bwd_activation_cache: dict[int, tuple[Any]] = {} def scale_grads(self, grad_scale_factor: int) -> None: @@ -202,13 +216,14 @@ def _run_fw_module( graph_meta: GraphMeta, fw_args: list[Any], numerics_logs: Optional[list[str]] = None, + inductor: bool = False, ) -> tuple[Any, tuple[tuple[Any], tuple[Any]]]: if numerics_logs is not None: debug_interpreter = DebugInterpreter(fw_module) fw_outputs = debug_interpreter.boxed_run(fw_args) numerics_logs += debug_interpreter.get_logs() else: - fw_outputs = fx.Interpreter(fw_module).boxed_run(fw_args) + fw_outputs = _execute_graph(fw_module, fw_args, inductor=inductor) num_inner_fwd_outputs = graph_meta.num_mutate_inputs + graph_meta.num_user_outputs saved_intermediates = fw_outputs[num_inner_fwd_outputs:] @@ -225,9 +240,9 @@ def _run_fw_module( def _run_full_bw_module( - bw_module: fx.GraphModule, graph_meta: GraphMeta, bw_args + bw_module: fx.GraphModule, graph_meta: GraphMeta, bw_args, inductor: bool = False ) -> tuple[list[Any], list[Any]]: - bw_outputs = fx.Interpreter(bw_module).boxed_run(bw_args) + bw_outputs = _execute_graph(bw_module, bw_args, inductor=inductor) num_params_buffers = graph_meta.num_params + graph_meta.num_buffers param_buffer_grads = bw_outputs[:num_params_buffers] input_grads = bw_outputs[num_params_buffers:] @@ -235,9 +250,14 @@ def _run_full_bw_module( def _run_dI_bw_module( - bw_dI_module: fx.GraphModule, graph_meta: GraphMeta, bw_dI_args + bw_dI_module: fx.GraphModule, + graph_meta: GraphMeta, + bw_dI_args, + inductor: bool = False, ) -> tuple[list[Any], list[Any]]: - inp_grads_and_activations = fx.Interpreter(bw_dI_module).boxed_run(bw_dI_args) + inp_grads_and_activations = _execute_graph( + bw_dI_module, bw_dI_args, inductor=inductor + ) inp_grads, activations = inp_grads_and_activations[ : graph_meta.num_input_grads ], list(inp_grads_and_activations[graph_meta.num_input_grads :]) @@ -245,23 +265,34 @@ def _run_dI_bw_module( def _run_dW_bw_module( - bw_dW_module: fx.GraphModule, graph_meta: GraphMeta, bw_dW_args + bw_dW_module: fx.GraphModule, + graph_meta: GraphMeta, + bw_dW_args, + inductor: bool = False, ) -> list[Any]: - param_buffer_grads = fx.Interpreter(bw_dW_module).boxed_run(bw_dW_args) + param_buffer_grads = _execute_graph(bw_dW_module, bw_dW_args, inductor=inductor) return param_buffer_grads def _run_unshard_module( - unshard_module: fx.GraphModule, graph_meta: GraphMeta, unshard_args + unshard_module: fx.GraphModule, + graph_meta: GraphMeta, + unshard_args, + inductor: bool = False, ) -> list[Any]: - unsharded_params = fx.Interpreter(unshard_module).boxed_run(unshard_args) + unsharded_params = _execute_graph(unshard_module, unshard_args, inductor=inductor) return unsharded_params def _run_reduce_grad_module( - reduce_grad_module: fx.GraphModule, graph_meta: GraphMeta, reduce_grad_args + reduce_grad_module: fx.GraphModule, + graph_meta: GraphMeta, + reduce_grad_args, + inductor: bool = False, ) -> list[Any]: - sharded_grads = fx.Interpreter(reduce_grad_module).boxed_run(reduce_grad_args) + sharded_grads = _execute_graph( + reduce_grad_module, reduce_grad_args, inductor=inductor + ) return sharded_grads @@ -270,8 +301,11 @@ def _run_multiplexed_fw_bw_module( fw_graph_meta: GraphMeta, bw_graph_meta: GraphMeta, bw_fw_args, + inductor: bool = False, ) -> tuple[list[Any], list[Any], Any, tuple[tuple[Any], tuple[Any]]]: - multiplexed_outs = fx.Interpreter(multiplexed_fw_bw_module).boxed_run(bw_fw_args) + multiplexed_outs = _execute_graph( + multiplexed_fw_bw_module, bw_fw_args, inductor=inductor + ) num_params_buffers = bw_graph_meta.num_params + bw_graph_meta.num_buffers num_bw_outs = bw_graph_meta.num_input_grads + num_params_buffers @@ -507,7 +541,11 @@ def stage_forward( action, ) output, saved_intermediates = _run_fw_module( - stage.graph_callables.fw, stage.graph_meta, fw_args, numerics_logs=numerics_logs + stage.graph_callables.fw, + stage.graph_meta, + fw_args, + numerics_logs=numerics_logs, + inductor=stage.inductor, ) _post_fwd_common( @@ -673,7 +711,10 @@ def stage_full_backward( action, ) input_grads, param_buffer_grads = _run_full_bw_module( - bw_stage.graph_callables.full_bw, bw_stage.graph_meta, bw_args + bw_stage.graph_callables.full_bw, + bw_stage.graph_meta, + bw_args, + inductor=bw_stage.inductor, ) bw_stage._accumulate_stage_unsharded_grads(param_buffer_grads) @@ -732,7 +773,10 @@ def stage_backward_input( ) assert bw_stage.graph_callables.bw_dI is not None input_grads, activations_for_backward = _run_dI_bw_module( - bw_stage.graph_callables.bw_dI, bw_stage.graph_meta, bw_args + bw_stage.graph_callables.bw_dI, + bw_stage.graph_meta, + bw_args, + inductor=bw_stage.inductor, ) bw_stage.bwd_activation_cache[bw_mb_index] = ( @@ -785,7 +829,10 @@ def stage_backward_weight( del activations_for_backward assert bw_stage.graph_callables.bw_dW is not None param_buffer_grads = _run_dW_bw_module( - bw_stage.graph_callables.bw_dW, bw_stage.graph_meta, bw_args + bw_stage.graph_callables.bw_dW, + bw_stage.graph_meta, + bw_args, + inductor=bw_stage.inductor, ) bw_stage._accumulate_stage_unsharded_grads(param_buffer_grads) @@ -894,6 +941,7 @@ def stage_unshard( stage.graph_callables.unshard, stage.graph_meta, sharded_params, + inductor=stage.inductor, ) stage.state["unsharded_params"] = unsharded_params @@ -926,6 +974,7 @@ def stage_reduce_grad( stage.graph_callables.reduce_grad, stage.graph_meta, stage.state["unsharded_grads"], + inductor=stage.inductor, ) stage.state["sharded_grads"] = sharded_grads @@ -934,6 +983,7 @@ class GraphPPRunner: def __init__( self, schedule: _PipelineScheduleRuntime, + inductor: bool = False, ): self.schedule = schedule if not schedule._backward_requires_autograd: @@ -949,6 +999,9 @@ def __init__( for stage in schedule._stages ) self.schedule._has_backward = True + for stage in schedule._stages: + assert isinstance(stage, GraphPipelineStage) + stage.inductor = inductor def _populate_stage_states(self, stage: GraphPipelineStage) -> None: sharded_params = [ diff --git a/examples/example_ds3_pp.py b/examples/example_ds3_pp.py index 663722c2..8895c6f9 100644 --- a/examples/example_ds3_pp.py +++ b/examples/example_ds3_pp.py @@ -619,7 +619,7 @@ def init_weights(self, *args, **kwargs): ) # Step 7. Register the schedule with the graph runner - graph_pp_runner = GraphPPRunner(schedule) + graph_pp_runner = GraphPPRunner(schedule) # inductor=True to compile with Inductor # Step 8. Run the whole pipeline once using the graph runner has_last_stage = (total_pp_stages - 1) in stage_mods