Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,18 @@ std::vector<paddle::Tensor> NoauxTc(
int topk,
float routed_scaling_factor);

std::vector<paddle::Tensor> NoauxTcRedundant(
paddle::Tensor& scores,
paddle::Tensor& scores_with_bias,
paddle::Tensor& expert_id_to_ep_rank_array,
paddle::Tensor& expert_in_rank_num_list,
paddle::Tensor& tokens_per_expert_stats_list,
int n_group,
int topk_group,
int topk,
float routed_scaling_factor,
int redundant_ep_rank_num_plus_one);

#ifdef ENABLE_FP8
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
const paddle::Tensor& x,
Expand Down Expand Up @@ -1251,6 +1263,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {

m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");

m.def("noaux_tc_redunant",&NoauxTcRedundant, "noaux_tc_redundant for MoE compute");

#ifdef ENABLE_FP8
m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func,
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),
Expand Down
92 changes: 92 additions & 0 deletions custom_ops/gpu_ops/noaux_tc_redundant.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@

// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <algorithm>
#include <optional>

#include "helper.h"
#include "noauxtc_kernel.h"

std::vector<paddle::Tensor> NoauxTcRedundant(paddle::Tensor& scores,
paddle::Tensor& scores_with_bias,
paddle::Tensor& expert_id_to_ep_rank_array,
paddle::Tensor& expert_in_rank_num_list,
paddle::Tensor& tokens_per_expert_stats_list,
int n_group,
int topk_group,
int topk,
float routed_scaling_factor,
int redundant_ep_rank_num_plus_one) {
auto input_shape = scores_with_bias.shape();
PD_CHECK(input_shape.size() == 2);
int64_t num_tokens = input_shape[0];
int64_t num_experts = input_shape[1];
auto input_type = scores_with_bias.dtype();
auto place = scores_with_bias.place();
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
auto topk_values = paddle::empty({num_tokens, topk}, input_type, place);
auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place);
auto stream = scores_with_bias.stream();

invokeNoAuxTcRedundant<float, int64_t>(reinterpret_cast<float*>(scores.data<float>()),
reinterpret_cast<float*>(group_scores.data<float>()),
reinterpret_cast<float*>(topk_values.data<float>()),
reinterpret_cast<int64_t*>(topk_indices.data<int64_t>()),
reinterpret_cast<float*>(scores_with_bias.data<float>()),
reinterpret_cast<int*>(expert_id_to_ep_rank_array.data<int>()),
reinterpret_cast<int*>(expert_in_rank_num_list.data<int>()),
reinterpret_cast<int*>(tokens_per_expert_stats_list.data<int>()),
num_tokens,
num_experts,
n_group,
topk_group,
topk,
routed_scaling_factor,
redundant_ep_rank_num_plus_one,
stream);

return {scores, topk_values, topk_indices};
}

std::vector<paddle::DataType> NoauxTcRedundantInferDtype(
const paddle::DataType& scores_dtype,
const paddle::DataType& scores_with_bias_dtype) {
return {scores_dtype, scores_dtype, paddle::DataType::INT64};
}

std::vector<std::vector<int64_t>> NoauxTcRedundantInferShape(
const std::vector<int64_t>& scores_shape,
const std::vector<int64_t>& ,
const int topk) {
auto num_tokens = scores_shape[0];
auto topk_values_shape = std::vector<int64_t>{num_tokens, topk};
auto topk_indices_shape = std::vector<int64_t>{num_tokens, topk};
return {scores_shape, topk_values_shape, topk_indices_shape};
}

PD_BUILD_STATIC_OP(noaux_tc_redundant)
.Inputs({"scores", "scores_with_bias", "expert_id_to_ep_rank_array", "expert_in_rank_num_list", "tokens_per_expert_stats_list"})
.Outputs({"output_tensor", "topk_values", "topk_indices", "tokens_per_expert_stats_list_out"})
.Attrs({"n_group: int",
"topk_group: int",
"topk:int",
"routed_scaling_factor: float",
"redundant_ep_rank_num_plus_one:int"})
.SetInplaceMap({{"tokens_per_expert_stats_list", "tokens_per_expert_stats_list_out"}})
.SetKernelFn(PD_KERNEL(NoauxTcRedundant))
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcRedundantInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(NoauxTcRedundantInferDtype));
232 changes: 232 additions & 0 deletions custom_ops/gpu_ops/noauxtc_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,14 @@ class WarpSelect : public WarpSort<capacity, greater, T, idxT> {
}; // end class WarpSelect
} // namespace warp_topk


