Skip to content

Commit ee37882

Browse files
[NewFeature] support eplb noaux (#4725)
* support eplb noaux * support eplb noaux * add eplb noaux test
1 parent 1e88754 commit ee37882

File tree

7 files changed

+477
-20
lines changed

7 files changed

+477
-20
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,18 @@ std::vector<paddle::Tensor> NoauxTc(
570570
int topk,
571571
float routed_scaling_factor);
572572

573+
std::vector<paddle::Tensor> NoauxTcRedundant(
574+
paddle::Tensor& scores,
575+
paddle::Tensor& scores_with_bias,
576+
paddle::Tensor& expert_id_to_ep_rank_array,
577+
paddle::Tensor& expert_in_rank_num_list,
578+
paddle::Tensor& tokens_per_expert_stats_list,
579+
int n_group,
580+
int topk_group,
581+
int topk,
582+
float routed_scaling_factor,
583+
int redundant_ep_rank_num_plus_one);
584+
573585
#ifdef ENABLE_FP8
574586
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
575587
const paddle::Tensor& x,
@@ -1251,6 +1263,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
12511263

12521264
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
12531265

1266+
m.def("noaux_tc_redunant",&NoauxTcRedundant, "noaux_tc_redundant for MoE compute");
1267+
12541268
#ifdef ENABLE_FP8
12551269
m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func,
12561270
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
2+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#pragma once
17+
18+
#include <algorithm>
19+
#include <optional>
20+
21+
#include "helper.h"
22+
#include "noauxtc_kernel.h"
23+
24+
std::vector<paddle::Tensor> NoauxTcRedundant(paddle::Tensor& scores,
25+
paddle::Tensor& scores_with_bias,
26+
paddle::Tensor& expert_id_to_ep_rank_array,
27+
paddle::Tensor& expert_in_rank_num_list,
28+
paddle::Tensor& tokens_per_expert_stats_list,
29+
int n_group,
30+
int topk_group,
31+
int topk,
32+
float routed_scaling_factor,
33+
int redundant_ep_rank_num_plus_one) {
34+
auto input_shape = scores_with_bias.shape();
35+
PD_CHECK(input_shape.size() == 2);
36+
int64_t num_tokens = input_shape[0];
37+
int64_t num_experts = input_shape[1];
38+
auto input_type = scores_with_bias.dtype();
39+
auto place = scores_with_bias.place();
40+
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
41+
auto topk_values = paddle::empty({num_tokens, topk}, input_type, place);
42+
auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place);
43+
auto stream = scores_with_bias.stream();
44+
45+
invokeNoAuxTcRedundant<float, int64_t>(reinterpret_cast<float*>(scores.data<float>()),
46+
reinterpret_cast<float*>(group_scores.data<float>()),
47+
reinterpret_cast<float*>(topk_values.data<float>()),
48+
reinterpret_cast<int64_t*>(topk_indices.data<int64_t>()),
49+
reinterpret_cast<float*>(scores_with_bias.data<float>()),
50+
reinterpret_cast<int*>(expert_id_to_ep_rank_array.data<int>()),
51+
reinterpret_cast<int*>(expert_in_rank_num_list.data<int>()),
52+
reinterpret_cast<int*>(tokens_per_expert_stats_list.data<int>()),
53+
num_tokens,
54+
num_experts,
55+
n_group,
56+
topk_group,
57+
topk,
58+
routed_scaling_factor,
59+
redundant_ep_rank_num_plus_one,
60+
stream);
61+
62+
return {scores, topk_values, topk_indices};
63+
}
64+
65+
std::vector<paddle::DataType> NoauxTcRedundantInferDtype(
66+
const paddle::DataType& scores_dtype,
67+
const paddle::DataType& scores_with_bias_dtype) {
68+
return {scores_dtype, scores_dtype, paddle::DataType::INT64};
69+
}
70+
71+
std::vector<std::vector<int64_t>> NoauxTcRedundantInferShape(
72+
const std::vector<int64_t>& scores_shape,
73+
const std::vector<int64_t>& ,
74+
const int topk) {
75+
auto num_tokens = scores_shape[0];
76+
auto topk_values_shape = std::vector<int64_t>{num_tokens, topk};
77+
auto topk_indices_shape = std::vector<int64_t>{num_tokens, topk};
78+
return {scores_shape, topk_values_shape, topk_indices_shape};
79+
}
80+
81+
PD_BUILD_STATIC_OP(noaux_tc_redundant)
82+
.Inputs({"scores", "scores_with_bias", "expert_id_to_ep_rank_array", "expert_in_rank_num_list", "tokens_per_expert_stats_list"})
83+
.Outputs({"output_tensor", "topk_values", "topk_indices", "tokens_per_expert_stats_list_out"})
84+
.Attrs({"n_group: int",
85+
"topk_group: int",
86+
"topk:int",
87+
"routed_scaling_factor: float",
88+
"redundant_ep_rank_num_plus_one:int"})
89+
.SetInplaceMap({{"tokens_per_expert_stats_list", "tokens_per_expert_stats_list_out"}})
90+
.SetKernelFn(PD_KERNEL(NoauxTcRedundant))
91+
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcRedundantInferShape))
92+
.SetInferDtypeFn(PD_INFER_DTYPE(NoauxTcRedundantInferDtype));

