diff --git a/autoparallel/_testing/models/dsv3.py b/autoparallel/_testing/models/dsv3.py index 6f3e2752..2d50b9c1 100644 --- a/autoparallel/_testing/models/dsv3.py +++ b/autoparallel/_testing/models/dsv3.py @@ -21,6 +21,11 @@ from autoparallel.collectives import all_to_all, axis_size, local_map +# When True, MoE uses uniform token routing and balanced all-to-all splits, +# eliminating data-dependent ops (.tolist(), dynamic grouped_mm offsets) that +# prevent Inductor compilation. +FORCE_BALANCED_ROUTING: bool = False + # parallelized kernel @triton.jit @@ -633,11 +638,8 @@ def forward( def _token_dispatch(routed_input, num_tokens_per_expert, axis_name): with fx_traceback.annotate({"comm_region": "token_dispatch"}): - # annotate module input placements/sharding with input_layouts - # ep_size = device_mesh.shape[0] ep_size = axis_size(axis_name) - # generate the input splits and output splits for all-to-all with torch.no_grad(): num_tokens_per_expert_group = all_to_all( num_tokens_per_expert, @@ -645,21 +647,26 @@ def _token_dispatch(routed_input, num_tokens_per_expert, axis_name): None, axis_name, ) - input_splits = ( - num_tokens_per_expert.view(ep_size, -1) - .sum(dim=1) - .to(torch.device("cpu"), non_blocking=True) - ) - # NOTE: this would incur a device-to-host sync - output_splits = ( - num_tokens_per_expert_group.view(ep_size, -1) - .sum(dim=1) - .to(torch.device("cpu"), non_blocking=False) - ) - input_splits = input_splits.tolist() - output_splits = output_splits.tolist() - # perform all-to-all + if FORCE_BALANCED_ROUTING: + input_splits = None + output_splits = None + else: + with torch.no_grad(): + input_splits = ( + num_tokens_per_expert.view(ep_size, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=True) + ) + # NOTE: this would incur a device-to-host sync + output_splits = ( + num_tokens_per_expert_group.view(ep_size, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=False) + ) + input_splits = input_splits.tolist() + output_splits = output_splits.tolist() + routed_input = all_to_all( routed_input, output_splits, @@ -708,13 +715,24 @@ def local_mapped_region( dim = x.shape[-1] - # num_tokens_per_expert = torch.ops.autoparallel.batched_histc( - num_tokens_per_expert = torch.histc( - selected_experts_indices.flatten(), - bins=num_experts, - min=0, - max=num_experts, - ) + if FORCE_BALANCED_ROUTING: + # Uniform distribution: same number of tokens per expert. + # Eliminates data-dependent grouped_mm offsets for Inductor. + total_tokens = selected_experts_indices.numel() + num_tokens_per_expert = torch.full( + (num_experts,), + total_tokens // num_experts, + device=x.device, + dtype=torch.int32, + ) + else: + # num_tokens_per_expert = torch.ops.autoparallel.batched_histc( + num_tokens_per_expert = torch.histc( + selected_experts_indices.flatten(), + bins=num_experts, + min=0, + max=num_experts, + ) # total_tokens_per_expert = all_reduce(num_tokens_per_expert, axis_name) total_tokens_per_expert = num_tokens_per_expert diff --git a/examples/example_ds3_pp.py b/examples/example_ds3_pp.py index 8895c6f9..7fa52ee4 100644 --- a/examples/example_ds3_pp.py +++ b/examples/example_ds3_pp.py @@ -34,6 +34,7 @@ from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.testing._internal.distributed.fake_pg import FakeStore +import autoparallel._testing.models.dsv3 as dsv3_module from autoparallel._testing.models.dsv3 import ( DeepSeekV3Model, DeepSeekV3ModelArgs, @@ -137,6 +138,7 @@ def run_test( rng_seed: Optional[int], logs_dir: str, use_cache: bool, + use_inductor: bool = False, ): if not fake_evaluate: pp_degree = 2 @@ -619,7 +621,7 @@ def init_weights(self, *args, **kwargs): ) # Step 7. Register the schedule with the graph runner - graph_pp_runner = GraphPPRunner(schedule) # inductor=True to compile with Inductor + graph_pp_runner = GraphPPRunner(schedule, inductor=use_inductor) # Step 8. Run the whole pipeline once using the graph runner has_last_stage = (total_pp_stages - 1) in stage_mods @@ -714,6 +716,12 @@ def init_weights(self, *args, **kwargs): default=False, help="Use cached graph files if available (default: False)", ) + parser.add_argument( + "--inductor", + action="store_true", + default=False, + help="Compile subgraphs with Inductor (also forces balanced MoE routing)", + ) args = parser.parse_args() if args.use_cache and not args.fake_evaluate: @@ -723,6 +731,12 @@ def init_weights(self, *args, **kwargs): torch.use_deterministic_algorithms(True) torch.manual_seed(args.rng_seed) + if args.inductor: + # The DSv3 MoE implementation uses .tolist() and data-dependent grouped_mm + # offsets, which Inductor cannot compile. Force balanced routing to make + # all token counts static. + dsv3_module.FORCE_BALANCED_ROUTING = True + run_test( fake_evaluate=args.fake_evaluate, use_loss_fn=args.use_loss_fn, @@ -730,4 +744,5 @@ def init_weights(self, *args, **kwargs): rng_seed=args.rng_seed, logs_dir=args.logs_dir, use_cache=args.use_cache, + use_inductor=args.inductor, )