From 7c0b309929ade6344c6c393b0d7167856fc04475 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Fri, 31 Oct 2025 14:50:13 +0800 Subject: [PATCH 1/3] support eplb noaux --- custom_ops/gpu_ops/cpp_extensions.cc | 14 ++ custom_ops/gpu_ops/noauxtc_kernel.h | 232 ++++++++++++++++++++ custom_ops/setup_ops.py | 1 + fastdeploy/model_executor/layers/moe/ep.py | 38 +++- fastdeploy/model_executor/layers/moe/moe.py | 36 ++- 5 files changed, 301 insertions(+), 20 deletions(-) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index a6fdef88345..f4b0b390196 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -570,6 +570,18 @@ std::vector NoauxTc( int topk, float routed_scaling_factor); +std::vector NoauxTcRedundant( + paddle::Tensor& scores, + paddle::Tensor& scores_with_bias, + paddle::Tensor& expert_id_to_ep_rank_array, + paddle::Tensor& expert_in_rank_num_list, + paddle::Tensor& tokens_per_expert_stats_list, + int n_group, + int topk_group, + int topk, + float routed_scaling_factor, + int redundant_ep_rank_num_plus_one); + #ifdef ENABLE_FP8 paddle::Tensor cutlass_fp8_fp8_half_gemm_func( const paddle::Tensor& x, @@ -1251,6 +1263,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); + m.def("noaux_tc_redunant",&NoauxTcRedundant, "noaux_tc_redundant for MoE compute"); + #ifdef ENABLE_FP8 m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func, py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"), diff --git a/custom_ops/gpu_ops/noauxtc_kernel.h b/custom_ops/gpu_ops/noauxtc_kernel.h index e8a3f450803..d1ba8efad71 100644 --- a/custom_ops/gpu_ops/noauxtc_kernel.h +++ b/custom_ops/gpu_ops/noauxtc_kernel.h @@ -306,6 +306,14 @@ class WarpSelect : public WarpSort { }; // end class WarpSelect } // namespace warp_topk + +inline __device__ unsigned int xorwow_moe(unsigned int &state) { + state ^= state >> 7; + state ^= state << 9; + state ^= state >> 13; + return state; +} + template __device__ void topk_with_k2(T* output, T const* input, @@ -507,6 +515,156 @@ __global__ void group_idx_and_topk_idx_kernel( } } +template +__global__ void group_idx_and_topk_idx_redundant_kernel( + T* scores, + T const* group_scores, + T* topk_values, + IdxT* topk_indices, + T* scores_with_bias, + int32_t* expert_id_to_ep_rank_array, + int32_t* expert_in_rank_num_list, + int32_t* tokens_per_expert_stats_list, + int64_t const num_tokens, + int64_t const n_group, + int64_t const topk_group, + int64_t const topk, + int64_t const num_experts, + int64_t const num_experts_per_group, + double routed_scaling_factor, + int64_t const redundant_ep_rank_num_plus_one) { + int32_t warp_id = threadIdx.x / WARP_SIZE; + int32_t lane_id = threadIdx.x % WARP_SIZE; + int32_t case_id = + blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token + unsigned int state = case_id; + scores_with_bias += case_id * num_experts; + scores += case_id * num_experts; + group_scores += case_id * n_group; + topk_values += case_id * topk; + topk_indices += case_id * topk; + int32_t align_num_experts_per_group = + warp_topk::round_up_to_multiple_of(num_experts_per_group); + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + + extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to + // store the target topk idx + int32_t* s_topk_idx = reinterpret_cast(smem_buf) + warp_id * topk; + T* s_topk_value = + reinterpret_cast(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) + + warp_id * topk; + + T value = cuda::std::numeric_limits::min(); + T topk_group_value = cuda::std::numeric_limits::min(); + int32_t num_equalto_topkth_group; + + if ((n_group > topk_group) && (case_id < num_tokens)) { + // calculate group_idx + int32_t target_num_min = WARP_SIZE - n_group + topk_group; + if (lane_id < n_group) { + value = group_scores[lane_id]; + } + + int count_equal_to_top_value = WARP_SIZE - n_group; + int pre_count_equal_to_top_value = 0; + // Use loop to find the largset top_group + while (count_equal_to_top_value < target_num_min) { + __syncwarp(); // Ensure all threads have valid data before reduction + topk_group_value = cg::reduce(tile, value, cg::greater()); + if (value == topk_group_value) { + value = cuda::std::numeric_limits::min(); + } + pre_count_equal_to_top_value = count_equal_to_top_value; + count_equal_to_top_value = __popc(__ballot_sync( + FULL_WARP_MASK, (value == cuda::std::numeric_limits::min()))); + } + num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; + } + __syncthreads(); + + warp_topk::WarpSelect + queue((int32_t)topk, cuda::std::numeric_limits::min()); + + int count_equalto_topkth_group = 0; + bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits::min()); + if (case_id < num_tokens) { + for (int i_group = 0; i_group < n_group; i_group++) { + if ((group_scores[i_group] > topk_group_value) || + ((group_scores[i_group] == topk_group_value) && + (count_equalto_topkth_group < num_equalto_topkth_group))) { + int32_t offset = i_group * num_experts_per_group; + for (int32_t i = lane_id; i < align_num_experts_per_group; + i += WARP_SIZE) { + T candidates = i < num_experts_per_group + ? scores_with_bias[offset + i] + : cuda::std::numeric_limits::min(); + queue.add(candidates, offset + i); + } + if (group_scores[i_group] == topk_group_value) { + count_equalto_topkth_group++; + } + } + } + queue.done(); + __syncwarp(); + // Get the topk_idx + queue.dumpIdx(s_topk_idx); + __syncwarp(); + } + + // Load the valid score value + // Calculate the summation + float topk_sum = 1e-20; + if (case_id < num_tokens) { + for (int i = lane_id; + i < warp_topk::round_up_to_multiple_of(topk); + i += WARP_SIZE) { + T value = i < topk ? scores[s_topk_idx[i]] + : 0.0f; // Load the valid value of expert + if (i < topk) { + s_topk_value[i] = value; + } + topk_sum += reduce(tile, value, cg::plus()); + } + } + + __syncthreads(); + if (case_id < num_tokens) { + for (int i = lane_id; i < num_experts; i += WARP_SIZE) { + scores[i] = 0; + } + } + __threadfence(); + __syncthreads(); + + if (case_id < num_tokens) { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + float value = s_topk_value[i] / topk_sum * routed_scaling_factor; + scores[s_topk_idx[i]] = value; + if (if_proceed_next_topk) { + int expert_topk = s_topk_idx[i]; + int len = expert_in_rank_num_list[expert_topk]; + int select = (int)xorwow_moe(state) % len; + int selected_rank = expert_id_to_ep_rank_array[expert_topk * redundant_ep_rank_num_plus_one + select]; + atomicAdd(&tokens_per_expert_stats_list[expert_topk], 1); + topk_indices[i] = (IdxT)selected_rank; + topk_values[i] = static_cast(value); + } + else { + int expert_topk = i; + int len = expert_in_rank_num_list[expert_topk]; + int select = (int)xorwow_moe(state) % len; + int selected_rank = expert_id_to_ep_rank_array[expert_topk * redundant_ep_rank_num_plus_one + select]; + atomicAdd(&tokens_per_expert_stats_list[expert_topk], 1); + topk_indices[i] = (IdxT)selected_rank; + topk_values[i] = static_cast(1.0f / topk); + } + } + } +} + template void invokeNoAuxTc(T* scores, T* group_scores, @@ -553,6 +711,60 @@ void invokeNoAuxTc(T* scores, routed_scaling_factor); } +template +void invokeNoAuxTcRedundant(T* scores, + T* group_scores, + T* topk_values, + IdxT* topk_indices, + T* scores_with_bias, + int32_t* expert_id_to_ep_rank_array, + int32_t* expert_in_rank_num_list, + int32_t* tokens_per_expert_stats_list, + int64_t const num_tokens, + int64_t const num_experts, + int64_t const n_group, + int64_t const topk_group, + int64_t const topk, + double const routed_scaling_factor, + int64_t const redundant_ep_rank_num_plus_one, + cudaStream_t const stream) { + int64_t num_cases = num_tokens * n_group; + int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; + topk_with_k2_kernel<<>>( + group_scores, + scores_with_bias, + num_tokens, + num_cases, + n_group, + num_experts / n_group); + + int64_t topk_with_k_group_num_blocks = + (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; + size_t dynamic_smem_in_bytes = + warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, + topk); + + group_idx_and_topk_idx_redundant_kernel<<>>(scores, + group_scores, + topk_values, + topk_indices, + scores_with_bias, + expert_id_to_ep_rank_array, + expert_in_rank_num_list, + tokens_per_expert_stats_list, + num_tokens, + n_group, + topk_group, + topk, + num_experts, + num_experts / n_group, + routed_scaling_factor, + redundant_ep_rank_num_plus_one); +} + #define INSTANTIATE_NOAUX_TC(T, IdxT) \ template void invokeNoAuxTc(T * scores, \ T * group_scores, \ @@ -568,3 +780,23 @@ void invokeNoAuxTc(T* scores, cudaStream_t const stream); INSTANTIATE_NOAUX_TC(float, int32_t); + +#define INSTANTIATE_NOAUX_TC_Redundant(T, IdxT) \ + template void invokeNoAuxTcRedundant(T * scores, \ + T * group_scores, \ + T* topk_values, \ + IdxT* topk_indices, \ + T * scores_with_bias, \ + int32_t* expert_id_to_ep_rank_array, \ + int32_t* expert_in_rank_num_list, \ + int32_t* tokens_per_expert_stats_list, \ + int64_t const num_tokens, \ + int64_t const num_experts, \ + int64_t const n_group, \ + int64_t const topk_group, \ + int64_t const topk, \ + double const routed_scaling_factor, \ + int64_t const redundant_ep_rank_num_plus_one, \ + cudaStream_t const stream); + +INSTANTIATE_NOAUX_TC_Redundant(float, int32_t); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 331f0e6f5a2..981a19ac8cc 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -298,6 +298,7 @@ def find_end_files(directory, end_str): "gpu_ops/get_position_ids_and_mask_encoder_batch.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", + "gpu_ops/noaux_tc_redundant.cu", "gpu_ops/custom_all_reduce/all_reduce.cu", "gpu_ops/merge_prefill_decode_output.cu", ] diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index d2bdeceb8cf..4b4de0b5ac6 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -437,17 +437,33 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): tokens_per_expert_stats_list, ) = layer.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer(layer.layer_idx) - topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select( - gating_logits=gate_out, - expert_id_to_ep_rank_array=expert_id_to_ep_rank_array, - expert_in_rank_num_list=expert_in_rank_num_list, - tokens_per_expert_stats_list=tokens_per_expert_stats_list, - bias=layer.gate_correction_bias, - moe_topk=self.top_k, - apply_norm_weight=True, - enable_softmax_top_k_fused=False, - redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1, - ) + if layer.topk_method == "noaux_tc": + from .moe import get_moe_scores + + score, topk_weights, topk_idx = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + expert_id_to_ep_rank_array=expert_id_to_ep_rank_array, + expert_in_rank_num_list=expert_in_rank_num_list, + tokens_per_expert_stats_list=tokens_per_expert_stats_list, + redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1, + ) + else: + topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select( + gating_logits=gate_out, + expert_id_to_ep_rank_array=expert_id_to_ep_rank_array, + expert_in_rank_num_list=expert_in_rank_num_list, + tokens_per_expert_stats_list=tokens_per_expert_stats_list, + bias=layer.gate_correction_bias, + moe_topk=self.top_k, + apply_norm_weight=True, + enable_softmax_top_k_fused=False, + redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1, + ) else: if layer.topk_method == "noaux_tc": from .moe import get_moe_scores diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 76c962069cd..49a215f2014 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -28,7 +28,7 @@ from fastdeploy.worker.experts_manager import RedundantExpertManger try: - from fastdeploy.model_executor.ops.gpu import noaux_tc + from fastdeploy.model_executor.ops.gpu import noaux_tc, noaux_tc_redundant except: logger.warning("import noaux_tc Failed!") @@ -66,6 +66,10 @@ def get_moe_scores( top_k, routed_scaling_factor, e_score_correction_bias, + expert_id_to_ep_rank_array=None, + expert_in_rank_num_list=None, + tokens_per_expert_stats_list=None, + redundant_ep_rank_num_plus_one=1, ) -> paddle.Tensor: """ compute moe scores using e_score_correction_bias. @@ -73,14 +77,28 @@ def get_moe_scores( scores = paddle.nn.functional.sigmoid(gating_output) assert e_score_correction_bias is not None, "e_score_correction_bias is none!" scores_with_bias = scores + e_score_correction_bias - scores, topk_values, topk_idx = noaux_tc( - scores, - scores_with_bias, - n_group if n_group > 0 else 1, - topk_group if topk_group > 0 else 1, - top_k, - routed_scaling_factor, - ) + if expert_id_to_ep_rank_array is None: + scores, topk_values, topk_idx = noaux_tc( + scores, + scores_with_bias, + n_group if n_group > 0 else 1, + topk_group if topk_group > 0 else 1, + top_k, + routed_scaling_factor, + ) + else: + scores, topk_values, topk_idx, _ = noaux_tc_redundant( + scores, + scores_with_bias, + expert_id_to_ep_rank_array, + expert_in_rank_num_list, + tokens_per_expert_stats_list, + n_group if n_group > 0 else 1, + topk_group if topk_group > 0 else 1, + top_k, + routed_scaling_factor, + redundant_ep_rank_num_plus_one, + ) return scores, topk_values, topk_idx From 2bd56ef7aaa05f2083e280bf51802a171e313a7f Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Fri, 31 Oct 2025 14:52:44 +0800 Subject: [PATCH 2/3] support eplb noaux --- custom_ops/gpu_ops/noaux_tc_redundant.cu | 92 ++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 custom_ops/gpu_ops/noaux_tc_redundant.cu diff --git a/custom_ops/gpu_ops/noaux_tc_redundant.cu b/custom_ops/gpu_ops/noaux_tc_redundant.cu new file mode 100644 index 00000000000..785261299e8 --- /dev/null +++ b/custom_ops/gpu_ops/noaux_tc_redundant.cu @@ -0,0 +1,92 @@ + +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "helper.h" +#include "noauxtc_kernel.h" + +std::vector NoauxTcRedundant(paddle::Tensor& scores, + paddle::Tensor& scores_with_bias, + paddle::Tensor& expert_id_to_ep_rank_array, + paddle::Tensor& expert_in_rank_num_list, + paddle::Tensor& tokens_per_expert_stats_list, + int n_group, + int topk_group, + int topk, + float routed_scaling_factor, + int redundant_ep_rank_num_plus_one) { + auto input_shape = scores_with_bias.shape(); + PD_CHECK(input_shape.size() == 2); + int64_t num_tokens = input_shape[0]; + int64_t num_experts = input_shape[1]; + auto input_type = scores_with_bias.dtype(); + auto place = scores_with_bias.place(); + auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place); + auto topk_values = paddle::empty({num_tokens, topk}, input_type, place); + auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place); + auto stream = scores_with_bias.stream(); + + invokeNoAuxTcRedundant(reinterpret_cast(scores.data()), + reinterpret_cast(group_scores.data()), + reinterpret_cast(topk_values.data()), + reinterpret_cast(topk_indices.data()), + reinterpret_cast(scores_with_bias.data()), + reinterpret_cast(expert_id_to_ep_rank_array.data()), + reinterpret_cast(expert_in_rank_num_list.data()), + reinterpret_cast(tokens_per_expert_stats_list.data()), + num_tokens, + num_experts, + n_group, + topk_group, + topk, + routed_scaling_factor, + redundant_ep_rank_num_plus_one, + stream); + + return {scores, topk_values, topk_indices}; +} + +std::vector NoauxTcRedundantInferDtype( + const paddle::DataType& scores_dtype, + const paddle::DataType& scores_with_bias_dtype) { + return {scores_dtype, scores_dtype, paddle::DataType::INT64}; +} + +std::vector> NoauxTcRedundantInferShape( + const std::vector& scores_shape, + const std::vector& , + const int topk) { + auto num_tokens = scores_shape[0]; + auto topk_values_shape = std::vector{num_tokens, topk}; + auto topk_indices_shape = std::vector{num_tokens, topk}; + return {scores_shape, topk_values_shape, topk_indices_shape}; +} + +PD_BUILD_STATIC_OP(noaux_tc_redundant) + .Inputs({"scores", "scores_with_bias", "expert_id_to_ep_rank_array", "expert_in_rank_num_list", "tokens_per_expert_stats_list"}) + .Outputs({"output_tensor", "topk_values", "topk_indices", "tokens_per_expert_stats_list_out"}) + .Attrs({"n_group: int", + "topk_group: int", + "topk:int", + "routed_scaling_factor: float", + "redundant_ep_rank_num_plus_one:int"}) + .SetInplaceMap({{"tokens_per_expert_stats_list", "tokens_per_expert_stats_list_out"}}) + .SetKernelFn(PD_KERNEL(NoauxTcRedundant)) + .SetInferShapeFn(PD_INFER_SHAPE(NoauxTcRedundantInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(NoauxTcRedundantInferDtype)); From 7f3a4014e04245255e17130ec7664a3717427b39 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Mon, 3 Nov 2025 20:29:48 +0800 Subject: [PATCH 3/3] add eplb noaux test --- tests/operators/test_noaux_tc_redundant.py | 84 ++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 tests/operators/test_noaux_tc_redundant.py diff --git a/tests/operators/test_noaux_tc_redundant.py b/tests/operators/test_noaux_tc_redundant.py new file mode 100644 index 00000000000..a8d2fbdb9d1 --- /dev/null +++ b/tests/operators/test_noaux_tc_redundant.py @@ -0,0 +1,84 @@ +import unittest + +import paddle + +from fastdeploy.model_executor.ops.gpu import noaux_tc_redundant + + +class TestMoeRouting(unittest.TestCase): + def setUp(self): + self.num_tokens = 10 + self.num_experts = 64 + self.gating_output = paddle.rand([self.num_tokens, self.num_experts]) + self.e_score_correction_bias = paddle.rand([self.num_experts]) + self.n_group = 8 + self.topk_group = 4 + self.top_k = 8 + self.routed_scaling_factor = 1.5 + self.redundant_ep_rank_num_plus_one = 1 + + def node_limit_routing(self, gate_probs): + """将所有专家分组, 只在topk_group个group内选择专家""" + assert len(gate_probs.shape) == 2 + seq_length, n_experts = gate_probs.shape + + group_scores = gate_probs.reshape([seq_length, 8, -1]).topk(2, axis=-1)[0].sum(axis=-1) + group_idx = paddle.topk(group_scores, k=4, axis=-1, sorted=True)[1] + group_mask = paddle.zeros_like(group_scores).put_along_axis( + group_idx, paddle.ones([], dtype="float32"), axis=-1 + ) + score_mask = group_mask.unsqueeze(-1).expand([seq_length, 8, n_experts // 8]).reshape([seq_length, -1]) + gate_probs = gate_probs.masked_fill(~score_mask.astype(paddle.bool), float("-inf")) + return gate_probs + + def ref_moe_routing(self): + scores = paddle.nn.functional.sigmoid(self.gating_output) + prob_for_choice = scores + self.e_score_correction_bias.unsqueeze(0) + prob_for_choice = self.node_limit_routing(prob_for_choice) + top_logits, topk_idx_ref = paddle.topk(prob_for_choice, self.top_k, axis=1) + + token_num, top_k = topk_idx_ref.shape + _, num_expert = prob_for_choice.shape + topk_idx_expanded = paddle.unsqueeze(topk_idx_ref, axis=-1) + indices = paddle.concat( + [ + paddle.arange(token_num, dtype="int64").unsqueeze(1).tile([1, top_k]).unsqueeze(-1), + topk_idx_expanded, + ], + axis=-1, + ) + selected_gate_probs = paddle.gather_nd(scores, indices) + + selected_gate_probs_sum = paddle.sum(selected_gate_probs, axis=1, keepdim=True) + topk_weights_ref = selected_gate_probs / selected_gate_probs_sum + topk_weights_ref = topk_weights_ref * self.routed_scaling_factor + return topk_weights_ref, topk_idx_ref + + def test_moe_select(self): + scores = paddle.nn.functional.sigmoid(self.gating_output) + scores_with_bias = scores + self.e_score_correction_bias.unsqueeze(0) + expert_id_to_ep_rank_array = paddle.arange(self.num_experts, dtype="int32").reshape([self.num_experts, 1]) + expert_in_rank_num_list = paddle.arange(self.num_experts, dtype="int32").reshape([self.num_experts, 1]) + tokens_per_expert_stats_list = paddle.arange(self.num_experts, dtype="int32").reshape([self.num_experts, 1]) + + scores, topk_values, topk_idx, _ = noaux_tc_redundant( + scores, + scores_with_bias, + expert_id_to_ep_rank_array, + expert_in_rank_num_list, + tokens_per_expert_stats_list, + self.n_group, + self.topk_group, + self.top_k, + self.routed_scaling_factor, + self.redundant_ep_rank_num_plus_one, + ) + + ref_topk_values, ref_topk_idx = self.ref_moe_routing() + + paddle.allclose(topk_values, ref_topk_values) + paddle.allclose(topk_idx.cast(int), ref_topk_idx.cast(int)) + + +if __name__ == "__main__": + unittest.main()