Skip to content

Scope CustomShardingPropagator to test_dtensor tests via pytest fixture#367

Merged
fmassa merged 1 commit intomainfrom
xmfan/stack/32
Mar 17, 2026
Merged

Scope CustomShardingPropagator to test_dtensor tests via pytest fixture#367
fmassa merged 1 commit intomainfrom
xmfan/stack/32

Conversation

@xmfan
Copy link
Member

@xmfan xmfan commented Mar 17, 2026

Stacked PRs:


Scope CustomShardingPropagator to test_dtensor tests via setUp/tearDown

The module-level dispatcher.sharding_propagator = CustomShardingPropagator()
was leaking into other test files (e.g. test_api.py) when run in the same
process, causing aten.copy_ KeyError failures because the custom propagator
doesn't have rules for ops that the default DTensor propagator handles.

Replace the global mutation with a setUp/tearDown mixin that installs the
custom propagator before each test and restores the original afterwards.

Authored with Claude.

@xmfan xmfan requested a review from zpcore March 17, 2026 00:57
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 17, 2026
@xmfan xmfan requested a review from sanketpurandare March 17, 2026 00:57


@pytest.fixture(autouse=True, scope="module")
def _restore_propagator():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not being called?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixture magic

The module-level `dispatcher.sharding_propagator = CustomShardingPropagator()`
was leaking into other test files (e.g. test_api.py) when run in the same
pytest process, causing `aten.copy_` failures because the custom propagator
doesn't have rules for ops that the default DTensor propagator handles.

test_dtensor.py's two test classes (ImplicitRegistrationTest, DimShardingTest)
inherit from DTensorTestBase which uses MultiProcessTestCase -- each test
spawns subprocesses that re-import the module. Those subprocesses don't run
pytest fixtures, so they need the custom propagator installed at module level.
We gate the module-level install on `multiprocessing.current_process().name`
to only run in spawned workers, and use a module-scoped autouse pytest fixture
to install/restore the propagator in the main process.

Authored with Claude.

stack-info: PR: #367, branch: xmfan/stack/32
@xmfan xmfan changed the title Scope CustomShardingPropagator to test_dtensor tests via setUp/tearDown Scope CustomShardingPropagator to test_dtensor tests via pytest fixture Mar 17, 2026
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

@fmassa fmassa merged commit 83342fd into main Mar 17, 2026
10 checks passed
@fmassa fmassa deleted the xmfan/stack/32 branch March 17, 2026 14:28
xmfan added a commit that referenced this pull request Mar 17, 2026
* Scope CustomShardingPropagator to test_dtensor tests via pytest fixture

The module-level `dispatcher.sharding_propagator = CustomShardingPropagator()`
was leaking into other test files (e.g. test_api.py) when run in the same
pytest process, causing `aten.copy_` failures because the custom propagator
doesn't have rules for ops that the default DTensor propagator handles.

test_dtensor.py's two test classes (ImplicitRegistrationTest, DimShardingTest)
inherit from DTensorTestBase which uses MultiProcessTestCase -- each test
spawns subprocesses that re-import the module. Those subprocesses don't run
pytest fixtures, so they need the custom propagator installed at module level.
We gate the module-level install on `multiprocessing.current_process().name`
to only run in spawned workers, and use a module-scoped autouse pytest fixture
to install/restore the propagator in the main process.

Authored with Claude.

stack-info: PR: #367, branch: xmfan/stack/32

* Add lazy Inductor compilation to graph_pp_runner

Add _execute_graph() that lazily compiles graph modules with
compile_fx_inner on first invocation. Controlled by an inductor kwarg
threaded through all _run_* functions.

GraphPPRunner accepts inductor=True and propagates it to all
GraphPipelineStage instances, which the stage_* action functions
read when calling _run_*.

Authored with Claude.

stack-info: PR: #360, branch: xmfan/stack/30
xmfan added a commit that referenced this pull request Mar 17, 2026
* Scope CustomShardingPropagator to test_dtensor tests via pytest fixture

The module-level `dispatcher.sharding_propagator = CustomShardingPropagator()`
was leaking into other test files (e.g. test_api.py) when run in the same
pytest process, causing `aten.copy_` failures because the custom propagator
doesn't have rules for ops that the default DTensor propagator handles.

test_dtensor.py's two test classes (ImplicitRegistrationTest, DimShardingTest)
inherit from DTensorTestBase which uses MultiProcessTestCase -- each test
spawns subprocesses that re-import the module. Those subprocesses don't run
pytest fixtures, so they need the custom propagator installed at module level.
We gate the module-level install on `multiprocessing.current_process().name`
to only run in spawned workers, and use a module-scoped autouse pytest fixture
to install/restore the propagator in the main process.

Authored with Claude.

stack-info: PR: #367, branch: xmfan/stack/32

* Add lazy Inductor compilation to graph_pp_runner

Add _execute_graph() that lazily compiles graph modules with
compile_fx_inner on first invocation. Controlled by an inductor kwarg
threaded through all _run_* functions.

GraphPPRunner accepts inductor=True and propagates it to all
GraphPipelineStage instances, which the stage_* action functions
read when calling _run_*.

Authored with Claude.

stack-info: PR: #360, branch: xmfan/stack/30

* Add --inductor flag to example_ds3_pp with FORCE_BALANCED_ROUTING

The DSv3 MoE implementation uses .tolist() and data-dependent grouped_mm
offsets that Inductor cannot compile. Add FORCE_BALANCED_ROUTING to the
model that makes token-per-expert counts uniform and uses balanced
all-to-all splits, eliminating all data-dependent ops.

The --inductor CLI flag enables both Inductor compilation and forced
balanced routing together.

Authored with Claude.

stack-info: PR: #361, branch: xmfan/stack/31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants