From 82089015f4ce4bcd6a14b8456f6bda499f3fde85 Mon Sep 17 00:00:00 2001 From: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> Date: Tue, 28 Oct 2025 03:00:50 -0700 Subject: [PATCH] [https://nvbugs/5613089][fix] Fix the rank to access all_rank_chunk_size_list when chunked MoE is used Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> --- tensorrt_llm/_torch/autotuner.py | 2 +- tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py | 6 +++--- tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py | 4 ++-- .../_torch/modules/fused_moe/fused_moe_trtllm_gen.py | 2 +- tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py | 4 ++-- tensorrt_llm/_torch/modules/fused_moe/interface.py | 1 + 6 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 00219cde522..c685748058e 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -727,10 +727,10 @@ def choose_one( new_tuning_failure_occured = False for p in profiles: - tensors = self._prepare_input_tensors(p, inputs) is_cache_hit, *_ = self.profiling_cache.search_cache( custom_op, runners, p.get_opt_shapes(), tuning_config) if not is_cache_hit: + tensors = self._prepare_input_tensors(p, inputs) # Initialize runner and tactic as None in case of no valid tactic or runners are found best_runner_id, best_tactic, min_time, has_tuning_failure_occured = self._profile_runners( custom_op, runners, tensors, p, tuning_config, **kwargs) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index 13325a2b832..647179bff8b 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -626,7 +626,7 @@ def forward_impl( all_rank_num_tokens_list = [[ val[idx_chunk] for val in all_rank_chunk_size_list ] for idx_chunk in range(num_chunks)] - chunk_size_list = all_rank_chunk_size_list[self.rank] + chunk_size_list = all_rank_chunk_size_list[self.parallel_rank] else: all_rank_num_tokens_list = [None] * num_chunks chunk_size_list = self.split_chunk(x.shape[0], num_chunks) @@ -685,7 +685,7 @@ def _reducescatter_or_allreduce(x_, idx): outputs = torch.cat(outputs_list) if self.use_dp and self.parallel_size > 1: - rank = self.mapping.tp_rank + rank = self.parallel_rank outputs = outputs[:all_rank_num_tokens[rank]] return outputs @@ -714,7 +714,7 @@ def forward_fake( is_nvfp4_input = isinstance(x, Fp4QuantizedTensor) data_type = output_dtype if is_nvfp4_input else x.dtype num_tokens = all_rank_num_tokens[ - self.mapping.tp_rank] if all_rank_num_tokens else x.shape[0] + self.parallel_rank] if all_rank_num_tokens else x.shape[0] hidden_size = x.shape[1] * (2 if is_nvfp4_input else 1) top_k = self.routing_method.experts_per_token return x.new_empty((num_tokens, top_k, hidden_size), diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 3eed7d00bb9..69968a8be4f 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -706,7 +706,7 @@ def forward_impl( all_rank_num_tokens_list = [[ val[idx_chunk] for val in all_rank_chunk_size_list ] for idx_chunk in range(num_chunks)] - chunk_size_list = all_rank_chunk_size_list[self.rank] + chunk_size_list = all_rank_chunk_size_list[self.parallel_rank] else: all_rank_num_tokens_list = [None] * num_chunks chunk_size_list = self.split_chunk(x.shape[0], num_chunks) @@ -778,6 +778,6 @@ def _reducescatter_or_allreduce(x_, idx): outputs = torch.cat(outputs_list) if self.use_dp and self.parallel_size > 1: - rank = self.mapping.tp_rank + rank = self.parallel_rank outputs = outputs[:all_rank_num_tokens[rank]] return outputs diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 1aca60ce417..fad075b74bf 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -661,7 +661,7 @@ def forward_impl( ) if use_dp_padding: - rank = self.mapping.tp_rank + rank = self.parallel_rank final_hidden_states = final_hidden_states[: all_rank_num_tokens[rank]] return final_hidden_states diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index b9ae687918b..a3c0ee4c622 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -828,7 +828,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int): ] for idx_chunk in range(num_chunks)] all_rank_max_num_tokens_list = split_chunk(all_rank_max_num_tokens, num_chunks) - chunk_size_list = all_rank_chunk_size_list[self.rank] + chunk_size_list = all_rank_chunk_size_list[self.parallel_rank] if use_all_to_all: all_rank_num_tokens_list = [[ 1 if val == 0 else val for val in val_list @@ -916,7 +916,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int): self.event_dict[EventType.MoeChunkingOverlap].record() self.event_dict[EventType.MoeChunkingOverlap].wait() outputs = torch.cat(outputs_list) - rank = self.mapping.tp_rank + rank = self.parallel_rank outputs = outputs[:all_rank_num_tokens[rank]] self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1 return outputs diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index 54589a71e5a..5d5079d9c8c 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -181,6 +181,7 @@ def __init__( # All ranks participate in allreduce regardless of EP/TP combination self.mapping = model_config.mapping + self.parallel_rank = self.mapping.tp_rank self.parallel_size = self.mapping.tp_size self.intermediate_size_per_partition = intermediate_size // self.tp_size