Skip to content

Commit 1789c88

Browse files
committed
fix
1 parent c5de4a5 commit 1789c88

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

csrc/attention/decode_attention_kernel_in8kv_flashdecoding_diverse.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ void dynamic_batching_flashdecoding_cache_attention_int8kv_diverse_kernel(
210210
const int64_t head_idx = blockIdx.x;
211211
const int64_t batch_idx = blockIdx.y;
212212
const int64_t seq_block_idx = blockIdx.z;
213+
const int64_t output_seq_block_idx = seq_block_idx + (b_shared_seq_len[batch_idx] + seq_block_size - 1) / seq_block_size;
213214

214215
const int64_t seq_len = b_seq_len[batch_idx] - b_shared_seq_len[batch_idx];
215216
const int64_t cur_req_idx = b_req_idx[batch_idx];
@@ -435,12 +436,12 @@ void dynamic_batching_flashdecoding_cache_attention_int8kv_diverse_kernel(
435436

436437
__syncthreads();
437438

438-
seq_block_idx += (b_shared_seq_len[batch_idx] + seq_block_size - 1) / seq_block_size;
439+
439440
for (int64_t i = threadIdx.x; i < HEAD_SIZE; i += TPB) {
440-
output_emb[batch_idx * output_emb_stride_b + head_idx * output_emb_stride_h + seq_block_idx * output_emb_stride_s + i] = logits[i];
441+
output_emb[batch_idx * output_emb_stride_b + head_idx * output_emb_stride_h + output_seq_block_idx * output_emb_stride_s + i] = logits[i];
441442
}
442443

443-
output_logexpsum[batch_idx * output_logexpsum_stride_b + head_idx * output_logexpsum_stride_h + seq_block_idx] = logf(exp_sum) + qk_max;
444+
output_logexpsum[batch_idx * output_logexpsum_stride_b + head_idx * output_logexpsum_stride_h + output_seq_block_idx] = logf(exp_sum) + qk_max;
444445
}
445446

446447

0 commit comments

Comments
 (0)