inline __device__ unsigned int xorwow_moe(unsigned int &state) {
state ^= state >> 7;
state ^= state << 9;
state ^= state >> 13;
return state;
}

template <typename T>
__device__ void topk_with_k2(T* output,
T const* input,
Expand Down Expand Up @@ -507,6 +515,156 @@ __global__ void group_idx_and_topk_idx_kernel(
}
}

template <typename T, typename IdxT>
__global__ void group_idx_and_topk_idx_redundant_kernel(
T* scores,
T const* group_scores,
T* topk_values,
IdxT* topk_indices,
T* scores_with_bias,
int32_t* expert_id_to_ep_rank_array,
int32_t* expert_in_rank_num_list,
int32_t* tokens_per_expert_stats_list,
int64_t const num_tokens,
int64_t const n_group,
int64_t const topk_group,
int64_t const topk,
int64_t const num_experts,
int64_t const num_experts_per_group,
double routed_scaling_factor,
int64_t const redundant_ep_rank_num_plus_one) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id =
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
unsigned int state = case_id;
scores_with_bias += case_id * num_experts;
scores += case_id * num_experts;
group_scores += case_id * n_group;
topk_values += case_id * topk;
topk_indices += case_id * topk;
int32_t align_num_experts_per_group =
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);

cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);

extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
// store the target topk idx
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf) + warp_id * topk;
T* s_topk_value =
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
warp_id * topk;

T value = cuda::std::numeric_limits<T>::min();
T topk_group_value = cuda::std::numeric_limits<T>::min();
int32_t num_equalto_topkth_group;

if ((n_group > topk_group) && (case_id < num_tokens)) {
// calculate group_idx
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
if (lane_id < n_group) {
value = group_scores[lane_id];
}

int count_equal_to_top_value = WARP_SIZE - n_group;
int pre_count_equal_to_top_value = 0;
// Use loop to find the largset top_group
while (count_equal_to_top_value < target_num_min) {
__syncwarp(); // Ensure all threads have valid data before reduction
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
if (value == topk_group_value) {
value = cuda::std::numeric_limits<T>::min();
}
pre_count_equal_to_top_value = count_equal_to_top_value;
count_equal_to_top_value = __popc(__ballot_sync(
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
}
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
}
__syncthreads();

warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t>
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());

int count_equalto_topkth_group = 0;
bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits<T>::min());
if (case_id < num_tokens) {
for (int i_group = 0; i_group < n_group; i_group++) {
if ((group_scores[i_group] > topk_group_value) ||
((group_scores[i_group] == topk_group_value) &&
(count_equalto_topkth_group < num_equalto_topkth_group))) {
int32_t offset = i_group * num_experts_per_group;
for (int32_t i = lane_id; i < align_num_experts_per_group;
i += WARP_SIZE) {
T candidates = i < num_experts_per_group
? scores_with_bias[offset + i]
: cuda::std::numeric_limits<T>::min();
queue.add(candidates, offset + i);
}
if (group_scores[i_group] == topk_group_value) {
count_equalto_topkth_group++;
}
}
}
queue.done();
__syncwarp();
// Get the topk_idx
queue.dumpIdx(s_topk_idx);
__syncwarp();
}

// Load the valid score value
// Calculate the summation
float topk_sum = 1e-20;
if (case_id < num_tokens) {
for (int i = lane_id;
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
i += WARP_SIZE) {
T value = i < topk ? scores[s_topk_idx[i]]
: 0.0f; // Load the valid value of expert
if (i < topk) {
s_topk_value[i] = value;
}
topk_sum += reduce(tile, value, cg::plus<float>());
}
}

__syncthreads();
if (case_id < num_tokens) {
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
scores[i] = 0;
}
}
__threadfence();
__syncthreads();