custom_ops/gpu_ops/noauxtc_kernel.h

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,14 @@ class WarpSelect : public WarpSort<capacity, greater, T, idxT> {
306306
}; // end class WarpSelect
307307
} // namespace warp_topk
308308

309+
310+
inline __device__ unsigned int xorwow_moe(unsigned int &state) {
311+
state ^= state >> 7;
312+
state ^= state << 9;
313+
state ^= state >> 13;
314+
return state;
315+
}
316+
309317
template <typename T>
310318
__device__ void topk_with_k2(T* output,
311319
T const* input,
@@ -507,6 +515,156 @@ __global__ void group_idx_and_topk_idx_kernel(
507515
}
508516
}
509517

518+
template <typename T, typename IdxT>
519+
__global__ void group_idx_and_topk_idx_redundant_kernel(
520+
T* scores,
521+
T const* group_scores,
522+
T* topk_values,
523+
IdxT* topk_indices,
524+
T* scores_with_bias,
525+
int32_t* expert_id_to_ep_rank_array,
526+
int32_t* expert_in_rank_num_list,
527+
int32_t* tokens_per_expert_stats_list,
528+
int64_t const num_tokens,
529+
int64_t const n_group,
530+
int64_t const topk_group,
531+
int64_t const topk,
532+
int64_t const num_experts,
533+
int64_t const num_experts_per_group,
534+
double routed_scaling_factor,
535+
int64_t const redundant_ep_rank_num_plus_one) {
536+
int32_t warp_id = threadIdx.x / WARP_SIZE;
537+
int32_t lane_id = threadIdx.x % WARP_SIZE;
538+
int32_t case_id =
539+
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
540+
unsigned int state = case_id;
541+
scores_with_bias += case_id * num_experts;
542+
scores += case_id * num_experts;
543+
group_scores += case_id * n_group;
544+
topk_values += case_id * topk;
545+
topk_indices += case_id * topk;
546+
int32_t align_num_experts_per_group =
547+
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
548+
549+
cg::thread_block block = cg::this_thread_block();
550+
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
551+
552+
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
553+
// store the target topk idx
554+
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf) + warp_id * topk;
555+
T* s_topk_value =
556+
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
557+
warp_id * topk;
558+
559+
T value = cuda::std::numeric_limits<T>::min();
560+
T topk_group_value = cuda::std::numeric_limits<T>::min();
561+
int32_t num_equalto_topkth_group;
562+
563+
if ((n_group > topk_group) && (case_id < num_tokens)) {
564+
// calculate group_idx
565+
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
566+
if (lane_id < n_group) {
567+
value = group_scores[lane_id];
568+
}
569+
570+
int count_equal_to_top_value = WARP_SIZE - n_group;
571+
int pre_count_equal_to_top_value = 0;
572+
// Use loop to find the largset top_group
573+
while (count_equal_to_top_value < target_num_min) {
574+
__syncwarp(); // Ensure all threads have valid data before reduction
575+
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
576+
if (value == topk_group_value) {
577+
value = cuda::std::numeric_limits<T>::min();
578+
}
579+
pre_count_equal_to_top_value = count_equal_to_top_value;
580+
count_equal_to_top_value = __popc(__ballot_sync(
581+
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
582+
}
583+
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
584+
}
585+
__syncthreads();
586+
587+
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t>
588+
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
589+
590+
int count_equalto_topkth_group = 0;
591+
bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits<T>::min());
592+
if (case_id < num_tokens) {
593+
for (int i_group = 0; i_group < n_group; i_group++) {
594+
if ((group_scores[i_group] > topk_group_value) ||
595+
((group_scores[i_group] == topk_group_value) &&
596+
(count_equalto_topkth_group < num_equalto_topkth_group))) {
597+
int32_t offset = i_group * num_experts_per_group;
598+
for (int32_t i = lane_id; i < align_num_experts_per_group;
599+
i += WARP_SIZE) {
600+
T candidates = i < num_experts_per_group
601+
? scores_with_bias[offset + i]
602+
: cuda::std::numeric_limits<T>::min();
603+
queue.add(candidates, offset + i);
604+
}
605+
if (group_scores[i_group] == topk_group_value) {
606+
count_equalto_topkth_group++;
607+
}
608+
}
609+
}
610+
queue.done();
611+
__syncwarp();
612+
// Get the topk_idx
613+
queue.dumpIdx(s_topk_idx);
614+
__syncwarp();
615+
}
616+
617+
// Load the valid score value
618+
// Calculate the summation
619+
float topk_sum = 1e-20;
620+
if (case_id < num_tokens) {
621+
for (int i = lane_id;
622+
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
623+
i += WARP_SIZE) {
624+
T value = i < topk ? scores[s_topk_idx[i]]
625+
: 0.0f; // Load the valid value of expert
626+
if (i < topk) {
627+
s_topk_value[i] = value;
628+
}
629+
topk_sum += reduce(tile, value, cg::plus<float>());
630+
}
631+
}
632+
633+
__syncthreads();
634+
if (case_id < num_tokens) {
635+
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
636+
scores[i] = 0;
637+
}
638+
}
639+
__threadfence();
640+
__syncthreads();
641+
642+
if (case_id < num_tokens) {
643+
for (int i = lane_id; i < topk; i += WARP_SIZE) {
644+
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
645+
scores[s_topk_idx[i]] = value;
646+
if (if_proceed_next_topk) {
647+
int expert_topk = s_topk_idx[i];
648+
int len = expert_in_rank_num_list[expert_topk];
649+
int select = (int)xorwow_moe(state) % len;
650+
int selected_rank = expert_id_to_ep_rank_array[expert_topk * redundant_ep_rank_num_plus_one + select];
651+
atomicAdd(&tokens_per_expert_stats_list[expert_topk], 1);
652+
topk_indices[i] = (IdxT)selected_rank;
653+
topk_values[i] = static_cast<T>(value);
654+
}
655+
else {
656+
int expert_topk = i;
657+
int len = expert_in_rank_num_list[expert_topk];
658+
int select = (int)xorwow_moe(state) % len;
659+
int selected_rank = expert_id_to_ep_rank_array[expert_topk * redundant_ep_rank_num_plus_one + select];
660+
atomicAdd(&tokens_per_expert_stats_list[expert_topk], 1);
661+
topk_indices[i] = (IdxT)selected_rank;
662+
topk_values[i] = static_cast<float>(1.0f / topk);
663+
}
664+
}
665+
}
666+
}
667+
510668
template <typename T, typename IdxT>
511669
void invokeNoAuxTc(T* scores,
512670
T* group_scores,
@@ -553,6 +711,60 @@ void invokeNoAuxTc(T* scores,
553711
routed_scaling_factor);
554712
}
555713

