Skip to content

Commit d31425d

Browse files
committed
simplify
1 parent 1789c88 commit d31425d

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

csrc/attention/decode_attention_kernel_in8kv_flashdecoding_diverse.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,12 @@ void dynamic_batching_flashdecoding_cache_attention_int8kv_diverse_kernel(
210210
const int64_t head_idx = blockIdx.x;
211211
const int64_t batch_idx = blockIdx.y;
212212
const int64_t seq_block_idx = blockIdx.z;
213-
const int64_t output_seq_block_idx = seq_block_idx + (b_shared_seq_len[batch_idx] + seq_block_size - 1) / seq_block_size;
213+
const int64_t shared_seq_len = b_shared_seq_len[batch_idx];
214+
const int64_t output_seq_block_idx = seq_block_idx + (shared_seq_len + seq_block_size - 1) / seq_block_size;
214215

215-
const int64_t seq_len = b_seq_len[batch_idx] - b_shared_seq_len[batch_idx];
216+
const int64_t seq_len = b_seq_len[batch_idx] - shared_seq_len;
216217
const int64_t cur_req_idx = b_req_idx[batch_idx];
217-
const int32_t * b_start_loc = req_to_tokens + cur_req_idx * req_to_tokens_stride + seq_block_idx * seq_block_size + b_shared_seq_len[batch_idx];
218+
const int32_t * b_start_loc = req_to_tokens + cur_req_idx * req_to_tokens_stride + seq_block_idx * seq_block_size + shared_seq_len;
218219

219220
// 向量化访问配置
220221
// 128-bit (16 bytes) 是最常用的向量化内存访问宽度,在所有 GPU 架构上都有良好支持

0 commit comments

Comments
 (0)