File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed
Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -210,6 +210,7 @@ void dynamic_batching_flashdecoding_cache_attention_int8kv_diverse_kernel(
210210 const int64_t head_idx = blockIdx .x ;
211211 const int64_t batch_idx = blockIdx .y ;
212212 const int64_t seq_block_idx = blockIdx .z ;
213+ const int64_t output_seq_block_idx = seq_block_idx + (b_shared_seq_len[batch_idx] + seq_block_size - 1 ) / seq_block_size;
213214
214215 const int64_t seq_len = b_seq_len[batch_idx] - b_shared_seq_len[batch_idx];
215216 const int64_t cur_req_idx = b_req_idx[batch_idx];
@@ -435,12 +436,12 @@ void dynamic_batching_flashdecoding_cache_attention_int8kv_diverse_kernel(
435436
436437 __syncthreads ();
437438
438- seq_block_idx += (b_shared_seq_len[batch_idx] + seq_block_size - 1 ) / seq_block_size;
439+
439440 for (int64_t i = threadIdx .x ; i < HEAD_SIZE; i += TPB) {
440- output_emb[batch_idx * output_emb_stride_b + head_idx * output_emb_stride_h + seq_block_idx * output_emb_stride_s + i] = logits[i];
441+ output_emb[batch_idx * output_emb_stride_b + head_idx * output_emb_stride_h + output_seq_block_idx * output_emb_stride_s + i] = logits[i];
441442 }
442443
443- output_logexpsum[batch_idx * output_logexpsum_stride_b + head_idx * output_logexpsum_stride_h + seq_block_idx ] = logf (exp_sum) + qk_max;
444+ output_logexpsum[batch_idx * output_logexpsum_stride_b + head_idx * output_logexpsum_stride_h + output_seq_block_idx ] = logf (exp_sum) + qk_max;
444445}
445446
446447
You can’t perform that action at this time.
0 commit comments