Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 42 additions & 24 deletions autoparallel/_testing/models/dsv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@

from autoparallel.collectives import all_to_all, axis_size, local_map

# When True, MoE uses uniform token routing and balanced all-to-all splits,
# eliminating data-dependent ops (.tolist(), dynamic grouped_mm offsets) that
# prevent Inductor compilation.
FORCE_BALANCED_ROUTING: bool = False


# parallelized kernel
@triton.jit
Expand Down Expand Up @@ -633,33 +638,35 @@ def forward(

def _token_dispatch(routed_input, num_tokens_per_expert, axis_name):
with fx_traceback.annotate({"comm_region": "token_dispatch"}):
# annotate module input placements/sharding with input_layouts
# ep_size = device_mesh.shape[0]
ep_size = axis_size(axis_name)

# generate the input splits and output splits for all-to-all
with torch.no_grad():
num_tokens_per_expert_group = all_to_all(
num_tokens_per_expert,
None,
None,
axis_name,
)
input_splits = (
num_tokens_per_expert.view(ep_size, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
# NOTE: this would incur a device-to-host sync
output_splits = (
num_tokens_per_expert_group.view(ep_size, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=False)
)
input_splits = input_splits.tolist()
output_splits = output_splits.tolist()

# perform all-to-all
if FORCE_BALANCED_ROUTING:
input_splits = None
output_splits = None
else:
with torch.no_grad():
input_splits = (
num_tokens_per_expert.view(ep_size, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
# NOTE: this would incur a device-to-host sync
output_splits = (
num_tokens_per_expert_group.view(ep_size, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=False)
)
input_splits = input_splits.tolist()
output_splits = output_splits.tolist()

routed_input = all_to_all(
routed_input,
output_splits,
Expand Down Expand Up @@ -708,13 +715,24 @@ def local_mapped_region(

dim = x.shape[-1]

# num_tokens_per_expert = torch.ops.autoparallel.batched_histc(
num_tokens_per_expert = torch.histc(
selected_experts_indices.flatten(),
bins=num_experts,
min=0,
max=num_experts,
)
if FORCE_BALANCED_ROUTING:
# Uniform distribution: same number of tokens per expert.
# Eliminates data-dependent grouped_mm offsets for Inductor.
total_tokens = selected_experts_indices.numel()
num_tokens_per_expert = torch.full(
(num_experts,),
total_tokens // num_experts,
device=x.device,
dtype=torch.int32,
)
else:
# num_tokens_per_expert = torch.ops.autoparallel.batched_histc(
num_tokens_per_expert = torch.histc(
selected_experts_indices.flatten(),
bins=num_experts,
min=0,
max=num_experts,
)

# total_tokens_per_expert = all_reduce(num_tokens_per_expert, axis_name)
total_tokens_per_expert = num_tokens_per_expert
Expand Down
17 changes: 16 additions & 1 deletion examples/example_ds3_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.testing._internal.distributed.fake_pg import FakeStore

import autoparallel._testing.models.dsv3 as dsv3_module
from autoparallel._testing.models.dsv3 import (
DeepSeekV3Model,
DeepSeekV3ModelArgs,
Expand Down Expand Up @@ -137,6 +138,7 @@ def run_test(
rng_seed: Optional[int],
logs_dir: str,
use_cache: bool,
use_inductor: bool = False,
):
if not fake_evaluate:
pp_degree = 2
Expand Down Expand Up @@ -619,7 +621,7 @@ def init_weights(self, *args, **kwargs):
)

# Step 7. Register the schedule with the graph runner
graph_pp_runner = GraphPPRunner(schedule) # inductor=True to compile with Inductor
graph_pp_runner = GraphPPRunner(schedule, inductor=use_inductor)

# Step 8. Run the whole pipeline once using the graph runner
has_last_stage = (total_pp_stages - 1) in stage_mods
Expand Down Expand Up @@ -714,6 +716,12 @@ def init_weights(self, *args, **kwargs):
default=False,
help="Use cached graph files if available (default: False)",
)
parser.add_argument(
"--inductor",
action="store_true",
default=False,
help="Compile subgraphs with Inductor (also forces balanced MoE routing)",
)
args = parser.parse_args()

if args.use_cache and not args.fake_evaluate:
Expand All @@ -723,11 +731,18 @@ def init_weights(self, *args, **kwargs):
torch.use_deterministic_algorithms(True)
torch.manual_seed(args.rng_seed)

if args.inductor:
# The DSv3 MoE implementation uses .tolist() and data-dependent grouped_mm
# offsets, which Inductor cannot compile. Force balanced routing to make
# all token counts static.
dsv3_module.FORCE_BALANCED_ROUTING = True

run_test(
fake_evaluate=args.fake_evaluate,
use_loss_fn=args.use_loss_fn,
schedule_name=args.schedule_name,
rng_seed=args.rng_seed,
logs_dir=args.logs_dir,
use_cache=args.use_cache,
use_inductor=args.inductor,
)
Loading