714+
template <typename T, typename IdxT>
715+
void invokeNoAuxTcRedundant(T* scores,
716+
T* group_scores,
717+
T* topk_values,
718+
IdxT* topk_indices,
719+
T* scores_with_bias,
720+
int32_t* expert_id_to_ep_rank_array,
721+
int32_t* expert_in_rank_num_list,
722+
int32_t* tokens_per_expert_stats_list,
723+
int64_t const num_tokens,
724+
int64_t const num_experts,
725+
int64_t const n_group,
726+
int64_t const topk_group,
727+
int64_t const topk,
728+
double const routed_scaling_factor,
729+
int64_t const redundant_ep_rank_num_plus_one,
730+
cudaStream_t const stream) {
731+
int64_t num_cases = num_tokens * n_group;
732+
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
733+
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
734+
group_scores,
735+
scores_with_bias,
736+
num_tokens,
737+
num_cases,
738+
n_group,
739+
num_experts / n_group);
740+
741+
int64_t topk_with_k_group_num_blocks =
742+
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
743+
size_t dynamic_smem_in_bytes =
744+
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
745+
topk);
746+
747+
group_idx_and_topk_idx_redundant_kernel<T><<<topk_with_k_group_num_blocks,
748+
BLOCK_SIZE,
749+
dynamic_smem_in_bytes,
750+
stream>>>(scores,
751+
group_scores,
752+
topk_values,
753+
topk_indices,
754+
scores_with_bias,
755+
expert_id_to_ep_rank_array,
756+
expert_in_rank_num_list,
757+
tokens_per_expert_stats_list,
758+
num_tokens,
759+
n_group,
760+
topk_group,
761+
topk,
762+
num_experts,
763+
num_experts / n_group,
764+
routed_scaling_factor,
765+
redundant_ep_rank_num_plus_one);
766+
}
767+
556768
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
557769
template void invokeNoAuxTc<T, IdxT>(T * scores, \
558770
T * group_scores, \
@@ -568,3 +780,23 @@ void invokeNoAuxTc(T* scores,
568780
cudaStream_t const stream);
569781

570782
INSTANTIATE_NOAUX_TC(float, int32_t);
783+
784+
#define INSTANTIATE_NOAUX_TC_Redundant(T, IdxT) \
785+
template void invokeNoAuxTcRedundant<T, IdxT>(T * scores, \
786+
T * group_scores, \
787+
T* topk_values, \
788+
IdxT* topk_indices, \
789+
T * scores_with_bias, \
790+
int32_t* expert_id_to_ep_rank_array, \
791+
int32_t* expert_in_rank_num_list, \
792+
int32_t* tokens_per_expert_stats_list, \
793+
int64_t const num_tokens, \
794+
int64_t const num_experts, \
795+
int64_t const n_group, \
796+
int64_t const topk_group, \
797+
int64_t const topk, \
798+
double const routed_scaling_factor, \
799+
int64_t const redundant_ep_rank_num_plus_one, \
800+
cudaStream_t const stream);
801+
802+
INSTANTIATE_NOAUX_TC_Redundant(float, int32_t);

custom_ops/setup_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def find_end_files(directory, end_str):
298298
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
299299
"gpu_ops/fused_rotary_position_encoding.cu",
300300
"gpu_ops/noaux_tc.cu",
301+
"gpu_ops/noaux_tc_redundant.cu",
301302
"gpu_ops/custom_all_reduce/all_reduce.cu",
302303
"gpu_ops/merge_prefill_decode_output.cu",
303304
]

0 commit comments

Comments
 (0)