Skip to content

Commit 3346541

Browse files
[https://nvbugs/5613089][fix] Fix the rank to access all_rank_chunk_size_list when chunked MoE is used
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
1 parent 6b9b73e commit 3346541

File tree

6 files changed

+10
-9
lines changed

6 files changed

+10
-9
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,10 +603,10 @@ def choose_one(
603603
new_tuning_failure_occured = False
604604

605605
for p in profiles:
606-
tensors = self._prepare_input_tensors(p, inputs)
607606
is_cache_hit, *_ = self.profiling_cache.search_cache(
608607
custom_op, runners, p.get_opt_shapes(), tuning_config)
609608
if not is_cache_hit:
609+
tensors = self._prepare_input_tensors(p, inputs)
610610
# Initialize runner and tactic as None in case of no valid tactic or runners are found
611611
best_runner_id, best_tactic, min_time, has_tuning_failure_occured = self._profile_runners(
612612
custom_op, runners, tensors, p, tuning_config, **kwargs)

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def forward_impl(
582582
all_rank_num_tokens_list = [[
583583
val[idx_chunk] for val in all_rank_chunk_size_list
584584
] for idx_chunk in range(num_chunks)]
585-
chunk_size_list = all_rank_chunk_size_list[self.rank]
585+
chunk_size_list = all_rank_chunk_size_list[self.parallel_rank]
586586
else:
587587
all_rank_num_tokens_list = [None] * num_chunks
588588
chunk_size_list = self.split_chunk(x.shape[0], num_chunks)
@@ -641,7 +641,7 @@ def _reducescatter_or_allreduce(x_, idx):
641641
outputs = torch.cat(outputs_list)
642642

643643
if self.use_dp and self.parallel_size > 1:
644-
rank = self.mapping.tp_rank
644+
rank = self.parallel_rank
645645
outputs = outputs[:all_rank_num_tokens[rank]]
646646
return outputs
647647

@@ -670,7 +670,7 @@ def forward_fake(
670670
is_nvfp4_input = isinstance(x, Fp4QuantizedTensor)
671671
data_type = output_dtype if is_nvfp4_input else x.dtype
672672
num_tokens = all_rank_num_tokens[
673-
self.mapping.tp_rank] if all_rank_num_tokens else x.shape[0]
673+
self.parallel_rank] if all_rank_num_tokens else x.shape[0]
674674
hidden_size = x.shape[1] * (2 if is_nvfp4_input else 1)
675675
top_k = self.routing_method.experts_per_token
676676
return x.new_empty((num_tokens, top_k, hidden_size),

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ def forward_impl(
706706
all_rank_num_tokens_list = [[
707707
val[idx_chunk] for val in all_rank_chunk_size_list
708708
] for idx_chunk in range(num_chunks)]
709-
chunk_size_list = all_rank_chunk_size_list[self.rank]
709+
chunk_size_list = all_rank_chunk_size_list[self.parallel_rank]
710710
else:
711711
all_rank_num_tokens_list = [None] * num_chunks
712712
chunk_size_list = self.split_chunk(x.shape[0], num_chunks)
@@ -778,6 +778,6 @@ def _reducescatter_or_allreduce(x_, idx):
778778
outputs = torch.cat(outputs_list)
779779

780780
if self.use_dp and self.parallel_size > 1:
781-
rank = self.mapping.tp_rank
781+
rank = self.parallel_rank
782782
outputs = outputs[:all_rank_num_tokens[rank]]
783783
return outputs

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def forward_impl(
618618
)
619619

620620
if use_dp_padding:
621-
rank = self.mapping.tp_rank
621+
rank = self.parallel_rank
622622
final_hidden_states = final_hidden_states[:
623623
all_rank_num_tokens[rank]]
624624
return final_hidden_states

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
843843
] for idx_chunk in range(num_chunks)]
844844
all_rank_max_num_tokens_list = split_chunk(all_rank_max_num_tokens,
845845
num_chunks)
846-
chunk_size_list = all_rank_chunk_size_list[self.rank]
846+
chunk_size_list = all_rank_chunk_size_list[self.parallel_rank]
847847
if use_all_to_all:
848848
all_rank_num_tokens_list = [[
849849
1 if val == 0 else val for val in val_list
@@ -931,7 +931,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
931931
self.event_dict[EventType.MoeChunkingOverlap].record()
932932
self.event_dict[EventType.MoeChunkingOverlap].wait()
933933
outputs = torch.cat(outputs_list)
934-
rank = self.mapping.tp_rank
934+
rank = self.parallel_rank
935935
outputs = outputs[:all_rank_num_tokens[rank]]
936936
self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1
937937
return outputs

tensorrt_llm/_torch/modules/fused_moe/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def __init__(
169169

170170
# All ranks participate in allreduce regardless of EP/TP combination
171171
self.mapping = model_config.mapping
172+
self.parallel_rank = self.mapping.tp_rank
172173
self.parallel_size = self.mapping.tp_size
173174
self.intermediate_size_per_partition = intermediate_size // self.tp_size
174175

0 commit comments

Comments
 (0)