From 78e51a31c5ef269cb091bb65318a4de187dc4718 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 30 Oct 2025 20:14:27 +0800 Subject: [PATCH 01/16] add flops and bandwidth to test_ffn.py --- tests/layers/test_ffn.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/layers/test_ffn.py b/tests/layers/test_ffn.py index 9e24630531f..ffb96060b9e 100644 --- a/tests/layers/test_ffn.py +++ b/tests/layers/test_ffn.py @@ -125,7 +125,7 @@ def test_ffn(self): moe_cuda_graphs = [None] * 100 cache_hidden_states = [None] * 100 - for idx, num_tokens in enumerate([10, 20, 40, 60, 80, 100, 128, 160, 192, 256]): + for idx, num_tokens in enumerate([10, 20, 40, 60, 80, 100, 128, 160, 192, 256, 512, 1024, 2048, 4096]): cache_hidden_states[idx] = paddle.rand((num_tokens, self.model_config.hidden_size), dtype=paddle.bfloat16) @@ -153,6 +153,14 @@ def test_ffn(self): print("num_tokens:", num_tokens) print(times[-5:]) + flops = num_layers * 2 * num_tokens * self.model_config.hidden_size * ffn.intermediate_size * 3 + memory = num_layers * self.model_config.hidden_size * ffn.intermediate_size * 3 + # memory += (num_layers * num_tokens * ffn.intermediate_size * 2) + + print(round(flops / times[-1] / (1024**3), 1), "TFLOPS") + + print(round(memory / times[-1] / (1024**3), 1), "TB/s") + shutil.rmtree(self.model_name_or_path) return out From f67981cee1fc75af1fbff733e0feae7fdcfe5cb9 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 30 Oct 2025 21:36:31 +0800 Subject: [PATCH 02/16] add flops and bandwidth to test_ffn.py --- tests/layers/test_fusedmoe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/layers/test_fusedmoe.py b/tests/layers/test_fusedmoe.py index ae2e1a4b631..221c3d2476d 100644 --- a/tests/layers/test_fusedmoe.py +++ b/tests/layers/test_fusedmoe.py @@ -612,9 +612,9 @@ def fake_model_run(): times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:] print("num_token:", num_tokens) print(times[-5:]) - GB = 1.0 * num_tokens * self.moe_k * self.hidden_size * 3.0 / (1e9) + rdma_GB = 3.0 * num_tokens * self.moe_k * self.hidden_size / (1e9) times_s = (times[-1] / num_layers) / (1e3) - print(times[-1], round(GB / times_s, 1)) + print(times[-1], round(rdma_GB / times_s, 1)) shutil.rmtree(self.model_name_or_path) From d65fd873f69a0dee27cf110353a67cee3001549e Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 30 Oct 2025 21:53:00 +0800 Subject: [PATCH 03/16] add flops and bandwidth to test_ffn.py --- tests/layers/test_fusedmoe.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/tests/layers/test_fusedmoe.py b/tests/layers/test_fusedmoe.py index 221c3d2476d..c64ff0dd0ee 100644 --- a/tests/layers/test_fusedmoe.py +++ b/tests/layers/test_fusedmoe.py @@ -447,6 +447,7 @@ def __init__( ), quant_config=BlockWiseFP8Config(weight_block_size=[128, 128]), # quant_config=WINT8Config({}), + # quant_config=WINT4Config({}), scheduler_config=SchedulerConfig({}), cache_config=CacheConfig({}), graph_opt_config=GraphOptimizationConfig({}), @@ -458,7 +459,7 @@ def __init__( self.fd_config.parallel_config.expert_parallel_size = self.ep_size if self.ep_size > 1: self.fd_config.parallel_config.ep_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() - self.fd_config.scheduler_config.splitwise_role = "decode" + self.fd_config.scheduler_config.splitwise_role = "mixed" self.fd_config.model_config.moe_phase.phase = "decode" weight_key_map = { @@ -573,7 +574,11 @@ def test_fused_moe(self): # 这行代码必须保留,否则影响均匀性! paddle.seed(ep_rank + 100) - fused_moe = FuseMoEWrapper(self.model_config, tp_size, tp_rank, ep_size, ep_rank, nnodes=nnodes) + num_layers = 80 + real_weight_layers = 20 + fused_moe = [None] * real_weight_layers + for i in range(real_weight_layers): + fused_moe[i] = FuseMoEWrapper(self.model_config, tp_size, tp_rank, ep_size, ep_rank, nnodes=nnodes) moe_cuda_graphs = [None] * 100 cache_hidden_states = [None] * 100 @@ -583,11 +588,9 @@ def test_fused_moe(self): cache_hidden_states[idx] = paddle.rand((num_tokens, self.model_config.hidden_size), dtype=paddle.bfloat16) - num_layers = 80 - def fake_model_run(): - for _ in range(num_layers): - out = fused_moe.fused_moe(cache_hidden_states[idx], gating) + for j in range(num_layers): + out = fused_moe[j % real_weight_layers].fused_moe(cache_hidden_states[idx], gating) return out @@ -616,6 +619,17 @@ def fake_model_run(): times_s = (times[-1] / num_layers) / (1e3) print(times[-1], round(rdma_GB / times_s, 1)) + tmp_layer = fused_moe[0].fused_moe + memory_GB = ( + tmp_layer.num_local_experts + * tmp_layer.hidden_size + * tmp_layer.moe_intermediate_size + * 3 + / (1e9) + * num_layers + ) + print(round(memory_GB / times[-1], 1), "TB/s") + shutil.rmtree(self.model_name_or_path) From 5263e1079b0761df0fbc10b01bb82523f8ef32fa Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 31 Oct 2025 13:24:21 +0800 Subject: [PATCH 04/16] commit --- .../multiquery_attention_c16_impl.cuh | 172 ++++++++++-------- 1 file changed, 92 insertions(+), 80 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index 99cc613d8ba..8026c4de4d1 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -30,13 +30,14 @@ template __global__ void multi_query_append_attention_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] - T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + const T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * // head_dim] - T *__restrict__ cache_v, + const T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + const T *__restrict__ cache_v, const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const T *__restrict__ sinks, // [q_num_heads] + const T *__restrict__ sinks, // [q_num_heads] const int *__restrict__ seq_lens, const int *__restrict__ seq_lens_kv, const int *__restrict__ batch_ids, @@ -60,6 +61,9 @@ __global__ void multi_query_append_attention_kernel( OutT *__restrict__ out, const int speculate_max_draft_token_num = 5, const int sliding_window = 0) { + static_assert(num_frags_y * 16 == HEAD_DIM); + static_assert(num_frags_z * 16 == BLOCK_SIZE); + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; const uint32_t kv_num_heads = gridDim.z; @@ -76,8 +80,8 @@ __global__ void multi_query_append_attention_kernel( block_table_now = block_table + batch_id * max_block_num_per_seq; - //When cudagraph capture prefill, may launch more gridDim.x - if(btid >= static_cast(num_blocks_x_cpu)){ + // When cudagraph capture prefill, may launch more gridDim.x + if (btid >= static_cast(num_blocks_x_cpu)) { return; } @@ -131,7 +135,7 @@ __global__ void multi_query_append_attention_kernel( const uint32_t o_offset = q_start_seq_id * q_n_stride + q_head_idx * HEAD_DIM + tid % 8 * num_elems_per_128b(); - T *q_base_ptr = q + q_offset; + const T *q_base_ptr = q + q_offset; T *o_base_ptr_T = nullptr; OutT *o_base_ptr_int8 = nullptr; if constexpr (partition_kv) { @@ -149,7 +153,8 @@ __global__ void multi_query_append_attention_kernel( } else { o_base_ptr_int8 = out + o_offset; } - const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + const int *mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; smem_t qo_smem(smem); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( @@ -172,7 +177,6 @@ __global__ void multi_query_append_attention_kernel( v_smem(smem + (NUM_WARPS * num_frags_x + num_frags_z) * 16 * HEAD_DIM * sizeof(T)); - const uint32_t num_iterations = div_up( CAUSAL ? (min(chunk_len, @@ -183,12 +187,13 @@ __global__ void multi_query_append_attention_kernel( : chunk_len, num_frags_z * 16); const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, + (CAUSAL ? (min(chunk_len, sub_if_greater_or_zero( kv_len - q_len + tile_id * num_rows_per_block / GROUP_SIZE, chunk_start))) - : mask_offset ? 0 : chunk_len) / + : mask_offset ? 0 + : chunk_len) / (num_frags_z * 16); uint32_t k_smem_offset_r = smem_t::get_permuted_offset( 8 * (tid / 16) + tid % 8, (tid % 16) / 8); @@ -204,8 +209,8 @@ __global__ void multi_query_append_attention_kernel( const uint32_t const_offset = kv_head_idx * kv_h_stride + (wid * 4 + tid / 8) * kv_b_stride + tid % 8 * num_elems_per_128b(); - T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; - T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + const T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + const T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; produce_kv_blockwise(); __syncthreads(); - if constexpr (!partition_kv ) { + if constexpr (!partition_kv) { if (sinks) { float current_sinks[num_frags_x][2]; - #pragma unroll +#pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - #pragma unroll +#pragma unroll for (uint32_t j = 0; j < 2; ++j) { - const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE; - current_sinks[fx][j] = static_cast(sinks[q_head_idx + h_offset]); + const uint32_t h_offset = + (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % + GROUP_SIZE; + current_sinks[fx][j] = + static_cast(sinks[q_head_idx + h_offset]); } } - normalize_d(o_frag, d_frag, m_frag, current_sinks); + normalize_d( + o_frag, d_frag, m_frag, current_sinks); } else { normalize_d(o_frag, d_frag); } @@ -375,7 +383,6 @@ __global__ void multi_query_append_attention_kernel( HEAD_DIM); } - if constexpr (partition_kv) { #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { @@ -421,13 +428,13 @@ template __global__ void multi_query_append_attention_warp1_4_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, // head_dim] T *__restrict__ cache_v, const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const T *__restrict__ sinks, // [q_num_heads] + const T *__restrict__ sinks, // [q_num_heads] const int *__restrict__ seq_lens, const int *__restrict__ seq_lens_kv, const int *__restrict__ batch_ids, @@ -435,7 +442,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ mask_offset, - const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask + const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask const int max_seq_len, const int max_dec_len, const int max_block_num_per_seq, @@ -469,8 +476,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const uint32_t num_rows_per_block = num_frags_x * 16; const int *block_table_now = block_table + batch_id * max_block_num_per_seq; - //When cudagraph capture prefill, may launch more gridDim.x - if(btid >= static_cast(num_blocks_x_cpu)){ + // When cudagraph capture prefill, may launch more gridDim.x + if (btid >= static_cast(num_blocks_x_cpu)) { return; } @@ -540,7 +547,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel( tid % 8 * num_elems_per_128b(); } } - const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + const int *mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; smem_t qo_smem(smem); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( @@ -576,11 +584,10 @@ __global__ void multi_query_append_attention_warp1_4_kernel( : chunk_len, NUM_WARP_KV * num_frags_z * 16); const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len, - chunk_start))) - : mask_offset ? 0 : chunk_len) / + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero(kv_len - q_len, chunk_start))) + : mask_offset ? 0 + : chunk_len) / (NUM_WARP_KV * num_frags_z * 16); uint32_t k_smem_offset_r = smem_t::get_permuted_offset( @@ -648,16 +655,18 @@ __global__ void multi_query_append_attention_warp1_4_kernel( NUM_WARPS, num_frags_x, num_frags_y, - num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr, - q_base_seq_id_this_block, - kv_idx_base + wid * num_frags_z * 16, - q_len, - kv_len, - chunk_end, - attn_mask_len, - s_frag, - mask_offset_this_seq, - sliding_window); + num_frags_z>( + attn_mask ? attn_mask + batch_id * attn_mask_len * attn_mask_len + : nullptr, + q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + attn_mask_len, + s_frag, + mask_offset_this_seq, + sliding_window); } // update m,d @@ -720,15 +729,19 @@ __global__ void multi_query_append_attention_warp1_4_kernel( if (num_chunks_this_seq <= 1) { if (sinks) { float current_sinks[num_frags_x][2]; - #pragma unroll +#pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - #pragma unroll +#pragma unroll for (uint32_t j = 0; j < 2; ++j) { - const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE; - current_sinks[fx][j] = static_cast(sinks[q_head_idx + h_offset]); + const uint32_t h_offset = + (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % + GROUP_SIZE; + current_sinks[fx][j] = + static_cast(sinks[q_head_idx + h_offset]); } } - normalize_d(o_frag, d_frag, m_frag, current_sinks); + normalize_d( + o_frag, d_frag, m_frag, current_sinks); } else { normalize_d(o_frag, d_frag); } @@ -933,8 +946,8 @@ void MultiQueryAppendAttention( const_cast(smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, seq_lens_q.data(), seq_lens_kv.data(), batch_ids.data(), @@ -996,8 +1009,8 @@ void MultiQueryAppendAttention( const_cast(smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, seq_lens_q.data(), seq_lens_kv.data(), batch_ids.data(), @@ -1048,8 +1061,8 @@ void MultiQueryAppendAttention( smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, reinterpret_cast(out->data()), quant_max_bound, quant_min_bound, @@ -1087,8 +1100,8 @@ void MultiQueryAppendAttention( smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, reinterpret_cast(out->data()), quant_max_bound, quant_min_bound, @@ -1138,9 +1151,9 @@ void MultiQueryAppendAttention( uint32_t attn_mask_len; if (attn_mask) { - attn_mask_len = attn_mask.get().shape()[1]; + attn_mask_len = attn_mask.get().shape()[1]; } else { - attn_mask_len = -1; + attn_mask_len = -1; } const int num_chunks = div_up(max_seq_len, chunk_size); @@ -1179,8 +1192,8 @@ void MultiQueryAppendAttention( const_cast(smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, seq_lens_q.data(), seq_lens_kv.data(), batch_ids.data(), @@ -1189,7 +1202,7 @@ void MultiQueryAppendAttention( block_table.data(), meta_data.mask_offset, attn_mask ? const_cast(attn_mask.get().data()) - : nullptr, + : nullptr, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1250,14 +1263,14 @@ void MultiQueryAppendAttention( reinterpret_cast(const_cast(cache_k.data())), reinterpret_cast(const_cast(cache_v.data())), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast( const_cast(smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, seq_lens_q.data(), seq_lens_kv.data(), batch_ids.data(), @@ -1266,7 +1279,7 @@ void MultiQueryAppendAttention( block_table.data(), meta_data.mask_offset, attn_mask ? const_cast(attn_mask.get().data()) - : nullptr, + : nullptr, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1306,14 +1319,14 @@ void MultiQueryAppendAttention( seq_lens_encoder.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, reinterpret_cast(out->data()), quant_max_bound, quant_min_bound, @@ -1326,15 +1339,14 @@ void MultiQueryAppendAttention( } else { constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); dim3 blocks_merge(blockx, blocky); merge_multi_chunks_v2_kernel + vec_size, + blocky, + HEAD_DIM, + OUT_NV_TYPE, + ENABLE_PREFILL> <<>>( reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), @@ -1345,14 +1357,14 @@ void MultiQueryAppendAttention( batch_id_per_token.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, reinterpret_cast(out->data()), quant_max_bound, quant_min_bound, From a769971d23ce8121528dee345bdacfeb086851fc Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 31 Oct 2025 13:43:34 +0800 Subject: [PATCH 05/16] commit --- custom_ops/gpu_ops/append_attn/append_attention_func.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index c4afa3d1c2b..e2f45ad2dec 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -171,7 +171,7 @@ template __device__ __forceinline__ void load_q_global_smem( - T* q_ptr_base, + const T* q_ptr_base, smem_t* q_smem, uint32_t q_idx_base, const uint32_t qo_upper_bound, @@ -194,7 +194,7 @@ __device__ __forceinline__ void load_q_global_smem( const uint32_t offset_now = base_offset + j * 4; const uint32_t n_offset = offset_now / group_size; const uint32_t h_offset = offset_now % group_size; - T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; + const T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; #pragma unroll for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { From 6335c68af21f3633e28f45556a67a1ddebfc3070 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 31 Oct 2025 13:46:14 +0800 Subject: [PATCH 06/16] commit --- .../append_attn/append_attention_func.cuh | 321 +++++++++--------- 1 file changed, 156 insertions(+), 165 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index e2f45ad2dec..f09dbb99d56 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -142,7 +142,6 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps( const uint32_t tx_offset = tx / 8; #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset; #pragma unroll const int j = ty; @@ -151,8 +150,7 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps( const uint32_t h_offset = offset_now % group_size; T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { q_smem->load_128b_async( q_smem_offset_w, q_ptr, n_offset < qo_upper_bound); q_smem_offset_w = @@ -194,10 +192,10 @@ __device__ __forceinline__ void load_q_global_smem( const uint32_t offset_now = base_offset + j * 4; const uint32_t n_offset = offset_now / group_size; const uint32_t h_offset = offset_now % group_size; - const T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; + const T* q_ptr = + q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { q_smem->load_128b_async( q_smem_offset_w, q_ptr, n_offset < qo_upper_bound); q_smem_offset_w = @@ -223,8 +221,7 @@ __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps( constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); #pragma unroll - for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; - ++i) { + for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; ++i) { const int offset = i * 1024 + ty * 256 + tx * 8; Load(reinterpret_cast(q_smem->base) + offset, &tmp_vec); #pragma unroll @@ -289,11 +286,9 @@ __device__ __forceinline__ void produce_kv_blockwise( const uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; // kv_idx used to check #pragma unroll - for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps; - ++i) { + for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps; ++i) { #pragma unroll - for (uint32_t j = 0; j < num_frags_y / 4; - ++j) { + for (uint32_t j = 0; j < num_frags_y / 4; ++j) { smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); *smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j); *gptr += 8 * num_elems_per_128b(); @@ -332,9 +327,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8( block_size / num_elems_per_128b(); // 8 constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; const uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t kv_idx = - kv_idx_base + - tx % 4 * num_elems_per_128b(); + uint32_t kv_idx = kv_idx_base + tx % 4 * num_elems_per_128b(); if constexpr (NUM_WARP_Q == 4) { int block_id = __ldg(&block_table_now[kv_idx / block_size]); if (block_id < 0) block_id = 0; @@ -343,8 +336,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8( for (uint32_t i = 0; i < num_frags_y * 2 / num_warps; ++i) { // m (num_frags_y * 16 / (num_warps * 8)) #pragma unroll - for (uint32_t j = 0; j < num_frags_z / 4; - ++j) { + for (uint32_t j = 0; j < num_frags_z / 4; ++j) { smem.load_128b_async(*smem_offset, cache_v_now, true); *smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>( *smem_offset, j); @@ -369,8 +361,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8( for (uint32_t i = 0; i < num_frags_y * 2 / num_warps; ++i) { // m (num_frags_y * 16 / (num_warps * 8)) #pragma unroll - for (uint32_t j = 0; j < 2 * num_frags_z / 4; - ++j) { + for (uint32_t j = 0; j < 2 * num_frags_z / 4; ++j) { smem.load_128b_async(*smem_offset, cache_v_now, true); *smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>( @@ -392,27 +383,28 @@ __device__ __forceinline__ void produce_v_blockwise_c8( } } -template +template __device__ __forceinline__ void produce_k_dynamic_scale( - T* k_smem_scale, - T* cache_k_reg, - const int* block_table_now, - const T* cache_k_scale, - const uint32_t kv_idx, - const uint32_t kv_num_heads, - const uint32_t kv_head_idx, - const uint32_t chunk_end -) { + T* k_smem_scale, + T* cache_k_reg, + const int* block_table_now, + const T* cache_k_scale, + const uint32_t kv_idx, + const uint32_t kv_num_heads, + const uint32_t kv_head_idx, + const uint32_t chunk_end) { const uint32_t tx = threadIdx.x, ty = threadIdx.y; if constexpr (NUM_WARP_Q == 4) { // 4 warps shared block_size const uint32_t tid = ty * 32 + tx; int block_id = __ldg(&block_table_now[kv_idx / block_size]); if (block_id < 0) block_id = 0; - const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size; + const T* cache_k_scale_now = cache_k_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size; if (tid < block_size) { k_smem_scale[tid] = cache_k_scale_now[tid]; } @@ -427,10 +419,12 @@ __device__ __forceinline__ void produce_k_dynamic_scale( const uint32_t kv_idx_now = kv_idx + block_size * ty / 2; int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); if (block_id < 0) block_id = 0; - const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size; + const T* cache_k_scale_now = cache_k_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size; const int kv_idx_this_thread = kv_idx + ty * 32 + tx; if (kv_idx_this_thread < chunk_end) { - k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx]; + k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx]; } else { k_smem_scale[ty * 32 + tx] = 0; } @@ -443,20 +437,19 @@ __device__ __forceinline__ void produce_k_dynamic_scale( } } -template +template __device__ __forceinline__ void produce_v_dynamic_scale( - T* v_smem_scale, - T* cache_v_reg, - const int* block_table_now, - const T* cache_v_scale, - const uint32_t kv_idx, - const uint32_t kv_num_heads, - const uint32_t kv_head_idx, - const uint32_t chunk_end -) { + T* v_smem_scale, + T* cache_v_reg, + const int* block_table_now, + const T* cache_v_scale, + const uint32_t kv_idx, + const uint32_t kv_num_heads, + const uint32_t kv_head_idx, + const uint32_t chunk_end) { const uint32_t tx = threadIdx.x, ty = threadIdx.y; if constexpr (NUM_WARP_Q == 4) { @@ -464,7 +457,9 @@ __device__ __forceinline__ void produce_v_dynamic_scale( const uint32_t tid = ty * 32 + tx; int block_id = __ldg(&block_table_now[kv_idx / block_size]); if (block_id < 0) block_id = 0; - const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size; + const T* cache_v_scale_now = cache_v_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size; if (tid < block_size) { v_smem_scale[tid] = cache_v_scale_now[tid]; } @@ -481,10 +476,12 @@ __device__ __forceinline__ void produce_v_dynamic_scale( const uint32_t kv_idx_now = kv_idx + block_size * ty / 2; int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); if (block_id < 0) block_id = 0; - const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size; + const T* cache_v_scale_now = cache_v_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size; const int kv_idx_this_thread = kv_idx + ty * 32 + tx; if (kv_idx_this_thread < chunk_end) { - v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx]; + v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx]; } else { v_smem_scale[ty * 32 + tx] = 0; } @@ -560,8 +557,7 @@ __device__ __forceinline__ void produce_k_blockwise_c8( for (uint32_t i = 0; i < 2 * num_frags_z * 4 / num_warps; ++i) { // m num_frags_z * 16 / (num_warps * 4) #pragma unroll - for (uint32_t j = 0; j < num_frags_y / 8; - ++j) { + for (uint32_t j = 0; j < num_frags_y / 8; ++j) { smem.load_128b_async(*smem_offset, cache_k_now, true); *smem_offset = smem.advance_offset_by_column<8, num_vecs_per_head>( *smem_offset, j); @@ -614,8 +610,7 @@ __device__ __forceinline__ void produce_v_blockwise_c4( #pragma unroll for (uint32_t i = 0; i < num_frags_y / num_warps; ++i) { // m #pragma unroll - for (uint32_t j = 0; j < num_frags_z / 4; - ++j) { + for (uint32_t j = 0; j < num_frags_z / 4; ++j) { smem.load_128b_async(*smem_offset, cache_v_now, true); *smem_offset = smem.advance_offset_by_column<2, num_vecs_per_blocksize>( *smem_offset, j); @@ -671,8 +666,7 @@ __device__ __forceinline__ void produce_k_blockwise_c4( for (uint32_t i = 0; i < num_frags_z * 2 / num_warps; ++i) { // m num_frags_z * 16 / (num_warps * 8) #pragma unroll - for (uint32_t j = 0; j < num_frags_y / 8; - ++j) { + for (uint32_t j = 0; j < num_frags_y / 8; ++j) { smem.load_128b_async(*smem_offset, cache_k_now, true); *smem_offset = smem.advance_offset_by_column<4, num_vecs_per_head>( *smem_offset, j); @@ -937,7 +931,7 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem, uint32_t* q_smem_offset_r, smem_t* k_smem, uint32_t* k_smem_offset_r, - const T *cache_k_scale, + const T* cache_k_scale, float (*s_frag)[num_frags_z][8]) { constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t num_vecs_per_head_q = head_dim / num_elems_per_128b(); @@ -973,8 +967,8 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem, #pragma unroll for (uint32_t fy = 0; fy < 2; ++fy) { T* b_frag_dq_T = reinterpret_cast(b_frag_dq); - convert_c8(b_frag_dq_T, b_frag[fy * 2]); - convert_c8(b_frag_dq_T + 4, b_frag[fy * 2 + 1]); + convert_c8(b_frag_dq_T, b_frag[fy * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fy * 2 + 1]); // scale zp if constexpr (!IsDynamicC8) { if constexpr (is_scale_channel_wise) { @@ -1036,7 +1030,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask, const uint32_t chunk_end, const uint32_t attn_mask_len, float (*s_frag)[num_frags_z][8], - const int *mask_offset = nullptr, + const int* mask_offset = nullptr, const int sliding_window = 0) { const uint32_t tx = threadIdx.x; #pragma unroll @@ -1053,24 +1047,25 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask, 8 * (reg_id / 4) + reg_id % 2; bool out_of_boundary; if (mask_offset) { - out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true; - } - else if (sliding_window > 0) - { - bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - (int)qo_len - sliding_window; - out_of_boundary = - (causal - ? (kv_idx > kv_len + q_idx - qo_len || out_of_window || (kv_idx >= chunk_end)) - : kv_idx >= chunk_end); - } - else - { - out_of_boundary = - (causal - ? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end)) - : kv_idx >= chunk_end); - if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) { - const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len; + out_of_boundary = q_idx < qo_len + ? (kv_idx >= mask_offset[q_idx * 2 + 1] || + kv_idx < mask_offset[q_idx * 2]) + : true; + } else if (sliding_window > 0) { + bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - + (int)qo_len - + sliding_window; + out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len || + out_of_window || (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + } else { + out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len || + (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + if (attn_mask != nullptr && kv_idx > kv_len - qo_len && + kv_idx < chunk_end && q_idx < attn_mask_len) { + const int32_t mask_idx = + q_idx * attn_mask_len + kv_idx - kv_len + qo_len; bool mask = attn_mask[mask_idx]; out_of_boundary |= mask; } @@ -1236,7 +1231,7 @@ __device__ __forceinline__ void compute_sfm_v_c8( float (*s_frag)[num_frags_z][8], float (*o_frag)[num_frags_y][8], float (*d)[2], - const T *cache_v_scale) { + const T* cache_v_scale) { constexpr uint32_t num_vecs_per_blocksize = block_size / num_elems_per_128b(); T s_frag_f16[num_frags_x][num_frags_z][8]; @@ -1268,8 +1263,8 @@ __device__ __forceinline__ void compute_sfm_v_c8( #pragma unroll for (uint32_t fz = 0; fz < 2; ++fz) { T* b_frag_dq_T = reinterpret_cast(b_frag_dq); - convert_c8(b_frag_dq_T, b_frag[fz * 2]); - convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); + convert_c8(b_frag_dq_T, b_frag[fz * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); // scale zp if constexpr (!IsDynamicC8) { if constexpr (is_scale_channel_wise) { @@ -1300,7 +1295,6 @@ __device__ __forceinline__ void compute_sfm_v_c8( o_frag[fx][fy], (uint32_t*)(s_frag_f16[fx][kz * 2 + fz]), b_frag_dq); - } } } @@ -1328,7 +1322,7 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( float (*s_frag)[num_frags_z][8], float (*o_frag)[num_frags_y][8], float (*d)[2], - T *cache_v_scale) { + T* cache_v_scale) { constexpr uint32_t num_vecs_per_blocksize = block_size / num_elems_per_128b(); @@ -1362,8 +1356,8 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( for (uint32_t fz = 0; fz < 2; ++fz) { // dequant b_frag -> b_frag_dq T* b_frag_dq_T = reinterpret_cast(b_frag_dq); - convert_c8(b_frag_dq_T, b_frag[fz * 2]); - convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); + convert_c8(b_frag_dq_T, b_frag[fz * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); // scale zp if constexpr (!IsDynamicC8) { if constexpr (is_scale_channel_wise) { @@ -1372,7 +1366,7 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; } } else { - #pragma unroll +#pragma unroll for (uint32_t b_i = 0; b_i < 8; ++b_i) { b_frag_dq_T[b_i] *= cache_v_scale[0]; } @@ -1431,8 +1425,7 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, } #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; - ++fz) { + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { #pragma unroll for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t b_frag[4]; @@ -1611,10 +1604,9 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps( for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t o_frag_f16[4]; vec_cast((T*)o_frag_f16, o_frag[fx][fy]); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset< - num_vecs_per_head>( - fx * 16 + tx / 4, - fy * 2); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(fx * 16 + tx / 4, + fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; @@ -1627,8 +1619,8 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps( } __syncthreads(); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset( - ty * 4 + tx / 8, tx % 8); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(ty * 4 + tx / 8, tx % 8); o_idx_base += (tx / 8) / group_size; o_ptr_base += ((tx / 8) / group_size) * qo_n_stride + @@ -1642,8 +1634,7 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps( T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride + ((fx * 16 + j * 4) % group_size) * qo_h_stride; #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { if (o_idx < qo_upper_bound) { // need write o_smem->store_128b(o_smem_offset_w, o_ptr); @@ -1658,7 +1649,6 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps( } } - template struct StoreFunc { __device__ __forceinline__ void operator()( @@ -1717,7 +1707,6 @@ struct StoreFunc { } }; - template struct StoreFunc { __device__ __forceinline__ void operator()( @@ -1770,10 +1759,9 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant( for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t o_frag_f16[4]; vec_cast((T*)o_frag_f16, o_frag[fx][fy]); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset< - num_vecs_per_head>( - fx * 16 + tx / 4, - fy * 2); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(fx * 16 + tx / 4, + fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; @@ -1786,8 +1774,8 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant( } __syncthreads(); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset( - ty * 4 + tx / 8, tx % 8); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(ty * 4 + tx / 8, tx % 8); const uint32_t tx_offset = tx / 8; #pragma unroll @@ -1804,8 +1792,7 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant( uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim + tx % 8 * num_elems_per_128b(); #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { if (n_offset < qo_upper_bound) { if constexpr (!partition_kv) { Load( @@ -1881,10 +1868,8 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant( for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t o_frag_f16[4]; vec_cast((T*)o_frag_f16, o_frag[fx][fy]); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset< - num_vecs_per_head>( - (ty * num_frags_x + fx) * 16 + tx / 4, - fy * 2); + uint32_t o_smem_offset_w = smem_t::get_permuted_offset( + (ty * num_frags_x + fx) * 16 + tx / 4, fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; @@ -1897,8 +1882,7 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant( __syncthreads(); uint32_t o_smem_offset_w = smem_t::get_permuted_offset( - ty * num_frags_x * 16 + tx / 8, - tx % 8); + ty * num_frags_x * 16 + tx / 8, tx % 8); const uint32_t tx_offset = tx / 8; #pragma unroll @@ -1914,13 +1898,12 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant( uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim + tx % 8 * num_elems_per_128b(); #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { if (n_offset < qo_upper_bound) { if (!partition_kv) { Load( - reinterpret_cast(o_smem->base + o_smem_offset_w), - &ori_out_vec); + reinterpret_cast(o_smem->base + o_smem_offset_w), + &ori_out_vec); if (in_scale > 0.0) { if (shift_bias) { Load(shift_bias + shift_smooth_offset, @@ -1929,16 +1912,16 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant( &smooth_weight_vec); } } - #pragma unroll +#pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { StoreFunc()(ori_out_vec, - shift_bias_vec, - smooth_weight_vec, - out_vec, - quant_max_bound, - quant_min_bound, - in_scale, - i); + shift_bias_vec, + smooth_weight_vec, + out_vec, + quant_max_bound, + quant_min_bound, + in_scale, + i); } Store(out_vec, o_ptr); } else { @@ -1979,10 +1962,8 @@ __device__ __forceinline__ void write_o_reg_gmem( for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t o_frag_f16[4]; vec_cast((T*)o_frag_f16, o_frag[fx][fy]); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset< - num_vecs_per_head>( - (ty * num_frags_x + fx) * 16 + tx / 4, - fy * 2); + uint32_t o_smem_offset_w = smem_t::get_permuted_offset( + (ty * num_frags_x + fx) * 16 + tx / 4, fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; @@ -1995,8 +1976,7 @@ __device__ __forceinline__ void write_o_reg_gmem( __syncthreads(); uint32_t o_smem_offset_w = smem_t::get_permuted_offset( - ty * num_frags_x * 16 + tx / 8, - tx % 8); + ty * num_frags_x * 16 + tx / 8, tx % 8); o_idx_base += (tx / 8) / group_size; o_ptr_base += ((tx / 8) / group_size) * qo_n_stride + @@ -2009,8 +1989,7 @@ __device__ __forceinline__ void write_o_reg_gmem( T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride + ((fx * 16 + j * 4) % group_size) * qo_h_stride; #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { if (o_idx < qo_upper_bound) { o_smem->store_128b(o_smem_offset_w, o_ptr); } @@ -2125,7 +2104,6 @@ __global__ void merge_multi_chunks_kernel( &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); } - template __device__ __forceinline__ void merge_block_res(float (*o_frag)[num_frags_y][8], float* md_smem, @@ -2307,18 +2285,18 @@ template __global__ void merge_multi_chunks_decoder_kernel( - const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads, + const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads, // head_dim] - const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads] - const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads] - const int *__restrict__ seq_lens_q, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ seq_lens_encoder, - const int *__restrict__ cu_seqlens_q, - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const T *__restrict__ sinks, // [q_num_heads] - OutT *__restrict__ out, + const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads] + const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads] + const int* __restrict__ seq_lens_q, + const int* __restrict__ seq_lens_kv, + const int* __restrict__ seq_lens_encoder, + const int* __restrict__ cu_seqlens_q, + const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T* __restrict__ sinks, // [q_num_heads] + OutT* __restrict__ out, const float quant_max_bound, const float quant_min_bound, const float in_scale, @@ -2419,8 +2397,14 @@ __global__ void merge_multi_chunks_decoder_kernel( } #pragma unroll for (int i = 0; i < vec_size; ++i) { - StoreFunc()( - st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i); + StoreFunc()(st.o, + shift_bias_vec, + smooth_weight_vec, + out_vec, + quant_max_bound, + quant_min_bound, + in_scale, + i); } Store( out_vec, @@ -2435,19 +2419,19 @@ template __global__ void merge_multi_chunks_v2_kernel( - const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads, + const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads, // head_dim] - const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads] - const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads] - const int *__restrict__ seq_lens_q, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ seq_lens_encoder, - const int *__restrict__ batch_id_per_token, - const int *__restrict__ cu_seqlens_q, - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const T *__restrict__ sinks, // [q_num_heads] - OutT *__restrict__ out, + const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads] + const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads] + const int* __restrict__ seq_lens_q, + const int* __restrict__ seq_lens_kv, + const int* __restrict__ seq_lens_encoder, + const int* __restrict__ batch_id_per_token, + const int* __restrict__ cu_seqlens_q, + const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T* __restrict__ sinks, // [q_num_heads] + OutT* __restrict__ out, const float quant_max_bound, const float quant_min_bound, const float in_scale, @@ -2464,7 +2448,7 @@ __global__ void merge_multi_chunks_v2_kernel( __shared__ float md_smem[bdy * 2]; for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { const uint32_t bid = batch_id_per_token[qid]; - if(bid == -1){ + if (bid == -1) { continue; } const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; @@ -2486,7 +2470,7 @@ __global__ void merge_multi_chunks_v2_kernel( const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size); if (num_chunks_this_seq <= 1) { continue; - }else if (!ENABLE_PREFILL){ + } else if (!ENABLE_PREFILL) { continue; } @@ -2496,12 +2480,12 @@ __global__ void merge_multi_chunks_v2_kernel( if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < vec_size / 2; ++i) { - *((half2 *)(&res_vec) + i) = make_half2(0, 0); + *((half2*)(&res_vec) + i) = make_half2(0, 0); } } else { #pragma unroll for (int i = 0; i < vec_size / 2; ++i) { - *((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0); + *((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0); } } float m; @@ -2581,10 +2565,17 @@ __global__ void merge_multi_chunks_v2_kernel( Load(smooth_weight + shift_smooth_offset, &smooth_weight_vec); } + #pragma unroll for (int i = 0; i < vec_size; ++i) { - StoreFunc()( - st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i); + StoreFunc()(st.o, + shift_bias_vec, + smooth_weight_vec, + out_vec, + quant_max_bound, + quant_min_bound, + in_scale, + i); } Store( out_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); From 6edc93b84ba9b35792506c2bf15d0268e2c28a05 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 31 Oct 2025 17:35:57 +0800 Subject: [PATCH 07/16] commit --- .../get_block_shape_and_split_kv_block.cu | 239 ++++++++++-------- 1 file changed, 128 insertions(+), 111 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 4fc43e34fa7..8c45fbd40e0 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -20,12 +20,11 @@ #include "utils.cuh" template -__global__ void -GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time, - const int *seq_lens_encoder, - const int *seq_lens_this_time_merged, - const int *seq_lens_encoder_merged, const int *seq_mapping, - const int *system_lens, int *max_lens, const int batch_size) { +__global__ void GetMaxLenKernel(const int *seq_lens_decoder, + const int *seq_lens_this_time, + const int *seq_lens_encoder, + int *max_lens, + const int batch_size) { const int tid = threadIdx.x; typedef cub::BlockReduce BlockReduce; @@ -36,9 +35,6 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time, int max_len_decoder_this_thread = 0; int max_len_this_thread = 0; int max_just_dec_len_this_thread = 0; - int max_just_dec_merged_len_this_time_this_thread = 0; - int max_system_len_this_thread = 0; - int max_dec_len_without_system_this_thread = 0; int max_len_kv_this_thread = 0; for (int i = tid; i < batch_size; i += blockDim.x) { const int seq_len_this_time = seq_lens_this_time[i]; @@ -47,17 +43,17 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time, max(seq_len_this_time, max_len_this_time_this_thread); max_len_encoder_this_thread = max(seq_lens_encoder[i], max_len_encoder_this_thread); - max_len_decoder_this_thread = max(seq_len_decoder, max_len_decoder_this_thread); - if (seq_len_this_time <= 0) - continue; - const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder; + max_len_decoder_this_thread = + max(seq_len_decoder, max_len_decoder_this_thread); + if (seq_len_this_time <= 0) continue; + const int max_just_dec_len_now = + seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder; max_len_this_thread = max(seq_len_decoder + seq_len_this_time, max_len_this_thread); max_just_dec_len_this_thread = max(max_just_dec_len_this_thread, max_just_dec_len_now); - if (seq_len_decoder == 0) - continue; + if (seq_len_decoder == 0) continue; max_len_kv_this_thread = max(seq_len_this_time + seq_len_decoder, max_len_kv_this_thread); } @@ -74,14 +70,6 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time, BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); int total_just_dec = BlockReduce(temp_storage) .Reduce(max_just_dec_len_this_thread, MaxOp()); - int total_just_dec_merged = - BlockReduce(temp_storage) - .Reduce(max_just_dec_merged_len_this_time_this_thread, MaxOp()); - int total_system_len = BlockReduce(temp_storage) - .Reduce(max_system_len_this_thread, MaxOp()); - int total_dec_len_without_system = - BlockReduce(temp_storage) - .Reduce(max_dec_len_without_system_this_thread, MaxOp()); int total_max_len_kv = BlockReduce(temp_storage).Reduce(max_len_kv_this_thread, MaxOp()); if (tid == 0) { @@ -90,9 +78,9 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time, max_lens[2] = total_max_len_decoder; max_lens[3] = total; max_lens[4] = total_just_dec; - max_lens[5] = total_just_dec_merged; - max_lens[6] = total_system_len; - max_lens[7] = total_dec_len_without_system; + max_lens[5] = 0; + max_lens[6] = 0; + max_lens[7] = 0; max_lens[8] = total_max_len_kv; } } @@ -100,12 +88,15 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time, void GetMaxLen(const paddle::Tensor &seq_lens_tensor, const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &seq_lens_encoder, - paddle::Tensor &max_len_tensor, const int batch_size) { + paddle::Tensor &max_len_tensor, + const int batch_size) { constexpr int blockSize = 1024; GetMaxLenKernel<<<1, blockSize, 0, seq_lens_encoder.stream()>>>( - seq_lens_tensor.data(), seq_lens_this_time.data(), - seq_lens_encoder.data(), nullptr, nullptr, nullptr, nullptr, - max_len_tensor.data(), batch_size); + seq_lens_tensor.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + max_len_tensor.data(), + batch_size); } template @@ -154,11 +145,11 @@ __global__ void search_chunk_size_for_mla( uint32_t res_id = 0; uint32_t max_last_wave_block = 0; for (uint32_t i = 1; i < config_size; i++) { - uint32_t last_wave_block = gridx_shared[i] % sm_cout; - if (last_wave_block >= max_last_wave_block) { - res_id = i; - max_last_wave_block = last_wave_block; - } + uint32_t last_wave_block = gridx_shared[i] % sm_cout; + if (last_wave_block >= max_last_wave_block) { + res_id = i; + max_last_wave_block = last_wave_block; + } } *num_blocks_x = gridx_shared[res_id]; *res_chunk_size = block_size << res_id; @@ -185,11 +176,11 @@ __global__ void split_block_for_mla(const int *__restrict__ seq_lens_q, int loop_times; loop_times = cute::ceil_div(seq_len_decoder, chunk_size); if (seq_len_encoder > 0) { - loop_times = 0; + loop_times = 0; } for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) { - batch_ids[index] = bid; - tile_ids_per_batch[index++] = tile_id; + batch_ids[index] = bid; + tile_ids_per_batch[index++] = tile_id; } } } @@ -255,8 +246,10 @@ __global__ void split_kv_block(const int *__restrict__ seq_lens_decoder, const int *__restrict__ seq_lens_encoder, int *__restrict__ batch_ids, int *__restrict__ tile_ids_per_batch, - int *__restrict__ num_blocks_x, const int bsz, - const int pad_len, const int num_row_per_block) { + int *__restrict__ num_blocks_x, + const int bsz, + const int pad_len, + const int num_row_per_block) { if (threadIdx.x == 0) { int gridx = 0; int index = 0; @@ -281,31 +274,37 @@ void GetBlockShapeAndSplitKVBlock( const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_this_time, - paddle::Tensor &decoder_batch_ids, // Inplace - paddle::Tensor &decoder_tile_ids_per_batch, // Inplace - paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory - paddle::Tensor &decoder_num_blocks_device, // Inplace - paddle::Tensor &decoder_chunk_size_device, // Inplace - paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU - paddle::Tensor &encoder_batch_ids, // Inplace - paddle::Tensor &encoder_tile_ids_per_batch, // Inplace - paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU - paddle::Tensor &kv_batch_ids, // Inplace - paddle::Tensor &kv_tile_ids_per_batch, // Inplace - paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU + paddle::Tensor &decoder_batch_ids, // Inplace + paddle::Tensor &decoder_tile_ids_per_batch, // Inplace + paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory + paddle::Tensor &decoder_num_blocks_device, // Inplace + paddle::Tensor &decoder_chunk_size_device, // Inplace + paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU + paddle::Tensor &encoder_batch_ids, // Inplace + paddle::Tensor &encoder_tile_ids_per_batch, // Inplace + paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU + paddle::Tensor &kv_batch_ids, // Inplace + paddle::Tensor &kv_tile_ids_per_batch, // Inplace + paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, const int block_size, - const int decoder_step_token_num) -{ + const int decoder_step_token_num) { auto stream = seq_lens_encoder.stream(); int bsz = seq_lens_this_time.shape()[0]; - paddle::Tensor max_len_tensor_gpu = GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, paddle::DataType::INT32, seq_lens_this_time.place()); - GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder, - max_len_tensor_gpu, bsz); - max_len_tensor_cpu.copy_(max_len_tensor_gpu, max_len_tensor_cpu.place(), false); + paddle::Tensor max_len_tensor_gpu = + GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, + paddle::DataType::INT32, + seq_lens_this_time.place()); + GetMaxLen(seq_lens_decoder, + seq_lens_this_time, + seq_lens_encoder, + max_len_tensor_gpu, + bsz); + max_len_tensor_cpu.copy_( + max_len_tensor_gpu, max_len_tensor_cpu.place(), false); auto max_len_cpu_ptr = max_len_tensor_cpu.data(); int max_len_this_time = max_len_cpu_ptr[0]; @@ -320,7 +319,6 @@ void GetBlockShapeAndSplitKVBlock( // decoder if (max_dec_len_this_time > 0) { - const bool mla_backend = checkAttentionBackend(); if (mla_backend && group_size <= 64) { const int set_chunk_size = get_mla_dec_chunk_size(bsz); @@ -356,8 +354,9 @@ void GetBlockShapeAndSplitKVBlock( const int chunk_size = decoder_chunk_size_cpu.data()[0]; // NOTE: (changwenbin) When using auto_chunk, - // decode_max_tile_size must take into account the maximum case, where * 1024 can cover 128K. - // const uint32_t decoder_batch_shape = seq_lens_decoder.dims()[0] * 1024; + // decode_max_tile_size must take into account the maximum case, where * + // 1024 can cover 128K. const uint32_t decoder_batch_shape = + // seq_lens_decoder.dims()[0] * 1024; const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q); @@ -375,7 +374,6 @@ void GetBlockShapeAndSplitKVBlock( decoder_batch_shape * sizeof(int32_t), stream)); - split_block_for_mla<<<1, 32, 0, stream>>>( seq_lens_this_time.data(), seq_lens_encoder.data(), @@ -419,49 +417,72 @@ void GetBlockShapeAndSplitKVBlock( decoder_num_blocks_cpu.copy_( decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); + decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); } } else { - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); - decoder_num_blocks_cpu.copy_( - decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); + decoder_num_blocks_cpu.copy_( + decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); } // encoder if (max_enc_len_this_time > 0) { - const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size); + const uint32_t max_tile_size_per_bs_kv = + div_up(max_enc_dec_len_this_time, block_size); const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv; - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data(), 0, kv_batch_shape * sizeof(int32_t), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data(), 0, kv_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + kv_batch_ids.data(), 0, kv_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(kv_tile_ids_per_batch.data(), + 0, + kv_batch_shape * sizeof(int32_t), + stream)); auto kv_num_blocks_x = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); split_kv_block<<<1, 32, 0, seq_lens_encoder.stream()>>>( seq_lens_decoder.data(), // sequence_lengths->data(), - seq_lens_encoder.data(), kv_batch_ids.data(), - kv_tile_ids_per_batch.data(), kv_num_blocks_x.data(), bsz, - block_size, block_size); - - kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false); + seq_lens_encoder.data(), + kv_batch_ids.data(), + kv_tile_ids_per_batch.data(), + kv_num_blocks_x.data(), + bsz, + block_size, + block_size); + + kv_num_blocks_x_cpu.copy_( + kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false); // Clear buffer - const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q); + const uint32_t encoder_max_tile_size_per_bs_q = + div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q); const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q; - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data(), 0, encoder_batch_shape * sizeof(int32_t), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data(), 0, encoder_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(encoder_batch_ids.data(), + 0, + encoder_batch_shape * sizeof(int32_t), + stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(encoder_tile_ids_per_batch.data(), + 0, + encoder_batch_shape * sizeof(int32_t), + stream)); auto encoder_num_blocks_x = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); - split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data(), nullptr, + split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data(), + nullptr, encoder_batch_ids.data(), encoder_tile_ids_per_batch.data(), - encoder_num_blocks_x.data(), bsz, - encoder_block_shape_q, group_size); - encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false); + encoder_num_blocks_x.data(), + bsz, + encoder_block_shape_q, + group_size); + encoder_num_blocks_x_cpu.copy_( + encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false); } - } std::vector> GetBlockShapeAndSplitKVBlockInferShape( @@ -472,8 +493,7 @@ std::vector> GetBlockShapeAndSplitKVBlockInferShape( const int decoder_block_shape_q, const int group_size, const int block_size, - const int decoder_step_token_num -) { + const int decoder_step_token_num) { return {}; } @@ -485,39 +505,36 @@ std::vector GetBlockShapeAndSplitKVBlockInferDtype( const int decoder_block_shape_q, const int group_size, const int block_size, - const int decoder_step_token_num -) { + const int decoder_step_token_num) { return {}; } PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) .Inputs({ - "seq_lens_encoder", - "seq_lens_decoder", - "seq_lens_this_time", - "decoder_batch_ids", - "decoder_tile_ids_per_batch", - "decoder_num_blocks_cpu", - "decoder_num_blocks_device", - "decoder_chunk_size_device", - "max_len_tensor_cpu", - "encoder_batch_ids", - "encoder_tile_ids_per_batch", - "encoder_num_blocks_x_cpu", - "kv_batch_ids", - "kv_tile_ids_per_batch", - "kv_num_blocks_x_cpu", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "decoder_batch_ids", + "decoder_tile_ids_per_batch", + "decoder_num_blocks_cpu", + "decoder_num_blocks_device", + "decoder_chunk_size_device", + "max_len_tensor_cpu", + "encoder_batch_ids", + "encoder_tile_ids_per_batch", + "encoder_num_blocks_x_cpu", + "kv_batch_ids", + "kv_tile_ids_per_batch", + "kv_num_blocks_x_cpu", }) .Outputs({ }) - .Attrs({ - "encoder_block_shape_q: int", - "decoder_block_shape_q: int", - "group_size: int", - "block_size: int", - "decoder_step_token_num: int" - }) + .Attrs({"encoder_block_shape_q: int", + "decoder_block_shape_q: int", + "group_size: int", + "block_size: int", + "decoder_step_token_num: int"}) .SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock)) .SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype)); From e383754cf1264d57c24e4f69c157d95760920096 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 31 Oct 2025 21:55:37 +0800 Subject: [PATCH 08/16] commit --- .../get_block_shape_and_split_kv_block.cu | 27 ++++++------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 8c45fbd40e0..e9585c6f701 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -85,20 +85,6 @@ __global__ void GetMaxLenKernel(const int *seq_lens_decoder, } } -void GetMaxLen(const paddle::Tensor &seq_lens_tensor, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - paddle::Tensor &max_len_tensor, - const int batch_size) { - constexpr int blockSize = 1024; - GetMaxLenKernel<<<1, blockSize, 0, seq_lens_encoder.stream()>>>( - seq_lens_tensor.data(), - seq_lens_this_time.data(), - seq_lens_encoder.data(), - max_len_tensor.data(), - batch_size); -} - template __global__ void search_chunk_size_for_mla( const int *__restrict__ seq_lens_q, @@ -298,11 +284,14 @@ void GetBlockShapeAndSplitKVBlock( GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, paddle::DataType::INT32, seq_lens_this_time.place()); - GetMaxLen(seq_lens_decoder, - seq_lens_this_time, - seq_lens_encoder, - max_len_tensor_gpu, - bsz); + + GetMaxLenKernel<1024><<<1, 1024, 0, seq_lens_encoder.stream()>>>( + seq_lens_decoder.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + max_len_tensor_gpu.data(), + bsz); + max_len_tensor_cpu.copy_( max_len_tensor_gpu, max_len_tensor_cpu.place(), false); From 07f41fa99b15643b0422b9861890b0cb9cf333e9 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 31 Oct 2025 21:56:09 +0800 Subject: [PATCH 09/16] commit --- .../gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index e9585c6f701..f7620f765c2 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -285,7 +285,7 @@ void GetBlockShapeAndSplitKVBlock( paddle::DataType::INT32, seq_lens_this_time.place()); - GetMaxLenKernel<1024><<<1, 1024, 0, seq_lens_encoder.stream()>>>( + GetMaxLenKernel<1024><<<1, 1024, 0, stream>>>( seq_lens_decoder.data(), seq_lens_this_time.data(), seq_lens_encoder.data(), From 2774a1176abbc2e2ae8295819cb37fe84d97d603 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 31 Oct 2025 21:56:20 +0800 Subject: [PATCH 10/16] commit --- .../gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index f7620f765c2..b3e1cc1f8be 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -292,6 +292,7 @@ void GetBlockShapeAndSplitKVBlock( max_len_tensor_gpu.data(), bsz); + max_len_tensor_cpu.copy_( max_len_tensor_gpu, max_len_tensor_cpu.place(), false); From e7895b3c3f193156addaa3d1a118a4f0ea24d462 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 31 Oct 2025 21:56:37 +0800 Subject: [PATCH 11/16] commit --- .../get_block_shape_and_split_kv_block.cu | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index b3e1cc1f8be..d5c8cb73f8f 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -285,13 +285,11 @@ void GetBlockShapeAndSplitKVBlock( paddle::DataType::INT32, seq_lens_this_time.place()); - GetMaxLenKernel<1024><<<1, 1024, 0, stream>>>( - seq_lens_decoder.data(), - seq_lens_this_time.data(), - seq_lens_encoder.data(), - max_len_tensor_gpu.data(), - bsz); - + GetMaxLenKernel<1024><<<1, 1024, 0, stream>>>(seq_lens_decoder.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + max_len_tensor_gpu.data(), + bsz); max_len_tensor_cpu.copy_( max_len_tensor_gpu, max_len_tensor_cpu.place(), false); From 732317f1c37ef4d00eab7a18e85a7f9532db01a1 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Fri, 31 Oct 2025 22:18:34 +0800 Subject: [PATCH 12/16] commit --- .../gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu | 3 --- 1 file changed, 3 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index d5c8cb73f8f..4a42235f59b 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -78,9 +78,6 @@ __global__ void GetMaxLenKernel(const int *seq_lens_decoder, max_lens[2] = total_max_len_decoder; max_lens[3] = total; max_lens[4] = total_just_dec; - max_lens[5] = 0; - max_lens[6] = 0; - max_lens[7] = 0; max_lens[8] = total_max_len_kv; } } From 639c5c310c3f2b54b9a0d7840f4529d7bf64f856 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Sat, 1 Nov 2025 00:18:20 +0800 Subject: [PATCH 13/16] commit --- .../append_attn/multiquery_attention_c16_impl.cuh | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index 8026c4de4d1..edfbae68646 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -61,9 +61,6 @@ __global__ void multi_query_append_attention_kernel( OutT *__restrict__ out, const int speculate_max_draft_token_num = 5, const int sliding_window = 0) { - static_assert(num_frags_y * 16 == HEAD_DIM); - static_assert(num_frags_z * 16 == BLOCK_SIZE); - constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; const uint32_t kv_num_heads = gridDim.z; @@ -76,9 +73,7 @@ __global__ void multi_query_append_attention_kernel( const uint32_t batch_id = batch_ids[btid]; const uint32_t tile_id = tile_ids_per_batch[btid]; const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; - const int *block_table_now = nullptr; - - block_table_now = block_table + batch_id * max_block_num_per_seq; + const int *block_table_now = block_table + batch_id * max_block_num_per_seq; // When cudagraph capture prefill, may launch more gridDim.x if (btid >= static_cast(num_blocks_x_cpu)) { @@ -115,6 +110,9 @@ __global__ void multi_query_append_attention_kernel( const uint32_t chunk_len = chunk_end - chunk_start; extern __shared__ uint8_t smem[]; + static_assert(num_frags_y * 16 == HEAD_DIM); + static_assert(num_frags_z * 16 == BLOCK_SIZE); + float s_frag[num_frags_x][num_frags_z][8]; float o_frag[num_frags_x][num_frags_y][8]; float m_frag[num_frags_x][2]; From 9a8ddada423069253a72a3171da2eeddf533fac9 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Sat, 1 Nov 2025 00:31:11 +0800 Subject: [PATCH 14/16] put q_end --- .../append_attn/multiquery_attention_c16_impl.cuh | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index edfbae68646..b49c3770efa 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -84,8 +84,7 @@ __global__ void multi_query_append_attention_kernel( if (q_len <= 0) { return; } - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; if (ENABLE_PREFILL) { kv_len += q_len; @@ -157,6 +156,10 @@ __global__ void multi_query_append_attention_kernel( uint32_t q_smem_offset_r = smem_t::get_permuted_offset( wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 + + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + load_q_global_smem( q_base_ptr, &qo_smem, @@ -483,8 +486,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( if (q_len <= 0) { return; } - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; if (ENABLE_PREFILL) { kv_len += q_len; @@ -551,6 +553,10 @@ __global__ void multi_query_append_attention_warp1_4_kernel( uint32_t q_smem_offset_r = smem_t::get_permuted_offset( tid % 16, tid / 16); // 16 * 16 + + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + load_q_global_smem_multi_warps Date: Sat, 1 Nov 2025 09:33:23 +0800 Subject: [PATCH 15/16] remove max_dec_len --- .../gpu_ops/append_attn/multiquery_attention_c16_impl.cuh | 6 ------ 1 file changed, 6 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index b49c3770efa..b31d4ec8fe1 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -46,7 +46,6 @@ __global__ void multi_query_append_attention_kernel( const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ mask_offset, const int max_seq_len, - const int max_dec_len, const int max_block_num_per_seq, const float scale, const float quant_max_bound, @@ -445,7 +444,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const int *__restrict__ mask_offset, const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask const int max_seq_len, - const int max_dec_len, const int max_block_num_per_seq, const float scale, const float quant_max_bound, @@ -960,7 +958,6 @@ void MultiQueryAppendAttention( block_table.data(), meta_data.mask_offset, max_seq_len, - max_dec_len, max_block_num_per_seq, scale, quant_max_bound, @@ -1023,7 +1020,6 @@ void MultiQueryAppendAttention( block_table.data(), meta_data.mask_offset, max_seq_len, - max_dec_len, max_block_num_per_seq, scale, quant_max_bound, @@ -1208,7 +1204,6 @@ void MultiQueryAppendAttention( attn_mask ? const_cast(attn_mask.get().data()) : nullptr, max_seq_len, - max_dec_len, max_block_num_per_seq, scale, quant_max_bound, @@ -1285,7 +1280,6 @@ void MultiQueryAppendAttention( attn_mask ? const_cast(attn_mask.get().data()) : nullptr, max_seq_len, - max_dec_len, max_block_num_per_seq, scale, quant_max_bound, From 9eb36ae82d35df816828a45fd136e5d25cb253b9 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Sat, 1 Nov 2025 09:43:03 +0800 Subject: [PATCH 16/16] remove NUM_WARP_KV from multi_query_append_attention_kernel --- .../gpu_ops/append_attn/multiquery_attention_c16_impl.cuh | 3 --- 1 file changed, 3 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index b31d4ec8fe1..90fd7079c50 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -21,7 +21,6 @@ template