Skip to content

Conversation

@qjia7
Copy link
Contributor

@qjia7 qjia7 commented Oct 22, 2025

This pull request enables conditionally register GQA with total_sequence_length on gpu or not. It resolves the issue that a MemcpyToHost is generated when graph capture is enabled (refer to #25868). This is the last functionality part to support graph capture in webgpu ep in ORT.

The main changes ensure that when graph capture is enabled, sequence length information is read from GPU buffers instead of CPU memory, and shader code generation adapts accordingly. This enables more efficient execution and compatibility with graph-captured models.

In this PR, we still get total sequence length from seqlen_k tensor not total_seqlen_tensor tensor to keep consistent with other parts. In the next PR, we can refactor all places to directly use total_seqlen_tensor instead of seqlen_k when graph capture enabled.

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Oct 22, 2025
@guschmue guschmue requested a review from Copilot October 28, 2025 15:44
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enables conditional registration of the GroupQueryAttention (GQA) operator based on whether graph capture is enabled in the WebGPU execution provider. When graph capture is enabled, the operator reads total sequence length from GPU buffers instead of CPU memory, eliminating the need for a MemcpyToHost operation that was blocking graph capture support.

Key changes:

  • Modified GQA kernel registration to conditionally set InputMemoryType based on graph capture status
  • Updated flash attention shader templates and programs to support reading sequence length from GPU buffers
  • Added validation logic to handle total_seqlen tensor when it resides on GPU during graph capture

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc Passes enable_graph_capture flag to RegisterWebGpuContribKernels
onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h Adds enable_graph_capture parameter to RegisterWebGpuContribKernels signature
onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc Replaces static GQA registration with conditional registration via CreateGroupQueryAttentionKernelInfo
onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h Declares CreateGroupQueryAttentionKernelInfo function for conditional kernel creation
onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Implements conditional kernel registration and updates ApplyFlashAttention signature to accept seqlen_k
onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template Adds get_total_sequence_length() function that reads from either GPU buffer or uniforms based on use_seqlen_k flag
onnxruntime/contrib_ops/webgpu/bert/flash_attention.h Adds use_seqlen_k member to CopyKVCacheProgram and FlashAttentionProgram classes
onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Implements use_seqlen_k logic in shader code generation and removes past_sequence_length uniform
onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h Updates validation logic to skip CPU-specific checks when total_seqlen is on GPU

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

.SetWorkgroupSize(tile_size)
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia)
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, use_seqlen_k)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
Copy link

Copilot AI Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The removal of the past_sequence_length uniform variable from the uniform list may break shader code that expects this value at a specific index. The shader template now calculates past_sequence_length locally (line 161-162), but any shader code relying on the uniform variable order may fail. Verify that all shader references have been updated accordingly.

Suggested change
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(parameters.past_sequence_length_)},

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants