From 10967bbd26af4f2bffe547d226021dc80cb1d8dc Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 20 Oct 2025 10:55:13 +0000 Subject: [PATCH 01/10] code drop Signed-off-by: Pawel Gadzinski --- .../common/fused_attn/fused_attn.cpp | 354 ++++++++++--- .../fused_attn_f16_arbitrary_seqlen.cu | 467 ------------------ .../fused_attn_f16_arbitrary_seqlen.h | 45 -- .../fused_attn_f16_max512_seqlen.cu | 264 ---------- .../fused_attn/fused_attn_f16_max512_seqlen.h | 37 -- .../common/fused_attn/fused_attn_fp8.cu | 418 ---------------- .../common/fused_attn/fused_attn_fp8.h | 41 -- 7 files changed, 297 insertions(+), 1329 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 77cd8d235a..a36df51558 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -464,30 +464,84 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_QKV, input_Bias, - output_O, Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, - wkspace, stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = 2 * h * d; // For max512, layout is always BS3HD or SB3HD (3HD group) + + Tensor Q_view = *input_QKV; + Q_view.data.dptr = input_QKV->data.dptr; + + Tensor K_view = *input_QKV; + K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); + + Tensor V_view = *input_QKV; + V_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + + fused_attn_max_512_fwd(b, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, + input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - fused_attn_arbitrary_seqlen_fwd_qkvpacked( - b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, - input_rng_state, wkspace, stream, handle); + // Unpack QKV and call the non-packed function + // Create separate tensor views for Q, K, V from the packed QKV tensor + const auto QKV_type = input_QKV->data.dtype; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + stride = (typeToNumBits(QKV_type) * h * d) / 8; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + stride = (typeToNumBits(QKV_type) * d) / 8; + } + + // Create tensor views for Q, K, V + Tensor Q_view = *input_QKV; + Q_view.data.dptr = input_QKV->data.dptr; + + Tensor K_view = *input_QKV; + K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); + + Tensor V_view = *input_QKV; + V_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + + fused_attn_arbitrary_seqlen_fwd( + b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, + attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, + input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_QKV, input_output_S, output_O, - Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, wkspace, - stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + stride = (typeToNumBits(QKV_type) * h * d) / 8; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + stride = (typeToNumBits(QKV_type) * d) / 8; + } + + Tensor Q_view = *input_QKV; + Q_view.data.dptr = input_QKV->data.dptr; + + Tensor K_view = *input_QKV; + K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); + + Tensor V_view = *input_QKV; + V_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + + fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -549,9 +603,32 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd_qkvpacked( - b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, - input_dO, output_S, output_dQKV, output_dBias, input_cu_seqlens, wkspace, stream, handle); + + // Unpack QKV and dQKV and call the non-packed function + size_t stride = 2 * h * d; + + Tensor Q_view = *input_QKV; + Q_view.data.dptr = input_QKV->data.dptr; + + Tensor K_view = *input_QKV; + K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); + + Tensor V_view = *input_QKV; + V_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + + Tensor dQ_view = *output_dQKV; + dQ_view.data.dptr = output_dQKV->data.dptr; + + Tensor dK_view = *output_dQKV; + dK_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + stride); + + Tensor dV_view = *output_dQKV; + dV_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + 2 * stride); + + fused_attn_max_512_bwd(b, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_dO, + output_S, &dQ_view, &dK_view, &dV_view, output_dBias, + input_cu_seqlens, input_cu_seqlens, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif @@ -567,12 +644,44 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con if (softmax_type != NVTE_VANILLA_SOFTMAX) { input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } - fused_attn_arbitrary_seqlen_bwd_qkvpacked( - b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size_left, window_size_right, deterministic, input_QKV, input_O, - input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQKV, output_dBias, - output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, - stream, handle); + + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + stride = (typeToNumBits(QKV_type) * h * d) / 8; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + stride = (typeToNumBits(QKV_type) * d) / 8; + } + + // Create tensor views for Q, K, V from input + Tensor Q_view = *input_QKV; + Q_view.data.dptr = input_QKV->data.dptr; + + Tensor K_view = *input_QKV; + K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); + + Tensor V_view = *input_QKV; + V_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + + // Create tensor views for dQ, dK, dV from output + Tensor dQ_view = *output_dQKV; + dQ_view.data.dptr = output_dQKV->data.dptr; + + Tensor dK_view = *output_dQKV; + dK_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + stride); + + Tensor dV_view = *output_dQKV; + dV_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + 2 * stride); + + fused_attn_arbitrary_seqlen_bwd( + b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, + attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, + &Q_view, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, + output_S, &dQ_view, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, + input_cu_seqlens, input_cu_seqlens, input_cu_seqlens_padded, input_cu_seqlens_padded, + input_rng_state, wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " @@ -584,10 +693,38 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd_qkvpacked(b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv, - input_S, input_output_dP, output_dQKV, input_cu_seqlens, - input_rng_state, wkspace, stream, handle); + + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + stride = (typeToNumBits(QKV_type) * h * d) / 8; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + stride = (typeToNumBits(QKV_type) * d) / 8; + } + + Tensor Q_view = *input_QKV; + Q_view.data.dptr = input_QKV->data.dptr; + + Tensor K_view = *input_QKV; + K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); + + Tensor V_view = *input_QKV; + V_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + + Tensor dQ_view = *output_dQKV; + dQ_view.data.dptr = output_dQKV->data.dptr; + + Tensor dK_view = *output_dQKV; + dK_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + stride); + + Tensor dV_view = *output_dQKV; + dV_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + 2 * stride); + + fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, + input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, + input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -684,33 +821,73 @@ void nvte_fused_attn_fwd_kvpacked( if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd_kvpacked( - b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + // Unpack KV and call the non-packed function + size_t stride = 2 * h_q * d; // For max512, KV layout is BS2HD or SB2HD + + Tensor K_view = *input_KV; + K_view.data.dptr = input_KV->data.dptr; + + Tensor V_view = *input_KV; + V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + + fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, + &V_view, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) - fused_attn_arbitrary_seqlen_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, - window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, output_O, - Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, - wkspace, stream, handle); + // Unpack KV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + stride = (typeToNumBits(Q_type) * h_kv * d) / 8; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + stride = (typeToNumBits(Q_type) * d) / 8; + } + + // Create tensor views for K, V from input_KV + Tensor K_view = *input_KV; + K_view.data.dptr = input_KV->data.dptr; + + Tensor V_view = *input_KV; + V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + + fused_attn_arbitrary_seqlen_fwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, + attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, + input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, input_KV, input_output_S, output_O, Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + // Unpack KV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + stride = (typeToNumBits(Q_type) * h_kv * d) / 8; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + stride = (typeToNumBits(Q_type) * d) / 8; + } + + Tensor K_view = *input_KV; + K_view.data.dptr = input_KV->data.dptr; + + Tensor V_view = *input_KV; + V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -782,10 +959,26 @@ void nvte_fused_attn_bwd_kvpacked( if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd_kvpacked( - b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, input_Q, input_KV, input_dO, output_S, output_dQ, output_dKV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); + + // Unpack KV and dKV and call the non-packed function + size_t stride = 2 * h_q * d; + + Tensor K_view = *input_KV; + K_view.data.dptr = input_KV->data.dptr; + + Tensor V_view = *input_KV; + V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + + Tensor dK_view = *output_dKV; + dK_view.data.dptr = output_dKV->data.dptr; + + Tensor dV_view = *output_dKV; + dV_view.data.dptr = static_cast(static_cast(output_dKV->data.dptr) + stride); + + fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, + input_dO, output_S, output_dQ, &dK_view, &dV_view, output_dBias, + input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif @@ -801,13 +994,38 @@ void nvte_fused_attn_bwd_kvpacked( if (softmax_type != NVTE_VANILLA_SOFTMAX) { input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } - fused_attn_arbitrary_seqlen_bwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, + + // Unpack KV and dKV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + stride = (typeToNumBits(Q_type) * h_kv * d) / 8; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + stride = (typeToNumBits(Q_type) * d) / 8; + } + + // Create tensor views for K, V from input_KV + Tensor K_view = *input_KV; + K_view.data.dptr = input_KV->data.dptr; + + Tensor V_view = *input_KV; + V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + + // Create tensor views for dK, dV from output_dKV + Tensor dK_view = *output_dKV; + dK_view.data.dptr = output_dKV->data.dptr; + + Tensor dV_view = *output_dKV; + dV_view.data.dptr = static_cast(static_cast(output_dKV->data.dptr) + stride); + + fused_attn_arbitrary_seqlen_bwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, - input_Q, input_KV, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ, - output_dKV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, - handle); + input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, + output_dQ, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, + wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.3 is required for BF16/FP16 fused attention " @@ -819,11 +1037,33 @@ void nvte_fused_attn_bwd_kvpacked( const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd_kvpacked(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, - output_dKV, input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, wkspace, stream, handle); + + // Unpack KV and dKV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + stride = (typeToNumBits(Q_type) * h_kv * d) / 8; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + stride = (typeToNumBits(Q_type) * d) / 8; + } + + Tensor K_view = *input_KV; + K_view.data.dptr = input_KV->data.dptr; + + Tensor V_view = *input_KV; + V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + + Tensor dK_view = *output_dKV; + dK_view.data.dptr = output_dKV->data.dptr; + + Tensor dV_view = *output_dKV; + dV_view.data.dptr = static_cast(static_cast(output_dKV->data.dptr) + stride); + + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O, + input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, + &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index ba0f845789..cec0bfda27 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -999,473 +999,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl( } // namespace fused_attn using namespace transformer_engine::fused_attn; -void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_QKV->data.dtype; - void *devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - bias_b = input_Bias->data.shape[0]; - bias_h = input_Bias->data.shape[1]; - } - void *devPtrSoftmaxOffset = nullptr; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; - } - - void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; - - size_t max_batch_size = 0; - size_t max_tokens = 0; - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - max_tokens = get_max_tokens(num_tokens); - } - - size_t i = 0; - if (Aux_CTX_Tensors->size == 0) { - const auto cudnn_runtime_version = cudnnGetVersion(); - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; - output_bias->data.dtype = QKV_type; - } - - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_softmax_offset->data.dptr = nullptr; - output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; - output_softmax_offset->data.dtype = DType::kFloat32; - } - - Aux_CTX_Tensors->size = i; - } else if (Aux_CTX_Tensors->size >= 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_rng_state->data.dptr = rng_state->data.dptr; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_bias->data.dptr = devPtrBias; - } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_softmax_offset->data.dptr = devPtrSoftmaxOffset; - } - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, - max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, - devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, - nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, - Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, - const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_QKV->data.dtype; - void *devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrQ = devPtrQKV; - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - void *devPtrBias = nullptr; - void *devPtrdBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - devPtrdBias = output_dBias->data.dptr; - bias_b = output_dBias->data.shape[0]; - bias_h = output_dBias->data.shape[1]; - } - - size_t max_batch_size = 0; - size_t max_tokens = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - max_tokens = get_max_tokens(num_tokens); - } - - void *devPtrdQKV = output_dQKV->data.dptr; - void *devPtrdQ = devPtrdQKV; - void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - - void *devPtrSoftmaxStats = nullptr; - devPtrSoftmaxStats = output_S->data.dptr; - void *devPtrSoftmaxOffset = nullptr; - void *devPtrdSoftmaxOffset = nullptr; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; - devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; - } - - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, - max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, - bias_type, mask_type, softmax_type, window_size_left, window_size_right, deterministic, - devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, - devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} -void fused_attn_arbitrary_seqlen_fwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_Q->data.dtype; - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); - - void *devPtrBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - bias_b = input_Bias->data.shape[0]; - bias_h = input_Bias->data.shape[1]; - } - void *devPtrSoftmaxOffset = nullptr; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; - } - - void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; - - void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; - void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; - void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; - void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; - void *devPtrPageTableK = page_table_k->data.dptr; - void *devPtrPageTableV = page_table_v->data.dptr; - - size_t max_batch_size = 0; - size_t max_tokens_q = 0; - size_t max_tokens_kv = 0; - if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - } - if (q_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_q = get_max_tokens(num_tokens_q); - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_kv = get_max_tokens(num_tokens_kv); - } - - size_t i = 0; - if (Aux_CTX_Tensors->size == 0) { - const auto cudnn_runtime_version = cudnnGetVersion(); - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; - output_bias->data.dtype = QKV_type; - } - - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_softmax_offset->data.dptr = nullptr; - output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; - output_softmax_offset->data.dtype = DType::kFloat32; - } - - Aux_CTX_Tensors->size = i; - } else if (Aux_CTX_Tensors->size >= 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_rng_state->data.dptr = rng_state->data.dptr; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_bias->data.dptr = devPtrBias; - } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_softmax_offset->data.dptr = devPtrSoftmaxOffset; - } - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, - devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_arbitrary_seqlen_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, - Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_Q->data.dtype; - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); - - void *devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - void *devPtrBias = nullptr; - void *devPtrdBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - devPtrdBias = output_dBias->data.dptr; - bias_b = output_dBias->data.shape[0]; - bias_h = output_dBias->data.shape[1]; - } - - size_t max_batch_size = 0; - size_t max_tokens_q = 0; - size_t max_tokens_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - } - if (q_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_q = get_max_tokens(num_tokens_q); - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_kv = get_max_tokens(num_tokens_kv); - } - - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdKV = output_dKV->data.dptr; - void *devPtrdK = devPtrdKV; - void *devPtrdV = static_cast(static_cast(devPtrdKV) + stride); - - void *devPtrSoftmaxStats = nullptr; - devPtrSoftmaxStats = output_S->data.dptr; - void *devPtrSoftmaxOffset = nullptr; - void *devPtrdSoftmaxOffset = nullptr; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; - devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; - } - - void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; - void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; - void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; - void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, - devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, - devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index b9658b0530..f22b11044c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -18,51 +18,6 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) -void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, - Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, - const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_arbitrary_seqlen_fwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_arbitrary_seqlen_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, - Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); - void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index 89528fa3c4..1028df6452 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -1215,150 +1215,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv } // namespace fused_attn using namespace transformer_engine::fused_attn; -void fused_attn_max_512_fwd_qkvpacked( - size_t batch, size_t num_head, size_t max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - // QKV shape is [b, s, 3, h, d] - void *devPtrQKV = input_QKV->data.dptr; - const auto stride = 2 * num_head * head_dim; - - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrBias = static_cast(input_Bias->data.dptr); - - void *devPtrO = output_O->data.dptr; - - void *devPtrS = nullptr; - - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 1; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen}; - output_S->data.dtype = input_QKV->data.dtype; - } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devPtrCuSeqlen = cu_seqlens->data.dptr; - - const DType rng_state_type = rng_state->data.dtype; - NVTE_CHECK(rng_state_type == DType::kInt64); - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - static_cast(static_cast(rng_state->data.dptr) + 1); - - const DType QKV_type = input_QKV->data.dtype; - size_t workspace_size = 0; - - fused_attn_max_512_fwd_impl( - batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, - devPtrCuSeqlen, devPtrCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(QKV_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || - bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS, - "NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512."); - - // Q shape is [b, s, h, d] - void *devPtrQ = input_Q->data.dptr; - - // KV shape is [b, s, 2, h, d] - const auto stride = 2 * num_head * head_dim; - void *devPtrK = input_KV->data.dptr; - void *devPtrV = static_cast(static_cast(devPtrK) + stride); - - void *devPtrBias = input_Bias->data.dptr; - - void *devPtrO = output_O->data.dptr; - - void *devPtrS = nullptr; - - const DType q_type = input_Q->data.dtype; - const DType kv_type = input_KV->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 1; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; - output_S->data.dtype = q_type; - } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devQCuSeqlen = q_cu_seqlens->data.dptr; - void *devKVCuSeqlen = kv_cu_seqlens->data.dptr; - - const DType rng_state_type = rng_state->data.dtype; - NVTE_CHECK(rng_state_type == DType::kInt64); - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - static_cast(static_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_max_512_fwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, - devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, @@ -1429,126 +1285,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, } } -void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, - size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV, - Tensor *output_dBias, const Tensor *cu_seqlens, - Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - - // QKV shape is [b, s, 3, h, d] - void *devPtrQKV = input_QKV->data.dptr; - - auto stride = 2 * num_head * head_dim; - void *devPtrQ = devPtrQKV; - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrdO = input_dO->data.dptr; - - // dQKV shape is [b, s, 3, h, d] - void *devPtrdQKV = output_dQKV->data.dptr; - void *devPtrdQ = devPtrdQKV; - void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - - void *devPtrdBias = output_dBias->data.dptr; - - void *devPtrS = output_S->data.dptr; - - // devPtrdS reuses the memory of devPtrS - void *devPtrdS = devPtrS; - - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - - const auto qkv_type = input_QKV->data.dtype; - size_t workspace_size = 0; - - fused_attn_max_512_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, attn_scale, - p_dropout, qkv_layout, mask_type, bias_type, devPtrQ, devPtrK, - devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdS, - devPtrdBias, devPtrCuSeqlens, devPtrCuSeqlens, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(qkv_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, - const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - // Q shape is [b, s, h, d] - // KV shape is [b, s, 2, h, d] - auto stride = 2 * num_head * head_dim; - void *devPtrQ = input_Q->data.dptr; - void *devPtrK = input_KV->data.dptr; - void *devPtrV = static_cast(static_cast(devPtrK) + stride); - - void *devPtrdO = input_dO->data.dptr; - - // dQ shape is [b, s, h, d] - // dKV shape is [b, s, 2, h, d] - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdK = output_dKV->data.dptr; - void *devPtrdV = static_cast(static_cast(devPtrdK) + stride); - - void *devPtrdBias = output_dBias->data.dptr; - - void *devPtrS = output_S->data.dptr; - - // devPtrdS reuses the memory of devPtrS - void *devPtrdS = devPtrS; - - void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr; - void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr; - - const auto q_type = input_Q->data.dtype; - const auto kv_type = input_KV->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - size_t workspace_size = 0; - - fused_attn_max_512_bwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, - mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h index 171fe846ce..57b7afcf43 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h @@ -18,25 +18,6 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8901) -void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, - size_t head_size, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, @@ -47,24 +28,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); -void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, - size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV, - Tensor *output_dBias, const Tensor *cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, - const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 21c544491a..3c50e7ab89 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2405,424 +2405,6 @@ void fused_attn_fp8_bwd_impl_v1( } // namespace fused_attn #if (CUDNN_VERSION >= 8900) -// fused attention FWD FP8 with packed QKV -void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen, - size_t head_dim, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor* input_QKV, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_QKV->data.dtype; - const DType O_type = output_O->data.dtype; - void* devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrQ = static_cast(devPtrQKV); - void* devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void* devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - void* devPtrDescaleQ = input_QKV->scale_inv.dptr; - void* devPtrDescaleK = input_QKV->scale_inv.dptr; - void* devPtrDescaleV = input_QKV->scale_inv.dptr; - - void* devPtrO = output_O->data.dptr; - void* devPtrAmaxO = output_O->amax.dptr; - void* devPtrScaleO = output_O->scale.dptr; - - void* devPtrM = nullptr; - void* devPtrZInv = nullptr; - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 3; - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_M->data.dptr = nullptr; - output_M->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - output_M->data.dtype = DType::kFloat32; - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - output_ZInv->data.dtype = DType::kFloat32; - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - devPtrM = output_M->data.dptr; - devPtrZInv = output_ZInv->data.dptr; - output_rng_state->data.dptr = rng_state->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void* devPtrAmaxS = input_output_S->amax.dptr; - void* devPtrScaleS = input_output_S->scale.dptr; - void* devPtrDescaleS = input_output_S->scale_inv.dptr; - - void* devPtrcuSeqlens = - reinterpret_cast(reinterpret_cast(cu_seqlens->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, - devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, - devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} -// fused attention BWD FP8 with packed QKV -void fused_attn_fp8_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, - const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, - const Tensor* output_dQKV, const Tensor* cu_seqlens, const Tensor* rng_state, Tensor* workspace, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_QKV->data.dtype; - const DType dO_type = input_dO->data.dtype; - const DType dQKV_type = output_dQKV->data.dtype; - void* devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrQ = devPtrQKV; - void* devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void* devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - void* devPtrDescaleQ = input_QKV->scale_inv.dptr; - void* devPtrDescaleK = input_QKV->scale_inv.dptr; - void* devPtrDescaleV = input_QKV->scale_inv.dptr; - - void* devPtrO = input_O->data.dptr; - const DType O_type = input_O->data.dtype; - void* devPtrDescaleO = nullptr; - if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { - devPtrDescaleO = input_O->scale_inv.dptr; - } - void* devPtrdO = input_dO->data.dptr; - void* devPtrDescaledO = input_dO->scale_inv.dptr; - - void* devPtrM = input_M->data.dptr; - void* devPtrZInv = input_ZInv->data.dptr; - - void* devPtrScaleS = input_S->scale.dptr; - void* devPtrDescaleS = input_S->scale_inv.dptr; - void* devPtrAmaxdP = input_output_dP->amax.dptr; - void* devPtrScaledP = input_output_dP->scale.dptr; - void* devPtrDescaledP = input_output_dP->scale_inv.dptr; - - void* devPtrdQKV = output_dQKV->data.dptr; - void* devPtrdQ = devPtrdQKV; - void* devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void* devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - void* devPtrAmaxdQ = output_dQKV->amax.dptr; - void* devPtrAmaxdK = output_dQKV->amax.dptr; - void* devPtrAmaxdV = output_dQKV->amax.dptr; - void* devPtrScaledQ = output_dQKV->scale.dptr; - void* devPtrScaledK = output_dQKV->scale.dptr; - void* devPtrScaledV = output_dQKV->scale.dptr; - - void* devPtrcuSeqlens = - reinterpret_cast(reinterpret_cast(cu_seqlens->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, - devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, - devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, - devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, - devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, - devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} -// fused attention FWD FP8 with packed KV -void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor* input_Q, - const Tensor* input_KV, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, - Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_Q->data.dtype; - const DType O_type = output_O->data.dtype; - void* devPtrQ = input_Q->data.dptr; - void* devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrK = devPtrKV; - void* devPtrV = static_cast(static_cast(devPtrKV) + stride); - void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_KV->scale_inv.dptr; - void* devPtrDescaleV = input_KV->scale_inv.dptr; - - void* devPtrO = output_O->data.dptr; - void* devPtrAmaxO = output_O->amax.dptr; - void* devPtrScaleO = output_O->scale.dptr; - - void* devPtrM = nullptr; - void* devPtrZInv = nullptr; - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 3; - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_M->data.dptr = nullptr; - output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_M->data.dtype = DType::kFloat32; - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_ZInv->data.dtype = DType::kFloat32; - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - devPtrM = output_M->data.dptr; - devPtrZInv = output_ZInv->data.dptr; - output_rng_state->data.dptr = rng_state->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void* devPtrAmaxS = input_output_S->amax.dptr; - void* devPtrScaleS = input_output_S->scale.dptr; - void* devPtrDescaleS = input_output_S->scale_inv.dptr; - - void* devPtrcuSeqlensQ = - reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); - void* devPtrcuSeqlensKV = - reinterpret_cast(reinterpret_cast(cu_seqlens_kv->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, - p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, - devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, - devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} -// fused attention BWD FP8 with packed KV -void fused_attn_fp8_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, - const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, - const Tensor* output_dQ, const Tensor* output_dKV, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_Q->data.dtype; - const DType dO_type = input_dO->data.dtype; - const DType dQKV_type = output_dQ->data.dtype; - void* devPtrQ = input_Q->data.dptr; - void* devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrK = devPtrKV; - void* devPtrV = static_cast(static_cast(devPtrKV) + stride); - void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_KV->scale_inv.dptr; - void* devPtrDescaleV = input_KV->scale_inv.dptr; - - void* devPtrO = input_O->data.dptr; - const DType O_type = input_O->data.dtype; - void* devPtrDescaleO = nullptr; - if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { - devPtrDescaleO = input_O->scale_inv.dptr; - } - void* devPtrdO = input_dO->data.dptr; - void* devPtrDescaledO = input_dO->scale_inv.dptr; - - void* devPtrM = input_M->data.dptr; - void* devPtrZInv = input_ZInv->data.dptr; - - void* devPtrScaleS = input_S->scale.dptr; - void* devPtrDescaleS = input_S->scale_inv.dptr; - void* devPtrAmaxdP = input_output_dP->amax.dptr; - void* devPtrScaledP = input_output_dP->scale.dptr; - void* devPtrDescaledP = input_output_dP->scale_inv.dptr; - - void* devPtrdQ = output_dQ->data.dptr; - void* devPtrdKV = output_dKV->data.dptr; - void* devPtrdK = devPtrdKV; - void* devPtrdV = static_cast(static_cast(devPtrdKV) + stride); - void* devPtrAmaxdQ = output_dQ->amax.dptr; - void* devPtrAmaxdK = output_dKV->amax.dptr; - void* devPtrAmaxdV = output_dKV->amax.dptr; - void* devPtrScaledQ = output_dQ->scale.dptr; - void* devPtrScaledK = output_dKV->scale.dptr; - void* devPtrScaledV = output_dKV->scale.dptr; - - void* devPtrcuSeqlensQ = - reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); - void* devPtrcuSeqlensKV = - reinterpret_cast(reinterpret_cast(cu_seqlens_kv->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, - qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, - devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, - devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, - devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, - devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 3daf45d162..c2efa25829 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -13,47 +13,6 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) -// fused attention FWD FP8 with packed QKV -void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen, - size_t head_dim, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); - -// fused attention BWD FP8 with packed QKV -void fused_attn_fp8_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, - const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, - const Tensor *output_dQKV, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); - -// fused attention FWD FP8 with packed KV -void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, - const Tensor *input_KV, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -// fused attention BWD FP8 with packed KV -void fused_attn_fp8_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, - const Tensor *output_dQ, const Tensor *output_dKV, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); - // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, From 9085883eeb9f0157430115508c34c705f8eb3d13 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Oct 2025 11:19:06 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn.cpp | 250 +++++++++--------- 1 file changed, 129 insertions(+), 121 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index a36df51558..862443013c 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -467,16 +467,17 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, // Unpack QKV and call the non-packed function const auto QKV_type = input_QKV->data.dtype; size_t stride = 2 * h * d; // For max512, layout is always BS3HD or SB3HD (3HD group) - + Tensor Q_view = *input_QKV; Q_view.data.dptr = input_QKV->data.dptr; - + Tensor K_view = *input_QKV; - K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); - + K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); + Tensor V_view = *input_QKV; - V_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); - + V_view.data.dptr = + static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + fused_attn_max_512_fwd(b, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, @@ -496,24 +497,24 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { stride = (typeToNumBits(QKV_type) * d) / 8; } - + // Create tensor views for Q, K, V Tensor Q_view = *input_QKV; Q_view.data.dptr = input_QKV->data.dptr; - + Tensor K_view = *input_QKV; - K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); - + K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); + Tensor V_view = *input_QKV; - V_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); - + V_view.data.dptr = + static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + fused_attn_arbitrary_seqlen_fwd( - b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, - attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, - input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, - wkspace, stream, handle); + b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, + window_size_right, &Q_view, &K_view, &V_view, input_Bias, input_SoftmaxOffset, output_O, + Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, input_cu_seqlens_padded, + input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -528,16 +529,17 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { stride = (typeToNumBits(QKV_type) * d) / 8; } - + Tensor Q_view = *input_QKV; Q_view.data.dptr = input_QKV->data.dptr; - + Tensor K_view = *input_QKV; - K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); - + K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); + Tensor V_view = *input_QKV; - V_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); - + V_view.data.dptr = + static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, @@ -603,32 +605,34 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - + // Unpack QKV and dQKV and call the non-packed function size_t stride = 2 * h * d; - + Tensor Q_view = *input_QKV; Q_view.data.dptr = input_QKV->data.dptr; - + Tensor K_view = *input_QKV; - K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); - + K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); + Tensor V_view = *input_QKV; - V_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); - + V_view.data.dptr = + static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + Tensor dQ_view = *output_dQKV; dQ_view.data.dptr = output_dQKV->data.dptr; - + Tensor dK_view = *output_dQKV; - dK_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + stride); - + dK_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + stride); + Tensor dV_view = *output_dQKV; - dV_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + 2 * stride); - + dV_view.data.dptr = + static_cast(static_cast(output_dQKV->data.dptr) + 2 * stride); + fused_attn_max_512_bwd(b, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_dO, - output_S, &dQ_view, &dK_view, &dV_view, output_dBias, - input_cu_seqlens, input_cu_seqlens, wkspace, stream, handle); + bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_dO, output_S, + &dQ_view, &dK_view, &dV_view, output_dBias, input_cu_seqlens, + input_cu_seqlens, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif @@ -644,7 +648,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con if (softmax_type != NVTE_VANILLA_SOFTMAX) { input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } - + // Unpack QKV and dQKV and call the non-packed function const auto QKV_type = input_QKV->data.dtype; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); @@ -654,34 +658,35 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { stride = (typeToNumBits(QKV_type) * d) / 8; } - + // Create tensor views for Q, K, V from input Tensor Q_view = *input_QKV; Q_view.data.dptr = input_QKV->data.dptr; - + Tensor K_view = *input_QKV; - K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); - + K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); + Tensor V_view = *input_QKV; - V_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); - + V_view.data.dptr = + static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + // Create tensor views for dQ, dK, dV from output Tensor dQ_view = *output_dQKV; dQ_view.data.dptr = output_dQKV->data.dptr; - + Tensor dK_view = *output_dQKV; - dK_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + stride); - + dK_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + stride); + Tensor dV_view = *output_dQKV; - dV_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + 2 * stride); - + dV_view.data.dptr = + static_cast(static_cast(output_dQKV->data.dptr) + 2 * stride); + fused_attn_arbitrary_seqlen_bwd( b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, - &Q_view, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, - output_S, &dQ_view, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, - input_cu_seqlens, input_cu_seqlens, input_cu_seqlens_padded, input_cu_seqlens_padded, - input_rng_state, wkspace, stream, handle); + attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view, + &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view, + &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, + input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " @@ -693,7 +698,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - + // Unpack QKV and dQKV and call the non-packed function const auto QKV_type = input_QKV->data.dtype; size_t stride = 0; @@ -702,29 +707,32 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { stride = (typeToNumBits(QKV_type) * d) / 8; } - + Tensor Q_view = *input_QKV; Q_view.data.dptr = input_QKV->data.dptr; - + Tensor K_view = *input_QKV; - K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); - + K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); + Tensor V_view = *input_QKV; - V_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); - + V_view.data.dptr = + static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + Tensor dQ_view = *output_dQKV; dQ_view.data.dptr = output_dQKV->data.dptr; - + Tensor dK_view = *output_dQKV; - dK_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + stride); - + dK_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + stride); + Tensor dV_view = *output_dQKV; - dV_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + 2 * stride); - + dV_view.data.dptr = + static_cast(static_cast(output_dQKV->data.dptr) + 2 * stride); + fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, - input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, handle); + input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, + handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -823,16 +831,16 @@ void nvte_fused_attn_fwd_kvpacked( #if (CUDNN_VERSION >= 8901) // Unpack KV and call the non-packed function size_t stride = 2 * h_q * d; // For max512, KV layout is BS2HD or SB2HD - + Tensor K_view = *input_KV; K_view.data.dptr = input_KV->data.dptr; - + Tensor V_view = *input_KV; - V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); - - fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, - &V_view, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + + fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, + input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); @@ -847,22 +855,22 @@ void nvte_fused_attn_fwd_kvpacked( } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { stride = (typeToNumBits(Q_type) * d) / 8; } - + // Create tensor views for K, V from input_KV Tensor K_view = *input_KV; K_view.data.dptr = input_KV->data.dptr; - + Tensor V_view = *input_KV; - V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); - + V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, - attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, - input_page_table_v, input_rng_state, wkspace, stream, handle); + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, + window_size_right, input_Q, &K_view, &V_view, input_Bias, input_SoftmaxOffset, output_O, + Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -877,13 +885,13 @@ void nvte_fused_attn_fwd_kvpacked( } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { stride = (typeToNumBits(Q_type) * d) / 8; } - + Tensor K_view = *input_KV; K_view.data.dptr = input_KV->data.dptr; - + Tensor V_view = *input_KV; - V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); - + V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, @@ -959,26 +967,26 @@ void nvte_fused_attn_bwd_kvpacked( if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - + // Unpack KV and dKV and call the non-packed function size_t stride = 2 * h_q * d; - + Tensor K_view = *input_KV; K_view.data.dptr = input_KV->data.dptr; - + Tensor V_view = *input_KV; - V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); - + V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + Tensor dK_view = *output_dKV; dK_view.data.dptr = output_dKV->data.dptr; - + Tensor dV_view = *output_dKV; - dV_view.data.dptr = static_cast(static_cast(output_dKV->data.dptr) + stride); - - fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, - input_dO, output_S, output_dQ, &dK_view, &dV_view, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); + dV_view.data.dptr = static_cast(static_cast(output_dKV->data.dptr) + stride); + + fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_dO, output_S, + output_dQ, &dK_view, &dV_view, output_dBias, input_cu_seqlens_q, + input_cu_seqlens_kv, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif @@ -994,7 +1002,7 @@ void nvte_fused_attn_bwd_kvpacked( if (softmax_type != NVTE_VANILLA_SOFTMAX) { input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } - + // Unpack KV and dKV and call the non-packed function const auto Q_type = input_Q->data.dtype; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); @@ -1004,21 +1012,21 @@ void nvte_fused_attn_bwd_kvpacked( } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { stride = (typeToNumBits(Q_type) * d) / 8; } - + // Create tensor views for K, V from input_KV Tensor K_view = *input_KV; K_view.data.dptr = input_KV->data.dptr; - + Tensor V_view = *input_KV; - V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); - + V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + // Create tensor views for dK, dV from output_dKV Tensor dK_view = *output_dKV; dK_view.data.dptr = output_dKV->data.dptr; - + Tensor dV_view = *output_dKV; - dV_view.data.dptr = static_cast(static_cast(output_dKV->data.dptr) + stride); - + dV_view.data.dptr = static_cast(static_cast(output_dKV->data.dptr) + stride); + fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, @@ -1037,7 +1045,7 @@ void nvte_fused_attn_bwd_kvpacked( const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - + // Unpack KV and dKV and call the non-packed function const auto Q_type = input_Q->data.dtype; size_t stride = 0; @@ -1046,24 +1054,24 @@ void nvte_fused_attn_bwd_kvpacked( } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { stride = (typeToNumBits(Q_type) * d) / 8; } - + Tensor K_view = *input_KV; K_view.data.dptr = input_KV->data.dptr; - + Tensor V_view = *input_KV; - V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); - + V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + Tensor dK_view = *output_dKV; dK_view.data.dptr = output_dKV->data.dptr; - + Tensor dV_view = *output_dKV; - dV_view.data.dptr = static_cast(static_cast(output_dKV->data.dptr) + stride); - + dV_view.data.dptr = static_cast(static_cast(output_dKV->data.dptr) + stride); + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, - &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, wkspace, stream, handle); + input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, + &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, + stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif From 6b6e78a6b97222d242ddfe83a6ce6ea13f8defde Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 20 Oct 2025 14:35:21 +0000 Subject: [PATCH 03/10] fix Signed-off-by: Pawel Gadzinski --- .../common/fused_attn/fused_attn.cpp | 346 +++++++++--------- .../fused_attn_f16_arbitrary_seqlen.cu | 4 +- 2 files changed, 168 insertions(+), 182 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 862443013c..32325e2978 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -15,6 +15,76 @@ #include "fused_attn_fp8.h" #include "utils.h" +namespace { +// Helper function to create a tensor view with modified shape and optional pointer offset +transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor* source, + const std::vector& shape, + size_t offset_bytes = 0) { + transformer_engine::Tensor view = *source; + if (offset_bytes > 0) { + view.data.dptr = static_cast(static_cast(source->data.dptr) + offset_bytes); + } + view.data.shape = shape; + view.nvte_tensor = 0; // Mark as unmanaged/local tensor view + return view; +} + +// Helper function to calculate stride for packed QKV tensor unpacking +size_t calculate_qkv_stride(NVTE_QKV_Layout_Group layout_group, + transformer_engine::DType dtype, + size_t h, size_t d) { + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + stride = (transformer_engine::typeToNumBits(dtype) * h * d) / 8; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + stride = (transformer_engine::typeToNumBits(dtype) * d) / 8; + } + return stride; +} + +// Helper function to determine unpacked shape for QKV packed tensor +std::vector calculate_qkv_unpacked_shape(const transformer_engine::Tensor* qkv_tensor, + size_t h, size_t d) { + std::vector unpacked_shape; + if (qkv_tensor->data.shape.size() == 4) { + // T3HD or TH3D (4D) -> THD (3D): remove dimension "3" at position 1 + unpacked_shape = {qkv_tensor->data.shape[0], h, d}; + } else { + // BS3HD/SB3HD or BSH3D/SBH3D (5D) -> BSHD/SBHD (4D): remove dimension "3" at position 2 + unpacked_shape = {qkv_tensor->data.shape[0], qkv_tensor->data.shape[1], h, d}; + } + return unpacked_shape; +} + +// Helper function to calculate stride for packed KV tensor unpacking +size_t calculate_kv_stride(NVTE_QKV_Layout_Group layout_group, + transformer_engine::DType dtype, + size_t h_kv, size_t d) { + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + stride = (transformer_engine::typeToNumBits(dtype) * h_kv * d) / 8; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + stride = (transformer_engine::typeToNumBits(dtype) * d) / 8; + } + return stride; +} + +// Helper function to determine unpacked shape for KV packed tensor +std::vector calculate_kv_unpacked_shape(const transformer_engine::Tensor* kv_tensor, + NVTE_QKV_Layout_Group layout_group, + NVTE_QKV_Format kv_format, + size_t t_kv, size_t h_kv, size_t d) { + std::vector unpacked_kv_shape; + if (kv_format == NVTE_QKV_Format::NVTE_THD) { + unpacked_kv_shape = {t_kv, h_kv, d}; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD || + layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + unpacked_kv_shape = {kv_tensor->data.shape[0], kv_tensor->data.shape[1], h_kv, d}; + } + return unpacked_kv_shape; +} +} // namespace + // map NVTE_QKV_Layout to NVTE_QKV_Layout_Group NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { switch (qkv_layout) { @@ -467,16 +537,12 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, // Unpack QKV and call the non-packed function const auto QKV_type = input_QKV->data.dtype; size_t stride = 2 * h * d; // For max512, layout is always BS3HD or SB3HD (3HD group) + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - Tensor Q_view = *input_QKV; - Q_view.data.dptr = input_QKV->data.dptr; - - Tensor K_view = *input_QKV; - K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); - - Tensor V_view = *input_QKV; - V_view.data.dptr = - static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); fused_attn_max_512_fwd(b, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, @@ -488,26 +554,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) // Unpack QKV and call the non-packed function - // Create separate tensor views for Q, K, V from the packed QKV tensor const auto QKV_type = input_QKV->data.dtype; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * h * d) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * d) / 8; - } + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); // Create tensor views for Q, K, V - Tensor Q_view = *input_QKV; - Q_view.data.dptr = input_QKV->data.dptr; - - Tensor K_view = *input_QKV; - K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); - - Tensor V_view = *input_QKV; - V_view.data.dptr = - static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); fused_attn_arbitrary_seqlen_fwd( b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, attn_scale, @@ -523,22 +577,13 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, #if (CUDNN_VERSION >= 8900) // Unpack QKV and call the non-packed function const auto QKV_type = input_QKV->data.dtype; - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * h * d) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * d) / 8; - } - - Tensor Q_view = *input_QKV; - Q_view.data.dptr = input_QKV->data.dptr; - - Tensor K_view = *input_QKV; - K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - Tensor V_view = *input_QKV; - V_view.data.dptr = - static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, @@ -607,27 +652,18 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; size_t stride = 2 * h * d; + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - Tensor Q_view = *input_QKV; - Q_view.data.dptr = input_QKV->data.dptr; - - Tensor K_view = *input_QKV; - K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); - - Tensor V_view = *input_QKV; - V_view.data.dptr = - static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - Tensor dQ_view = *output_dQKV; - dQ_view.data.dptr = output_dQKV->data.dptr; - - Tensor dK_view = *output_dQKV; - dK_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + stride); - - Tensor dV_view = *output_dQKV; - dV_view.data.dptr = - static_cast(static_cast(output_dQKV->data.dptr) + 2 * stride); + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); fused_attn_max_512_bwd(b, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_dO, output_S, @@ -651,35 +687,17 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con // Unpack QKV and dQKV and call the non-packed function const auto QKV_type = input_QKV->data.dtype; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * h * d) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * d) / 8; - } - - // Create tensor views for Q, K, V from input - Tensor Q_view = *input_QKV; - Q_view.data.dptr = input_QKV->data.dptr; - - Tensor K_view = *input_QKV; - K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - Tensor V_view = *input_QKV; - V_view.data.dptr = - static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - // Create tensor views for dQ, dK, dV from output - Tensor dQ_view = *output_dQKV; - dQ_view.data.dptr = output_dQKV->data.dptr; - - Tensor dK_view = *output_dQKV; - dK_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + stride); - - Tensor dV_view = *output_dQKV; - dV_view.data.dptr = - static_cast(static_cast(output_dQKV->data.dptr) + 2 * stride); + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); fused_attn_arbitrary_seqlen_bwd( b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, @@ -701,32 +719,17 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con // Unpack QKV and dQKV and call the non-packed function const auto QKV_type = input_QKV->data.dtype; - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * h * d) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * d) / 8; - } - - Tensor Q_view = *input_QKV; - Q_view.data.dptr = input_QKV->data.dptr; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - Tensor K_view = *input_QKV; - K_view.data.dptr = static_cast(static_cast(input_QKV->data.dptr) + stride); + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - Tensor V_view = *input_QKV; - V_view.data.dptr = - static_cast(static_cast(input_QKV->data.dptr) + 2 * stride); - - Tensor dQ_view = *output_dQKV; - dQ_view.data.dptr = output_dQKV->data.dptr; - - Tensor dK_view = *output_dQKV; - dK_view.data.dptr = static_cast(static_cast(output_dQKV->data.dptr) + stride); - - Tensor dV_view = *output_dQKV; - dV_view.data.dptr = - static_cast(static_cast(output_dQKV->data.dptr) + 2 * stride); + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, @@ -832,11 +835,18 @@ void nvte_fused_attn_fwd_kvpacked( // Unpack KV and call the non-packed function size_t stride = 2 * h_q * d; // For max512, KV layout is BS2HD or SB2HD - Tensor K_view = *input_KV; - K_view.data.dptr = input_KV->data.dptr; + // Create tensor views for K, V + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + std::vector unpacked_kv_shape; + if (kv_format == NVTE_QKV_Format::NVTE_THD) { + unpacked_kv_shape = {t_kv, h_kv, d}; + } else { + // BS2HD or SB2HD -> BSHD or SBHD + unpacked_kv_shape = {input_KV->data.shape[0], input_KV->data.shape[1], h_kv, d}; + } - Tensor V_view = *input_KV; - V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, @@ -849,19 +859,13 @@ void nvte_fused_attn_fwd_kvpacked( #if (CUDNN_VERSION >= 8903) // Unpack KV and call the non-packed function const auto Q_type = input_Q->data.dtype; - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(Q_type) * h_kv * d) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(Q_type) * d) / 8; - } - - // Create tensor views for K, V from input_KV - Tensor K_view = *input_KV; - K_view.data.dptr = input_KV->data.dptr; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - Tensor V_view = *input_KV; - V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, @@ -879,18 +883,13 @@ void nvte_fused_attn_fwd_kvpacked( #if (CUDNN_VERSION >= 8900) // Unpack KV and call the non-packed function const auto Q_type = input_Q->data.dtype; - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(Q_type) * h_kv * d) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(Q_type) * d) / 8; - } - - Tensor K_view = *input_KV; - K_view.data.dptr = input_KV->data.dptr; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - Tensor V_view = *input_KV; - V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, @@ -971,17 +970,21 @@ void nvte_fused_attn_bwd_kvpacked( // Unpack KV and dKV and call the non-packed function size_t stride = 2 * h_q * d; - Tensor K_view = *input_KV; - K_view.data.dptr = input_KV->data.dptr; - - Tensor V_view = *input_KV; - V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + // Create tensor views for K, V + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + std::vector unpacked_kv_shape; + if (kv_format == NVTE_QKV_Format::NVTE_THD) { + unpacked_kv_shape = {t_kv, h_kv, d}; + } else { + // BS2HD or SB2HD -> BSHD or SBHD + unpacked_kv_shape = {input_KV->data.shape[0], input_KV->data.shape[1], h_kv, d}; + } - Tensor dK_view = *output_dKV; - dK_view.data.dptr = output_dKV->data.dptr; + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - Tensor dV_view = *output_dKV; - dV_view.data.dptr = static_cast(static_cast(output_dKV->data.dptr) + stride); + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_dO, output_S, @@ -1006,26 +1009,17 @@ void nvte_fused_attn_bwd_kvpacked( // Unpack KV and dKV and call the non-packed function const auto Q_type = input_Q->data.dtype; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(Q_type) * h_kv * d) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(Q_type) * d) / 8; - } - - // Create tensor views for K, V from input_KV - Tensor K_view = *input_KV; - K_view.data.dptr = input_KV->data.dptr; - - Tensor V_view = *input_KV; - V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - // Create tensor views for dK, dV from output_dKV - Tensor dK_view = *output_dKV; - dK_view.data.dptr = output_dKV->data.dptr; + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - Tensor dV_view = *output_dKV; - dV_view.data.dptr = static_cast(static_cast(output_dKV->data.dptr) + stride); + // Create tensor views for dK, dV + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout, @@ -1048,24 +1042,16 @@ void nvte_fused_attn_bwd_kvpacked( // Unpack KV and dKV and call the non-packed function const auto Q_type = input_Q->data.dtype; - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(Q_type) * h_kv * d) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(Q_type) * d) / 8; - } - - Tensor K_view = *input_KV; - K_view.data.dptr = input_KV->data.dptr; - - Tensor V_view = *input_KV; - V_view.data.dptr = static_cast(static_cast(input_KV->data.dptr) + stride); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - Tensor dK_view = *output_dKV; - dK_view.data.dptr = output_dKV->data.dptr; + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - Tensor dV_view = *output_dKV; - dV_view.data.dptr = static_cast(static_cast(output_dKV->data.dptr) + stride); + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index cec0bfda27..5f80633aea 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1039,8 +1039,8 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; - void *devPtrPageTableK = page_table_k->data.dptr; - void *devPtrPageTableV = page_table_v->data.dptr; + void *devPtrPageTableK = page_table_k ? page_table_k->data.dptr : nullptr; + void *devPtrPageTableV = page_table_v ? page_table_v->data.dptr : nullptr; size_t max_batch_size = 0; size_t max_tokens_q = 0; From 6f9b4660877000f8ae5a8bc9a0cd1c999084a0a9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Oct 2025 15:20:55 +0000 Subject: [PATCH 04/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn.cpp | 36 +++++++++---------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 32325e2978..7602110598 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -17,12 +17,12 @@ namespace { // Helper function to create a tensor view with modified shape and optional pointer offset -transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor* source, - const std::vector& shape, +transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor *source, + const std::vector &shape, size_t offset_bytes = 0) { transformer_engine::Tensor view = *source; if (offset_bytes > 0) { - view.data.dptr = static_cast(static_cast(source->data.dptr) + offset_bytes); + view.data.dptr = static_cast(static_cast(source->data.dptr) + offset_bytes); } view.data.shape = shape; view.nvte_tensor = 0; // Mark as unmanaged/local tensor view @@ -30,9 +30,8 @@ transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor* so } // Helper function to calculate stride for packed QKV tensor unpacking -size_t calculate_qkv_stride(NVTE_QKV_Layout_Group layout_group, - transformer_engine::DType dtype, - size_t h, size_t d) { +size_t calculate_qkv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype, + size_t h, size_t d) { size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { stride = (transformer_engine::typeToNumBits(dtype) * h * d) / 8; @@ -43,8 +42,8 @@ size_t calculate_qkv_stride(NVTE_QKV_Layout_Group layout_group, } // Helper function to determine unpacked shape for QKV packed tensor -std::vector calculate_qkv_unpacked_shape(const transformer_engine::Tensor* qkv_tensor, - size_t h, size_t d) { +std::vector calculate_qkv_unpacked_shape(const transformer_engine::Tensor *qkv_tensor, + size_t h, size_t d) { std::vector unpacked_shape; if (qkv_tensor->data.shape.size() == 4) { // T3HD or TH3D (4D) -> THD (3D): remove dimension "3" at position 1 @@ -57,9 +56,8 @@ std::vector calculate_qkv_unpacked_shape(const transformer_engine::Tenso } // Helper function to calculate stride for packed KV tensor unpacking -size_t calculate_kv_stride(NVTE_QKV_Layout_Group layout_group, - transformer_engine::DType dtype, - size_t h_kv, size_t d) { +size_t calculate_kv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype, + size_t h_kv, size_t d) { size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { stride = (transformer_engine::typeToNumBits(dtype) * h_kv * d) / 8; @@ -70,10 +68,10 @@ size_t calculate_kv_stride(NVTE_QKV_Layout_Group layout_group, } // Helper function to determine unpacked shape for KV packed tensor -std::vector calculate_kv_unpacked_shape(const transformer_engine::Tensor* kv_tensor, - NVTE_QKV_Layout_Group layout_group, - NVTE_QKV_Format kv_format, - size_t t_kv, size_t h_kv, size_t d) { +std::vector calculate_kv_unpacked_shape(const transformer_engine::Tensor *kv_tensor, + NVTE_QKV_Layout_Group layout_group, + NVTE_QKV_Format kv_format, size_t t_kv, size_t h_kv, + size_t d) { std::vector unpacked_kv_shape; if (kv_format == NVTE_QKV_Format::NVTE_THD) { unpacked_kv_shape = {t_kv, h_kv, d}; @@ -861,7 +859,7 @@ void nvte_fused_attn_fwd_kvpacked( const auto Q_type = input_Q->data.dtype; NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); - std::vector unpacked_kv_shape = + std::vector unpacked_kv_shape = calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); @@ -885,7 +883,7 @@ void nvte_fused_attn_fwd_kvpacked( const auto Q_type = input_Q->data.dtype; NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); - std::vector unpacked_kv_shape = + std::vector unpacked_kv_shape = calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); @@ -1011,7 +1009,7 @@ void nvte_fused_attn_bwd_kvpacked( NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); - std::vector unpacked_kv_shape = + std::vector unpacked_kv_shape = calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); @@ -1044,7 +1042,7 @@ void nvte_fused_attn_bwd_kvpacked( const auto Q_type = input_Q->data.dtype; NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); - std::vector unpacked_kv_shape = + std::vector unpacked_kv_shape = calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); From 8ca92522d17e3c37265e33ea7aff2d0eecdb03fb Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 5 Nov 2025 14:06:02 +0000 Subject: [PATCH 05/10] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/common/fused_attn/fused_attn.cpp | 8 ++++++++ .../common/include/transformer_engine/fused_attn.h | 12 ++++++++++++ 2 files changed, 20 insertions(+) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 7602110598..4f0f9504eb 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -482,6 +482,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with packed QKV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, @@ -595,6 +597,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } } // NVTE fused attention BWD with packed QKV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, @@ -742,6 +746,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con } } // NVTE fused attention FWD with packed KV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. void nvte_fused_attn_fwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, @@ -901,6 +907,8 @@ void nvte_fused_attn_fwd_kvpacked( } } // NVTE fused attention BWD with packed KV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index a150978c4a..ae2f4e055e 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -215,6 +215,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( int64_t window_size_right); /*! \brief Compute dot product attention with packed QKV input. + * + * \warning This API is **deprecated**. + * Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead. * * Computes: * - P = Q * Transpose(K) + Bias @@ -275,6 +278,9 @@ void nvte_fused_attn_fwd_qkvpacked( int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. + * + * \warning This API is **deprecated**. + * Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. * * Support Matrix: \verbatim @@ -334,6 +340,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con bool deterministic, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with packed KV input. + * + * \warning This API is **deprecated**. + * Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead. * * Computes: * - P = Q * Transpose(K) + Bias @@ -405,6 +414,9 @@ void nvte_fused_attn_fwd_kvpacked( NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. + * + * \warning This API is **deprecated**. + * Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. * * Support Matrix: \verbatim From 87c7fbf14c19d5948c80c11aeff543fea7e9e05e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Nov 2025 14:31:58 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn.cpp | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index a979542110..23ff4d7927 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -564,11 +564,12 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); fused_attn_arbitrary_seqlen_fwd( - b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, return_max_logit, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, - window_size_right, &Q_view, &K_view, &V_view, input_Bias, input_SoftmaxOffset, output_O, - Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, input_cu_seqlens_padded, - input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, wkspace, stream, handle); + b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, + return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, + input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -874,12 +875,12 @@ void nvte_fused_attn_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, - window_size_right, input_Q, &K_view, &V_view, input_Bias, input_SoftmaxOffset, output_O, - Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, - wkspace, stream, handle); + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, + return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, + input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); From ec7531de42f6ded2a0fe51e0e37c9bce876122dc Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 5 Nov 2025 15:51:36 +0000 Subject: [PATCH 07/10] fix Signed-off-by: Pawel Gadzinski --- .../common/fused_attn/fused_attn.cpp | 30 +++++-------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 23ff4d7927..f16e707501 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -536,7 +536,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, #if (CUDNN_VERSION >= 8901) // Unpack QKV and call the non-packed function const auto QKV_type = input_QKV->data.dtype; - size_t stride = 2 * h * d; // For max512, layout is always BS3HD or SB3HD (3HD group) + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); // Create tensor views for Q, K, V @@ -656,7 +656,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con // Unpack QKV and dQKV and call the non-packed function const auto QKV_type = input_QKV->data.dtype; - size_t stride = 2 * h * d; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); // Create tensor views for Q, K, V and dQ, dK, dV @@ -839,17 +839,10 @@ void nvte_fused_attn_fwd_kvpacked( if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) // Unpack KV and call the non-packed function - size_t stride = 2 * h_q * d; // For max512, KV layout is BS2HD or SB2HD - - // Create tensor views for K, V NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - std::vector unpacked_kv_shape; - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - unpacked_kv_shape = {t_kv, h_kv, d}; - } else { - // BS2HD or SB2HD -> BSHD or SBHD - unpacked_kv_shape = {input_KV->data.shape[0], input_KV->data.shape[1], h_kv, d}; - } + size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); @@ -976,17 +969,10 @@ void nvte_fused_attn_bwd_kvpacked( Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); // Unpack KV and dKV and call the non-packed function - size_t stride = 2 * h_q * d; - - // Create tensor views for K, V NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - std::vector unpacked_kv_shape; - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - unpacked_kv_shape = {t_kv, h_kv, d}; - } else { - // BS2HD or SB2HD -> BSHD or SBHD - unpacked_kv_shape = {input_KV->data.shape[0], input_KV->data.shape[1], h_kv, d}; - } + size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); From af1fb164262b88d8d71b204dd95c773d40c7ddef Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 7 Nov 2025 09:08:57 +0000 Subject: [PATCH 08/10] depracted compile time warning + \warning -> \deprecated Signed-off-by: Pawel Gadzinski --- .../include/transformer_engine/fused_attn.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 61b3c0dfb5..aeb10d3124 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -217,8 +217,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( /*! \brief Compute dot product attention with packed QKV input. * - * \warning This API is **deprecated**. - * Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead. + * \deprecated Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead. * * Computes: * - P = Q * Transpose(K) + Bias @@ -271,6 +270,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ +[[deprecated("nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate Q, K, V tensors instead.")]] void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, @@ -284,8 +284,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, /*! \brief Compute the backward of the dot product attention with packed QKV input. * - * \warning This API is **deprecated**. - * Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. + * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. * * Support Matrix: \verbatim @@ -333,6 +332,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ +[[deprecated("nvte_fused_attn_bwd_qkvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate Q, K, V tensors instead.")]] void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, @@ -346,8 +346,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con /*! \brief Compute dot product attention with packed KV input. * - * \warning This API is **deprecated**. - * Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead. + * \deprecated Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead. * * Computes: * - P = Q * Transpose(K) + Bias @@ -408,6 +407,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ +[[deprecated("nvte_fused_attn_fwd_kvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate Q, K, V tensors instead.")]] void nvte_fused_attn_fwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, @@ -421,8 +421,7 @@ void nvte_fused_attn_fwd_kvpacked( /*! \brief Compute the backward of the dot product attention with packed KV input. * - * \warning This API is **deprecated**. - * Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. + * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. * * Support Matrix: \verbatim @@ -476,6 +475,7 @@ void nvte_fused_attn_fwd_kvpacked( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ +[[deprecated("nvte_fused_attn_bwd_kvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate Q, K, V tensors instead.")]] void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, From 44cc48832f1f4a90b7b598275901dfa7e6f74bf5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Nov 2025 09:10:40 +0000 Subject: [PATCH 09/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../include/transformer_engine/fused_attn.h | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index aeb10d3124..bc626d42f3 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -270,7 +270,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -[[deprecated("nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate Q, K, V tensors instead.")]] +[[deprecated( + "nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate " + "Q, K, V tensors instead.")]] void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, @@ -332,7 +334,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -[[deprecated("nvte_fused_attn_bwd_qkvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate Q, K, V tensors instead.")]] +[[deprecated( + "nvte_fused_attn_bwd_qkvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate " + "Q, K, V tensors instead.")]] void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, @@ -407,7 +411,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -[[deprecated("nvte_fused_attn_fwd_kvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate Q, K, V tensors instead.")]] +[[deprecated( + "nvte_fused_attn_fwd_kvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate " + "Q, K, V tensors instead.")]] void nvte_fused_attn_fwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, @@ -475,7 +481,9 @@ void nvte_fused_attn_fwd_kvpacked( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -[[deprecated("nvte_fused_attn_bwd_kvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate Q, K, V tensors instead.")]] +[[deprecated( + "nvte_fused_attn_bwd_kvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate " + "Q, K, V tensors instead.")]] void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, From 4f99a55ce0336eb03a6da1c83703a2ae29b4fadd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Nov 2025 20:06:37 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn.cpp | 18 ++++++++---------- .../include/transformer_engine/fused_attn.h | 18 ++++++++---------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 286c0c6d69..ac6fefdc6a 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -623,16 +623,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, // NVTE fused attention BWD with packed QKV // DEPRECATED: This API is deprecated. // Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. -void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, - const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, NVTETensor dSoftmaxOffset, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - size_t max_seqlen, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, + NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, + NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index c6a0e18945..298dc63900 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -340,16 +340,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, [[deprecated( "nvte_fused_attn_bwd_qkvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate " "Q, K, V tensors instead.")]] -void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, - const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, NVTETensor dSoftmaxOffset, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - size_t max_seqlen, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, + NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, + NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with packed KV input. *