@@ -306,6 +306,14 @@ class WarpSelect : public WarpSort<capacity, greater, T, idxT> {
306306}; // end class WarpSelect
307307} // namespace warp_topk
308308
309+
310+ inline __device__ unsigned int xorwow_moe (unsigned int &state) {
311+ state ^= state >> 7 ;
312+ state ^= state << 9 ;
313+ state ^= state >> 13 ;
314+ return state;
315+ }
316+
309317template <typename T>
310318__device__ void topk_with_k2 (T* output,
311319 T const * input,
@@ -507,6 +515,156 @@ __global__ void group_idx_and_topk_idx_kernel(
507515 }
508516}
509517
518+ template <typename T, typename IdxT>
519+ __global__ void group_idx_and_topk_idx_redundant_kernel (
520+ T* scores,
521+ T const * group_scores,
522+ T* topk_values,
523+ IdxT* topk_indices,
524+ T* scores_with_bias,
525+ int32_t * expert_id_to_ep_rank_array,
526+ int32_t * expert_in_rank_num_list,
527+ int32_t * tokens_per_expert_stats_list,
528+ int64_t const num_tokens,
529+ int64_t const n_group,
530+ int64_t const topk_group,
531+ int64_t const topk,
532+ int64_t const num_experts,
533+ int64_t const num_experts_per_group,
534+ double routed_scaling_factor,
535+ int64_t const redundant_ep_rank_num_plus_one) {
536+ int32_t warp_id = threadIdx.x / WARP_SIZE;
537+ int32_t lane_id = threadIdx.x % WARP_SIZE;
538+ int32_t case_id =
539+ blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
540+ unsigned int state = case_id;
541+ scores_with_bias += case_id * num_experts;
542+ scores += case_id * num_experts;
543+ group_scores += case_id * n_group;
544+ topk_values += case_id * topk;
545+ topk_indices += case_id * topk;
546+ int32_t align_num_experts_per_group =
547+ warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
548+
549+ cg::thread_block block = cg::this_thread_block ();
550+ cg::thread_block_tile<32 > tile = cg::tiled_partition<32 >(block);
551+
552+ extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
553+ // store the target topk idx
554+ int32_t * s_topk_idx = reinterpret_cast <int32_t *>(smem_buf) + warp_id * topk;
555+ T* s_topk_value =
556+ reinterpret_cast <T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
557+ warp_id * topk;
558+
559+ T value = cuda::std::numeric_limits<T>::min ();
560+ T topk_group_value = cuda::std::numeric_limits<T>::min ();
561+ int32_t num_equalto_topkth_group;
562+
563+ if ((n_group > topk_group) && (case_id < num_tokens)) {
564+ // calculate group_idx
565+ int32_t target_num_min = WARP_SIZE - n_group + topk_group;
566+ if (lane_id < n_group) {
567+ value = group_scores[lane_id];
568+ }
569+
570+ int count_equal_to_top_value = WARP_SIZE - n_group;
571+ int pre_count_equal_to_top_value = 0 ;
572+ // Use loop to find the largset top_group
573+ while (count_equal_to_top_value < target_num_min) {
574+ __syncwarp (); // Ensure all threads have valid data before reduction
575+ topk_group_value = cg::reduce (tile, value, cg::greater<T>());
576+ if (value == topk_group_value) {
577+ value = cuda::std::numeric_limits<T>::min ();
578+ }
579+ pre_count_equal_to_top_value = count_equal_to_top_value;
580+ count_equal_to_top_value = __popc (__ballot_sync (
581+ FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min ())));
582+ }
583+ num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
584+ }
585+ __syncthreads ();
586+
587+ warp_topk::WarpSelect</* capability*/ WARP_SIZE, /* greater*/ true , T, int32_t >
588+ queue ((int32_t )topk, cuda::std::numeric_limits<T>::min ());
589+
590+ int count_equalto_topkth_group = 0 ;
591+ bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits<T>::min ());
592+ if (case_id < num_tokens) {
593+ for (int i_group = 0 ; i_group < n_group; i_group++) {
594+ if ((group_scores[i_group] > topk_group_value) ||
595+ ((group_scores[i_group] == topk_group_value) &&
596+ (count_equalto_topkth_group < num_equalto_topkth_group))) {
597+ int32_t offset = i_group * num_experts_per_group;
598+ for (int32_t i = lane_id; i < align_num_experts_per_group;
599+ i += WARP_SIZE) {
600+ T candidates = i < num_experts_per_group
601+ ? scores_with_bias[offset + i]
602+ : cuda::std::numeric_limits<T>::min ();
603+ queue.add (candidates, offset + i);
604+ }
605+ if (group_scores[i_group] == topk_group_value) {
606+ count_equalto_topkth_group++;
607+ }
608+ }
609+ }
610+ queue.done ();
611+ __syncwarp ();
612+ // Get the topk_idx
613+ queue.dumpIdx (s_topk_idx);
614+ __syncwarp ();
615+ }
616+
617+ // Load the valid score value
618+ // Calculate the summation
619+ float topk_sum = 1e-20 ;
620+ if (case_id < num_tokens) {
621+ for (int i = lane_id;
622+ i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
623+ i += WARP_SIZE) {
624+ T value = i < topk ? scores[s_topk_idx[i]]
625+ : 0 .0f ; // Load the valid value of expert
626+ if (i < topk) {
627+ s_topk_value[i] = value;
628+ }
629+ topk_sum += reduce (tile, value, cg::plus<float >());
630+ }
631+ }
632+
633+ __syncthreads ();
634+ if (case_id < num_tokens) {
635+ for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
636+ scores[i] = 0 ;
637+ }
638+ }
639+ __threadfence ();
640+ __syncthreads ();
641+
642+ if (case_id < num_tokens) {
643+ for (int i = lane_id; i < topk; i += WARP_SIZE) {
644+ float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
645+ scores[s_topk_idx[i]] = value;
646+ if (if_proceed_next_topk) {
647+ int expert_topk = s_topk_idx[i];
648+ int len = expert_in_rank_num_list[expert_topk];
649+ int select = (int )xorwow_moe (state) % len;
650+ int selected_rank = expert_id_to_ep_rank_array[expert_topk * redundant_ep_rank_num_plus_one + select];
651+ atomicAdd (&tokens_per_expert_stats_list[expert_topk], 1 );
652+ topk_indices[i] = (IdxT)selected_rank;
653+ topk_values[i] = static_cast <T>(value);
654+ }
655+ else {
656+ int expert_topk = i;
657+ int len = expert_in_rank_num_list[expert_topk];
658+ int select = (int )xorwow_moe (state) % len;
659+ int selected_rank = expert_id_to_ep_rank_array[expert_topk * redundant_ep_rank_num_plus_one + select];
660+ atomicAdd (&tokens_per_expert_stats_list[expert_topk], 1 );
661+ topk_indices[i] = (IdxT)selected_rank;
662+ topk_values[i] = static_cast <float >(1 .0f / topk);
663+ }
664+ }
665+ }
666+ }
667+
510668template <typename T, typename IdxT>
511669void invokeNoAuxTc (T* scores,
512670 T* group_scores,
@@ -553,6 +711,60 @@ void invokeNoAuxTc(T* scores,
553711 routed_scaling_factor);
554712}
555713
714+ template <typename T, typename IdxT>
715+ void invokeNoAuxTcRedundant (T* scores,
716+ T* group_scores,
717+ T* topk_values,
718+ IdxT* topk_indices,
719+ T* scores_with_bias,
720+ int32_t * expert_id_to_ep_rank_array,
721+ int32_t * expert_in_rank_num_list,
722+ int32_t * tokens_per_expert_stats_list,
723+ int64_t const num_tokens,
724+ int64_t const num_experts,
725+ int64_t const n_group,
726+ int64_t const topk_group,
727+ int64_t const topk,
728+ double const routed_scaling_factor,
729+ int64_t const redundant_ep_rank_num_plus_one,
730+ cudaStream_t const stream) {
731+ int64_t num_cases = num_tokens * n_group;
732+ int64_t topk_with_k2_num_blocks = (num_cases - 1 ) / NUM_WARPS_PER_BLOCK + 1 ;
733+ topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0 , stream>>>(
734+ group_scores,
735+ scores_with_bias,
736+ num_tokens,
737+ num_cases,
738+ n_group,
739+ num_experts / n_group);
740+
741+ int64_t topk_with_k_group_num_blocks =
742+ (num_tokens - 1 ) / NUM_WARPS_PER_BLOCK + 1 ;
743+ size_t dynamic_smem_in_bytes =
744+ warp_topk::calc_smem_size_for_block_wide<T, int32_t >(NUM_WARPS_PER_BLOCK,
745+ topk);
746+
747+ group_idx_and_topk_idx_redundant_kernel<T><<<topk_with_k_group_num_blocks,
748+ BLOCK_SIZE,
749+ dynamic_smem_in_bytes,
750+ stream>>>(scores,
751+ group_scores,
752+ topk_values,
753+ topk_indices,
754+ scores_with_bias,
755+ expert_id_to_ep_rank_array,
756+ expert_in_rank_num_list,
757+ tokens_per_expert_stats_list,
758+ num_tokens,
759+ n_group,
760+ topk_group,
761+ topk,
762+ num_experts,
763+ num_experts / n_group,
764+ routed_scaling_factor,
765+ redundant_ep_rank_num_plus_one);
766+ }
767+
556768#define INSTANTIATE_NOAUX_TC (T, IdxT ) \
557769 template void invokeNoAuxTc<T, IdxT>(T * scores, \
558770 T * group_scores, \
@@ -568,3 +780,23 @@ void invokeNoAuxTc(T* scores,
568780 cudaStream_t const stream);
569781
570782INSTANTIATE_NOAUX_TC (float , int32_t );
783+
784+ #define INSTANTIATE_NOAUX_TC_Redundant (T, IdxT ) \
785+ template void invokeNoAuxTcRedundant<T, IdxT>(T * scores, \
786+ T * group_scores, \
787+ T* topk_values, \
788+ IdxT* topk_indices, \
789+ T * scores_with_bias, \
790+ int32_t * expert_id_to_ep_rank_array, \
791+ int32_t * expert_in_rank_num_list, \
792+ int32_t * tokens_per_expert_stats_list, \
793+ int64_t const num_tokens, \
794+ int64_t const num_experts, \
795+ int64_t const n_group, \
796+ int64_t const topk_group, \
797+ int64_t const topk, \
798+ double const routed_scaling_factor, \
799+ int64_t const redundant_ep_rank_num_plus_one, \
800+ cudaStream_t const stream);
801+
802+ INSTANTIATE_NOAUX_TC_Redundant (float , int32_t );
0 commit comments