From cd59efe443e92644e03a7e965cad27f9bf107a5d Mon Sep 17 00:00:00 2001 From: Feng Li Date: Mon, 4 Sep 2023 18:53:53 +0000 Subject: [PATCH 1/2] masked_tokens uses session_length --- .../models/multi_gpu_gpt/ParallelGpt.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc index 93b80ae6e..2423df5a2 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc @@ -158,7 +158,7 @@ void ParallelGpt::allocateBuffer(size_t batch_size, parent_ids_buf_ = (int*)(allocator_->reMalloc(parent_ids_buf_, sizeof(int) * batchxbeam * max_session_len, true)); seq_limit_len_ = (uint32_t*)(allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false)); tiled_masked_tokens_ = - (bool*)(allocator_->reMalloc(tiled_masked_tokens_, sizeof(bool) * batchxbeam * memory_len, true)); + (bool*)(allocator_->reMalloc(tiled_masked_tokens_, sizeof(bool) * batchxbeam * max_session_len, true)); context_decoder_input_buf_ = (T*)(allocator_->reMalloc( context_decoder_input_buf_, sizeof(T) * batchxbeam * max_input_len * hidden_units_, false)); @@ -865,7 +865,7 @@ void ParallelGpt::forward(std::unordered_map* outp PUSH_RANGE("initialize output and parent ids"); cudaMemsetAsync(output_ids_buf_, 0, sizeof(int) * batch_size * beam_width * session_len, stream_); cudaMemsetAsync(parent_ids_buf_, 0, sizeof(int) * batch_size * beam_width * session_len, stream_); - cudaMemsetAsync(tiled_masked_tokens_, false, sizeof(bool) * batch_size * beam_width * memory_len, stream_); + cudaMemsetAsync(tiled_masked_tokens_, false, sizeof(bool) * batch_size * beam_width * session_len, stream_); cudaMemsetAsync(tiled_total_padding_count_, 0, sizeof(int) * batch_size * beam_width, stream_); if (beam_width > 1) { cudaMemsetAsync(cache_indirections_[0], 0, 2 * sizeof(int) * batch_size * beam_width * memory_len, stream_); @@ -1180,7 +1180,7 @@ void ParallelGpt::forward(std::unordered_map* outp PUSH_RANGE("mask padding tokens"); invokeMaskPaddingTokens(tiled_masked_tokens_, input_tensors->at("input_lengths").getPtr(), - memory_len, + session_len, max_input_length, initial_step, batch_size, @@ -1316,8 +1316,8 @@ void ParallelGpt::forward(std::unordered_map* outp {"masked_tokens", Tensor(MEMORY_GPU, TYPE_BOOL, - {local_batch_size * beam_width, memory_len}, - tiled_masked_tokens_ + id_offset * memory_len)}}); + {local_batch_size * beam_width, session_len}, + tiled_masked_tokens_ + id_offset * session_len)}}); if (beam_width > 1) { decoder_input_tensors.insert({"cache_indirection", Tensor(MEMORY_GPU, From 72319c6453772166dcfed22c505d1b18d626e08b Mon Sep 17 00:00:00 2001 From: Feng Li Date: Mon, 4 Sep 2023 19:55:42 +0000 Subject: [PATCH 2/2] masked_tokens uses session length everywhere --- .../kernels/decoder_masked_multihead_attention.h | 4 +++- .../decoder_masked_multihead_attention_template.hpp | 5 +++-- .../layers/attention_layers/DecoderSelfAttentionLayer.cc | 7 ++++++- .../models/multi_gpu_gpt/ParallelGptDecoder.cc | 2 +- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention.h b/src/fastertransformer/kernels/decoder_masked_multihead_attention.h index 5a768184c..0d3e67546 100644 --- a/src/fastertransformer/kernels/decoder_masked_multihead_attention.h +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention.h @@ -80,8 +80,10 @@ struct Multihead_attention_params_base { int batch_size = 0; // The beam width int beam_width = 0; - // The sequence length. + // The cache length. int memory_max_len = 0; + // The whole sequence length, which includes context and output. + int session_len = 0; // The number of heads (H). int num_heads = 0; // The hidden dimension per head (Dh). diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp index 8e7cb92a2..038f35ba4 100644 --- a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp @@ -1219,6 +1219,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params::forward(TensorMap* output_tens // finished [batch_size] (optional) // total_padding_tokens [batch_size] (optional) // max_input_length [1] on cpu (optional) - // masked_tokens [batch_size, memory_len], (optional) + // masked_tokens [batch_size, session_len], (optional) // cache_indirection [batch_size / beam_width, beam_width, memory_max_len] (optional) // d_prefix_prompt_lengths [batch_size] (optional) // max_prefix_prompt_length [1] on cpu (optional) @@ -504,6 +507,7 @@ void DecoderSelfAttentionLayer::forward(TensorMap* output_tens const int batch_size = input_tensors->at("input_query").shape[0]; const int beam_width = cache_indir != nullptr ? input_tensors->at("cache_indirection").shape[1] : 1; const int memory_max_len = output_tensors->at("key_cache").shape[3]; + const int session_len = masked_tokens != nullptr ? input_tensors->at("masked_tokens").shape[1] : 0; const int* d_prefix_prompt_lengths = input_tensors->getPtr("d_prefix_prompt_lengths", nullptr); const int max_prefix_prompt_length = input_tensors->getVal("max_prefix_prompt_length", 0); @@ -596,6 +600,7 @@ void DecoderSelfAttentionLayer::forward(TensorMap* output_tens rotary_embedding_dim_, neox_rotary_style_, memory_max_len, + session_len, d_prefix_prompt_lengths, max_prefix_prompt_length, input_tensors->getVal("max_input_length", 0), diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc index 173c87b46..0fe65c94d 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc @@ -269,7 +269,7 @@ void ParallelGptDecoder::forward(std::unordered_map* // cache_indirection [local_batch_size / beam_width, beam_width, memory_len] // Here, local_batch_size contains the beam_width, so local_batch_size / beam_width // is real local_batch_size. (optional.) - // masked_tokens [local_batch_size, memory_len] + // masked_tokens [local_batch_size, session_len] // linear_bias_slopes [head_num], optional // output tensors: