diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 9c6e9b33d5..ac6fefdc6a 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -15,6 +15,74 @@ #include "fused_attn_fp8.h" #include "utils.h" +namespace { +// Helper function to create a tensor view with modified shape and optional pointer offset +transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor *source, + const std::vector &shape, + size_t offset_bytes = 0) { + transformer_engine::Tensor view = *source; + if (offset_bytes > 0) { + view.data.dptr = static_cast(static_cast(source->data.dptr) + offset_bytes); + } + view.data.shape = shape; + view.nvte_tensor = 0; // Mark as unmanaged/local tensor view + return view; +} + +// Helper function to calculate stride for packed QKV tensor unpacking +size_t calculate_qkv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype, + size_t h, size_t d) { + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + stride = (transformer_engine::typeToNumBits(dtype) * h * d) / 8; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + stride = (transformer_engine::typeToNumBits(dtype) * d) / 8; + } + return stride; +} + +// Helper function to determine unpacked shape for QKV packed tensor +std::vector calculate_qkv_unpacked_shape(const transformer_engine::Tensor *qkv_tensor, + size_t h, size_t d) { + std::vector unpacked_shape; + if (qkv_tensor->data.shape.size() == 4) { + // T3HD or TH3D (4D) -> THD (3D): remove dimension "3" at position 1 + unpacked_shape = {qkv_tensor->data.shape[0], h, d}; + } else { + // BS3HD/SB3HD or BSH3D/SBH3D (5D) -> BSHD/SBHD (4D): remove dimension "3" at position 2 + unpacked_shape = {qkv_tensor->data.shape[0], qkv_tensor->data.shape[1], h, d}; + } + return unpacked_shape; +} + +// Helper function to calculate stride for packed KV tensor unpacking +size_t calculate_kv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype, + size_t h_kv, size_t d) { + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + stride = (transformer_engine::typeToNumBits(dtype) * h_kv * d) / 8; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + stride = (transformer_engine::typeToNumBits(dtype) * d) / 8; + } + return stride; +} + +// Helper function to determine unpacked shape for KV packed tensor +std::vector calculate_kv_unpacked_shape(const transformer_engine::Tensor *kv_tensor, + NVTE_QKV_Layout_Group layout_group, + NVTE_QKV_Format kv_format, size_t t_kv, size_t h_kv, + size_t d) { + std::vector unpacked_kv_shape; + if (kv_format == NVTE_QKV_Format::NVTE_THD) { + unpacked_kv_shape = {t_kv, h_kv, d}; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD || + layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + unpacked_kv_shape = {kv_tensor->data.shape[0], kv_tensor->data.shape[1], h_kv, d}; + } + return unpacked_kv_shape; +} +} // namespace + // map NVTE_QKV_Layout to NVTE_QKV_Layout_Group NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { switch (qkv_layout) { @@ -436,6 +504,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with packed QKV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, @@ -487,30 +557,62 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_QKV, input_Bias, - output_O, Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, - wkspace, stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + fused_attn_max_512_fwd(b, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, + input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - fused_attn_arbitrary_seqlen_fwd_qkvpacked( - b, h, max_seqlen, d, t, is_training, return_max_logit, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, - input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, - input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + fused_attn_arbitrary_seqlen_fwd( + b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, + return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, + input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_QKV, input_output_S, output_O, - Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, wkspace, - stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -519,6 +621,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } } // NVTE fused attention BWD with packed QKV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. void nvte_fused_attn_bwd_qkvpacked( const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, @@ -570,9 +674,25 @@ void nvte_fused_attn_bwd_qkvpacked( if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd_qkvpacked( - b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, - input_dO, output_S, output_dQKV, output_dBias, input_cu_seqlens, wkspace, stream, handle); + + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); + + fused_attn_max_512_bwd(b, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_dO, output_S, + &dQ_view, &dK_view, &dV_view, output_dBias, input_cu_seqlens, + input_cu_seqlens, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif @@ -588,12 +708,27 @@ void nvte_fused_attn_bwd_qkvpacked( if (softmax_type != NVTE_VANILLA_SOFTMAX) { input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } - fused_attn_arbitrary_seqlen_bwd_qkvpacked( - b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size_left, window_size_right, deterministic, input_QKV, input_O, - input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQKV, output_dBias, - output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, - stream, handle); + + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); + + fused_attn_arbitrary_seqlen_bwd( + b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, + attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view, + &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view, + &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, + input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " @@ -605,10 +740,26 @@ void nvte_fused_attn_bwd_qkvpacked( const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd_qkvpacked(b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv, - input_S, input_output_dP, output_dQKV, input_cu_seqlens, - input_rng_state, wkspace, stream, handle); + + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); + + fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, + input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, + input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, + handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -617,6 +768,8 @@ void nvte_fused_attn_bwd_qkvpacked( } } // NVTE fused attention FWD with packed KV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. void nvte_fused_attn_fwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, @@ -706,21 +859,40 @@ void nvte_fused_attn_fwd_kvpacked( if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd_kvpacked( - b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + // Unpack KV and call the non-packed function + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, + input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) - fused_attn_arbitrary_seqlen_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, + // Unpack KV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + fused_attn_arbitrary_seqlen_fwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, - output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else @@ -729,10 +901,20 @@ void nvte_fused_attn_fwd_kvpacked( #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, input_KV, input_output_S, output_O, Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + // Unpack KV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -741,6 +923,8 @@ void nvte_fused_attn_fwd_kvpacked( } } // NVTE fused attention BWD with packed KV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, @@ -806,10 +990,23 @@ void nvte_fused_attn_bwd_kvpacked( if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd_kvpacked( - b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, input_Q, input_KV, input_dO, output_S, output_dQ, output_dKV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); + + // Unpack KV and dKV and call the non-packed function + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); + + fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_dO, output_S, + output_dQ, &dK_view, &dV_view, output_dBias, input_cu_seqlens_q, + input_cu_seqlens_kv, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif @@ -825,13 +1022,29 @@ void nvte_fused_attn_bwd_kvpacked( if (softmax_type != NVTE_VANILLA_SOFTMAX) { input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } - fused_attn_arbitrary_seqlen_bwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, + + // Unpack KV and dKV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + // Create tensor views for dK, dV + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); + + fused_attn_arbitrary_seqlen_bwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, - input_Q, input_KV, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ, - output_dKV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, - handle); + input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, + output_dQ, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, + wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.3 is required for BF16/FP16 fused attention " @@ -843,11 +1056,25 @@ void nvte_fused_attn_bwd_kvpacked( const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd_kvpacked(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, - output_dKV, input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, wkspace, stream, handle); + + // Unpack KV and dKV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); + + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O, + input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, + &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, + stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 950ced61bb..14468b543a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1037,532 +1037,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl( } // namespace fused_attn using namespace transformer_engine::fused_attn; -void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, bool return_max_logit, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, - const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_QKV->data.dtype; - void *devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - bias_b = input_Bias->data.shape[0]; - bias_h = input_Bias->data.shape[1]; - } - void *devPtrSoftmaxOffset = nullptr; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; - } - - void *devPtrO = output_O->data.dptr; - void *devPtrS1 = nullptr; - void *devPtrS2 = nullptr; - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; - - size_t max_batch_size = 0; - size_t max_tokens = 0; - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - max_tokens = get_max_tokens(num_tokens); - } - - size_t i = 0; - if (Aux_CTX_Tensors->size == 0) { - const auto cudnn_runtime_version = cudnnGetVersion(); - if (return_max_logit) { - Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_Max->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_Max->data.shape = {max_tokens, num_attn_heads, 1}; - } else { - output_Max->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - } - output_Max->data.dtype = DType::kFloat32; - Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_Sum_Exp->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_Sum_Exp->data.shape = {max_tokens, num_attn_heads, 1}; - } else { - output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - } - output_Sum_Exp->data.dtype = DType::kFloat32; - } else { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - } - output_S->data.dtype = DType::kFloat32; - } - - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; - output_bias->data.dtype = QKV_type; - } - - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_softmax_offset->data.dptr = nullptr; - output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; - output_softmax_offset->data.dtype = DType::kFloat32; - } - - Aux_CTX_Tensors->size = i; - } else if (Aux_CTX_Tensors->size >= 2) { - if (return_max_logit) { - Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS1 = output_Max->data.dptr; - Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS2 = output_Sum_Exp->data.dptr; - } else { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS1 = output_S->data.dptr; - } - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_rng_state->data.dptr = rng_state->data.dptr; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_bias->data.dptr = devPtrBias; - } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_softmax_offset->data.dptr = devPtrSoftmaxOffset; - } - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, - max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, - return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, - devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, devPtrSeqOffsets, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, - Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, - const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_QKV->data.dtype; - void *devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrQ = devPtrQKV; - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - void *devPtrBias = nullptr; - void *devPtrdBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - devPtrdBias = output_dBias->data.dptr; - bias_b = output_dBias->data.shape[0]; - bias_h = output_dBias->data.shape[1]; - } - - size_t max_batch_size = 0; - size_t max_tokens = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - max_tokens = get_max_tokens(num_tokens); - } - - void *devPtrdQKV = output_dQKV->data.dptr; - void *devPtrdQ = devPtrdQKV; - void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - - void *devPtrSoftmaxStats = nullptr; - devPtrSoftmaxStats = output_S->data.dptr; - void *devPtrSoftmaxOffset = nullptr; - void *devPtrdSoftmaxOffset = nullptr; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; - devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; - } - - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, - max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, - bias_type, mask_type, softmax_type, window_size_left, window_size_right, deterministic, - devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, - devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} -void fused_attn_arbitrary_seqlen_fwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_Q->data.dtype; - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); - - void *devPtrBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - bias_b = input_Bias->data.shape[0]; - bias_h = input_Bias->data.shape[1]; - } - void *devPtrSoftmaxOffset = nullptr; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; - } - - void *devPtrO = output_O->data.dptr; - void *devPtrS1 = nullptr; - void *devPtrS2 = nullptr; - - void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; - void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; - void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; - void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; - void *devPtrPageTableK = page_table_k->data.dptr; - void *devPtrPageTableV = page_table_v->data.dptr; - - size_t max_batch_size = 0; - size_t max_tokens_q = 0; - size_t max_tokens_kv = 0; - if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - } - if (q_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_q = get_max_tokens(num_tokens_q); - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_kv = get_max_tokens(num_tokens_kv); - } - - size_t i = 0; - if (Aux_CTX_Tensors->size == 0) { - const auto cudnn_runtime_version = cudnnGetVersion(); - if (return_max_logit) { - Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_Max->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_Max->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_Max->data.dtype = DType::kFloat32; - Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_Sum_Exp->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_Sum_Exp->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_Sum_Exp->data.dtype = DType::kFloat32; - } else { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - } - - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; - output_bias->data.dtype = QKV_type; - } - - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_softmax_offset->data.dptr = nullptr; - output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; - output_softmax_offset->data.dtype = DType::kFloat32; - } - - Aux_CTX_Tensors->size = i; - } else if (Aux_CTX_Tensors->size >= 2) { - if (return_max_logit) { - Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS1 = output_Max->data.dptr; - Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS2 = output_Sum_Exp->data.dptr; - } else { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS1 = output_S->data.dptr; - } - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_rng_state->data.dptr = rng_state->data.dptr; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_bias->data.dptr = devPtrBias; - } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_softmax_offset->data.dptr = devPtrSoftmaxOffset; - } - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, - devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_arbitrary_seqlen_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, - Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_Q->data.dtype; - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); - - void *devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - void *devPtrBias = nullptr; - void *devPtrdBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - devPtrdBias = output_dBias->data.dptr; - bias_b = output_dBias->data.shape[0]; - bias_h = output_dBias->data.shape[1]; - } - - size_t max_batch_size = 0; - size_t max_tokens_q = 0; - size_t max_tokens_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - } - if (q_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_q = get_max_tokens(num_tokens_q); - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_kv = get_max_tokens(num_tokens_kv); - } - - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdKV = output_dKV->data.dptr; - void *devPtrdK = devPtrdKV; - void *devPtrdV = static_cast(static_cast(devPtrdKV) + stride); - - void *devPtrSoftmaxStats = nullptr; - devPtrSoftmaxStats = output_S->data.dptr; - void *devPtrSoftmaxOffset = nullptr; - void *devPtrdSoftmaxOffset = nullptr; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; - devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; - } - - void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; - void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; - void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; - void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, - devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, - devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, @@ -1604,8 +1078,8 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; - void *devPtrPageTableK = page_table_k->data.dptr; - void *devPtrPageTableV = page_table_v->data.dptr; + void *devPtrPageTableK = page_table_k ? page_table_k->data.dptr : nullptr; + void *devPtrPageTableV = page_table_v ? page_table_v->data.dptr : nullptr; size_t max_batch_size = 0; size_t max_tokens_q = 0; diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index a3181c6295..872b798bb4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -18,53 +18,6 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) -void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, bool return_max_logit, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, - const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, - Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, - const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_arbitrary_seqlen_fwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_arbitrary_seqlen_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, - Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); - void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index 89528fa3c4..1028df6452 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -1215,150 +1215,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv } // namespace fused_attn using namespace transformer_engine::fused_attn; -void fused_attn_max_512_fwd_qkvpacked( - size_t batch, size_t num_head, size_t max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - // QKV shape is [b, s, 3, h, d] - void *devPtrQKV = input_QKV->data.dptr; - const auto stride = 2 * num_head * head_dim; - - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrBias = static_cast(input_Bias->data.dptr); - - void *devPtrO = output_O->data.dptr; - - void *devPtrS = nullptr; - - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 1; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen}; - output_S->data.dtype = input_QKV->data.dtype; - } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devPtrCuSeqlen = cu_seqlens->data.dptr; - - const DType rng_state_type = rng_state->data.dtype; - NVTE_CHECK(rng_state_type == DType::kInt64); - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - static_cast(static_cast(rng_state->data.dptr) + 1); - - const DType QKV_type = input_QKV->data.dtype; - size_t workspace_size = 0; - - fused_attn_max_512_fwd_impl( - batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, - devPtrCuSeqlen, devPtrCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(QKV_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || - bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS, - "NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512."); - - // Q shape is [b, s, h, d] - void *devPtrQ = input_Q->data.dptr; - - // KV shape is [b, s, 2, h, d] - const auto stride = 2 * num_head * head_dim; - void *devPtrK = input_KV->data.dptr; - void *devPtrV = static_cast(static_cast(devPtrK) + stride); - - void *devPtrBias = input_Bias->data.dptr; - - void *devPtrO = output_O->data.dptr; - - void *devPtrS = nullptr; - - const DType q_type = input_Q->data.dtype; - const DType kv_type = input_KV->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 1; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; - output_S->data.dtype = q_type; - } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devQCuSeqlen = q_cu_seqlens->data.dptr; - void *devKVCuSeqlen = kv_cu_seqlens->data.dptr; - - const DType rng_state_type = rng_state->data.dtype; - NVTE_CHECK(rng_state_type == DType::kInt64); - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - static_cast(static_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_max_512_fwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, - devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, @@ -1429,126 +1285,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, } } -void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, - size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV, - Tensor *output_dBias, const Tensor *cu_seqlens, - Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - - // QKV shape is [b, s, 3, h, d] - void *devPtrQKV = input_QKV->data.dptr; - - auto stride = 2 * num_head * head_dim; - void *devPtrQ = devPtrQKV; - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrdO = input_dO->data.dptr; - - // dQKV shape is [b, s, 3, h, d] - void *devPtrdQKV = output_dQKV->data.dptr; - void *devPtrdQ = devPtrdQKV; - void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - - void *devPtrdBias = output_dBias->data.dptr; - - void *devPtrS = output_S->data.dptr; - - // devPtrdS reuses the memory of devPtrS - void *devPtrdS = devPtrS; - - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - - const auto qkv_type = input_QKV->data.dtype; - size_t workspace_size = 0; - - fused_attn_max_512_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, attn_scale, - p_dropout, qkv_layout, mask_type, bias_type, devPtrQ, devPtrK, - devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdS, - devPtrdBias, devPtrCuSeqlens, devPtrCuSeqlens, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(qkv_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, - const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - // Q shape is [b, s, h, d] - // KV shape is [b, s, 2, h, d] - auto stride = 2 * num_head * head_dim; - void *devPtrQ = input_Q->data.dptr; - void *devPtrK = input_KV->data.dptr; - void *devPtrV = static_cast(static_cast(devPtrK) + stride); - - void *devPtrdO = input_dO->data.dptr; - - // dQ shape is [b, s, h, d] - // dKV shape is [b, s, 2, h, d] - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdK = output_dKV->data.dptr; - void *devPtrdV = static_cast(static_cast(devPtrdK) + stride); - - void *devPtrdBias = output_dBias->data.dptr; - - void *devPtrS = output_S->data.dptr; - - // devPtrdS reuses the memory of devPtrS - void *devPtrdS = devPtrS; - - void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr; - void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr; - - const auto q_type = input_Q->data.dtype; - const auto kv_type = input_KV->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - size_t workspace_size = 0; - - fused_attn_max_512_bwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, - mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h index 171fe846ce..57b7afcf43 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h @@ -18,25 +18,6 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8901) -void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, - size_t head_size, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, @@ -47,24 +28,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); -void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, - size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV, - Tensor *output_dBias, const Tensor *cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, - const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 7b85be972c..5d806290a9 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2407,424 +2407,6 @@ void fused_attn_fp8_bwd_impl_v1( } // namespace fused_attn #if (CUDNN_VERSION >= 8900) -// fused attention FWD FP8 with packed QKV -void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen, - size_t head_dim, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor* input_QKV, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_QKV->data.dtype; - const DType O_type = output_O->data.dtype; - void* devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrQ = static_cast(devPtrQKV); - void* devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void* devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - void* devPtrDescaleQ = input_QKV->scale_inv.dptr; - void* devPtrDescaleK = input_QKV->scale_inv.dptr; - void* devPtrDescaleV = input_QKV->scale_inv.dptr; - - void* devPtrO = output_O->data.dptr; - void* devPtrAmaxO = output_O->amax.dptr; - void* devPtrScaleO = output_O->scale.dptr; - - void* devPtrM = nullptr; - void* devPtrZInv = nullptr; - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 3; - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_M->data.dptr = nullptr; - output_M->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - output_M->data.dtype = DType::kFloat32; - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - output_ZInv->data.dtype = DType::kFloat32; - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - devPtrM = output_M->data.dptr; - devPtrZInv = output_ZInv->data.dptr; - output_rng_state->data.dptr = rng_state->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void* devPtrAmaxS = input_output_S->amax.dptr; - void* devPtrScaleS = input_output_S->scale.dptr; - void* devPtrDescaleS = input_output_S->scale_inv.dptr; - - void* devPtrcuSeqlens = - reinterpret_cast(reinterpret_cast(cu_seqlens->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, - devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, - devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} -// fused attention BWD FP8 with packed QKV -void fused_attn_fp8_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, - const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, - const Tensor* output_dQKV, const Tensor* cu_seqlens, const Tensor* rng_state, Tensor* workspace, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_QKV->data.dtype; - const DType dO_type = input_dO->data.dtype; - const DType dQKV_type = output_dQKV->data.dtype; - void* devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrQ = devPtrQKV; - void* devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void* devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - void* devPtrDescaleQ = input_QKV->scale_inv.dptr; - void* devPtrDescaleK = input_QKV->scale_inv.dptr; - void* devPtrDescaleV = input_QKV->scale_inv.dptr; - - void* devPtrO = input_O->data.dptr; - const DType O_type = input_O->data.dtype; - void* devPtrDescaleO = nullptr; - if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { - devPtrDescaleO = input_O->scale_inv.dptr; - } - void* devPtrdO = input_dO->data.dptr; - void* devPtrDescaledO = input_dO->scale_inv.dptr; - - void* devPtrM = input_M->data.dptr; - void* devPtrZInv = input_ZInv->data.dptr; - - void* devPtrScaleS = input_S->scale.dptr; - void* devPtrDescaleS = input_S->scale_inv.dptr; - void* devPtrAmaxdP = input_output_dP->amax.dptr; - void* devPtrScaledP = input_output_dP->scale.dptr; - void* devPtrDescaledP = input_output_dP->scale_inv.dptr; - - void* devPtrdQKV = output_dQKV->data.dptr; - void* devPtrdQ = devPtrdQKV; - void* devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void* devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - void* devPtrAmaxdQ = output_dQKV->amax.dptr; - void* devPtrAmaxdK = output_dQKV->amax.dptr; - void* devPtrAmaxdV = output_dQKV->amax.dptr; - void* devPtrScaledQ = output_dQKV->scale.dptr; - void* devPtrScaledK = output_dQKV->scale.dptr; - void* devPtrScaledV = output_dQKV->scale.dptr; - - void* devPtrcuSeqlens = - reinterpret_cast(reinterpret_cast(cu_seqlens->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, - devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, - devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, - devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, - devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, - devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} -// fused attention FWD FP8 with packed KV -void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor* input_Q, - const Tensor* input_KV, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, - Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_Q->data.dtype; - const DType O_type = output_O->data.dtype; - void* devPtrQ = input_Q->data.dptr; - void* devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrK = devPtrKV; - void* devPtrV = static_cast(static_cast(devPtrKV) + stride); - void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_KV->scale_inv.dptr; - void* devPtrDescaleV = input_KV->scale_inv.dptr; - - void* devPtrO = output_O->data.dptr; - void* devPtrAmaxO = output_O->amax.dptr; - void* devPtrScaleO = output_O->scale.dptr; - - void* devPtrM = nullptr; - void* devPtrZInv = nullptr; - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 3; - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_M->data.dptr = nullptr; - output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_M->data.dtype = DType::kFloat32; - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_ZInv->data.dtype = DType::kFloat32; - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - devPtrM = output_M->data.dptr; - devPtrZInv = output_ZInv->data.dptr; - output_rng_state->data.dptr = rng_state->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void* devPtrAmaxS = input_output_S->amax.dptr; - void* devPtrScaleS = input_output_S->scale.dptr; - void* devPtrDescaleS = input_output_S->scale_inv.dptr; - - void* devPtrcuSeqlensQ = - reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); - void* devPtrcuSeqlensKV = - reinterpret_cast(reinterpret_cast(cu_seqlens_kv->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, - p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, - devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, - devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} -// fused attention BWD FP8 with packed KV -void fused_attn_fp8_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, - const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, - const Tensor* output_dQ, const Tensor* output_dKV, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_Q->data.dtype; - const DType dO_type = input_dO->data.dtype; - const DType dQKV_type = output_dQ->data.dtype; - void* devPtrQ = input_Q->data.dptr; - void* devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrK = devPtrKV; - void* devPtrV = static_cast(static_cast(devPtrKV) + stride); - void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_KV->scale_inv.dptr; - void* devPtrDescaleV = input_KV->scale_inv.dptr; - - void* devPtrO = input_O->data.dptr; - const DType O_type = input_O->data.dtype; - void* devPtrDescaleO = nullptr; - if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { - devPtrDescaleO = input_O->scale_inv.dptr; - } - void* devPtrdO = input_dO->data.dptr; - void* devPtrDescaledO = input_dO->scale_inv.dptr; - - void* devPtrM = input_M->data.dptr; - void* devPtrZInv = input_ZInv->data.dptr; - - void* devPtrScaleS = input_S->scale.dptr; - void* devPtrDescaleS = input_S->scale_inv.dptr; - void* devPtrAmaxdP = input_output_dP->amax.dptr; - void* devPtrScaledP = input_output_dP->scale.dptr; - void* devPtrDescaledP = input_output_dP->scale_inv.dptr; - - void* devPtrdQ = output_dQ->data.dptr; - void* devPtrdKV = output_dKV->data.dptr; - void* devPtrdK = devPtrdKV; - void* devPtrdV = static_cast(static_cast(devPtrdKV) + stride); - void* devPtrAmaxdQ = output_dQ->amax.dptr; - void* devPtrAmaxdK = output_dKV->amax.dptr; - void* devPtrAmaxdV = output_dKV->amax.dptr; - void* devPtrScaledQ = output_dQ->scale.dptr; - void* devPtrScaledK = output_dKV->scale.dptr; - void* devPtrScaledV = output_dKV->scale.dptr; - - void* devPtrcuSeqlensQ = - reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); - void* devPtrcuSeqlensKV = - reinterpret_cast(reinterpret_cast(cu_seqlens_kv->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, - qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, - devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, - devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, - devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, - devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 3daf45d162..c2efa25829 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -13,47 +13,6 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) -// fused attention FWD FP8 with packed QKV -void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen, - size_t head_dim, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); - -// fused attention BWD FP8 with packed QKV -void fused_attn_fp8_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, - const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, - const Tensor *output_dQKV, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); - -// fused attention FWD FP8 with packed KV -void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, - const Tensor *input_KV, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -// fused attention BWD FP8 with packed KV -void fused_attn_fp8_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, - const Tensor *output_dQ, const Tensor *output_dKV, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); - // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 40e6a0b4bc..298dc63900 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -217,6 +217,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( int64_t window_size_right, bool return_max_logit, bool cuda_graph); /*! \brief Compute dot product attention with packed QKV input. + * + * \deprecated Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead. * * Computes: * - P = Q * Transpose(K) + Bias @@ -270,6 +272,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ +[[deprecated( + "nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate " + "Q, K, V tensors instead.")]] void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, @@ -282,6 +287,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. + * + * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. * * Support Matrix: \verbatim @@ -330,6 +337,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ +[[deprecated( + "nvte_fused_attn_bwd_qkvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate " + "Q, K, V tensors instead.")]] void nvte_fused_attn_bwd_qkvpacked( const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, @@ -340,6 +350,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with packed KV input. + * + * \deprecated Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead. * * Computes: * - P = Q * Transpose(K) + Bias @@ -401,6 +413,9 @@ void nvte_fused_attn_bwd_qkvpacked( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ +[[deprecated( + "nvte_fused_attn_fwd_kvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate " + "Q, K, V tensors instead.")]] void nvte_fused_attn_fwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, @@ -413,6 +428,8 @@ void nvte_fused_attn_fwd_kvpacked( int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. + * + * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. * * Support Matrix: \verbatim @@ -467,6 +484,9 @@ void nvte_fused_attn_fwd_kvpacked( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ +[[deprecated( + "nvte_fused_attn_bwd_kvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate " + "Q, K, V tensors instead.")]] void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,