if (case_id < num_tokens) {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
scores[s_topk_idx[i]] = value;
if (if_proceed_next_topk) {
int expert_topk = s_topk_idx[i];
int len = expert_in_rank_num_list[expert_topk];
int select = (int)xorwow_moe(state) % len;
int selected_rank = expert_id_to_ep_rank_array[expert_topk * redundant_ep_rank_num_plus_one + select];
atomicAdd(&tokens_per_expert_stats_list[expert_topk], 1);
topk_indices[i] = (IdxT)selected_rank;
topk_values[i] = static_cast<T>(value);
}
else {
int expert_topk = i;
int len = expert_in_rank_num_list[expert_topk];
int select = (int)xorwow_moe(state) % len;
int selected_rank = expert_id_to_ep_rank_array[expert_topk * redundant_ep_rank_num_plus_one + select];
atomicAdd(&tokens_per_expert_stats_list[expert_topk], 1);
topk_indices[i] = (IdxT)selected_rank;
topk_values[i] = static_cast<float>(1.0f / topk);
}
}
}
}

template <typename T, typename IdxT>
void invokeNoAuxTc(T* scores,
T* group_scores,
Expand Down Expand Up @@ -553,6 +711,60 @@ void invokeNoAuxTc(T* scores,
routed_scaling_factor);
}

template <typename T, typename IdxT>
void invokeNoAuxTcRedundant(T* scores,
T* group_scores,
T* topk_values,
IdxT* topk_indices,
T* scores_with_bias,
int32_t* expert_id_to_ep_rank_array,
int32_t* expert_in_rank_num_list,
int32_t* tokens_per_expert_stats_list,
int64_t const num_tokens,
int64_t const num_experts,
int64_t const n_group,
int64_t const topk_group,
int64_t const topk,
double const routed_scaling_factor,
int64_t const redundant_ep_rank_num_plus_one,
cudaStream_t const stream) {
int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
group_scores,
scores_with_bias,
num_tokens,
num_cases,
n_group,
num_experts / n_group);

int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
size_t dynamic_smem_in_bytes =
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
topk);

group_idx_and_topk_idx_redundant_kernel<T><<<topk_with_k_group_num_blocks,
BLOCK_SIZE,
dynamic_smem_in_bytes,
stream>>>(scores,
group_scores,
topk_values,
topk_indices,
scores_with_bias,
expert_id_to_ep_rank_array,
expert_in_rank_num_list,
tokens_per_expert_stats_list,
num_tokens,
n_group,
topk_group,
topk,
num_experts,
num_experts / n_group,
routed_scaling_factor,
redundant_ep_rank_num_plus_one);
}

#define INSTANTIATE_NOAUX_TC(T, IdxT) \
template void invokeNoAuxTc<T, IdxT>(T * scores, \
T * group_scores, \
Expand All @@ -568,3 +780,23 @@ void invokeNoAuxTc(T* scores,
cudaStream_t const stream);

INSTANTIATE_NOAUX_TC(float, int32_t);

#define INSTANTIATE_NOAUX_TC_Redundant(T, IdxT) \
template void invokeNoAuxTcRedundant<T, IdxT>(T * scores, \
T * group_scores, \
T* topk_values, \
IdxT* topk_indices, \
T * scores_with_bias, \
int32_t* expert_id_to_ep_rank_array, \
int32_t* expert_in_rank_num_list, \
int32_t* tokens_per_expert_stats_list, \
int64_t const num_tokens, \
int64_t const num_experts, \
int64_t const n_group, \
int64_t const topk_group, \
int64_t const topk, \
double const routed_scaling_factor, \
int64_t const redundant_ep_rank_num_plus_one, \
cudaStream_t const stream);

INSTANTIATE_NOAUX_TC_Redundant(float, int32_t);
1 change: 1 addition & 0 deletions custom_ops/setup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def find_end_files(directory, end_str):
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/noaux_tc.cu",
"gpu_ops/noaux_tc_redundant.cu",
"gpu_ops/custom_all_reduce/all_reduce.cu",
"gpu_ops/merge_prefill_decode_output.cu",
]
Expand Down
Loading
Loading