-
Couldn't load subscription status.
- Fork 3.5k
[webgpu] Register GQA based on graph capture #26384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enables conditional registration of the GroupQueryAttention (GQA) operator based on whether graph capture is enabled in the WebGPU execution provider. When graph capture is enabled, the operator reads total sequence length from GPU buffers instead of CPU memory, eliminating the need for a MemcpyToHost operation that was blocking graph capture support.
Key changes:
- Modified GQA kernel registration to conditionally set InputMemoryType based on graph capture status
- Updated flash attention shader templates and programs to support reading sequence length from GPU buffers
- Added validation logic to handle total_seqlen tensor when it resides on GPU during graph capture
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc | Passes enable_graph_capture flag to RegisterWebGpuContribKernels |
| onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h | Adds enable_graph_capture parameter to RegisterWebGpuContribKernels signature |
| onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc | Replaces static GQA registration with conditional registration via CreateGroupQueryAttentionKernelInfo |
| onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h | Declares CreateGroupQueryAttentionKernelInfo function for conditional kernel creation |
| onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | Implements conditional kernel registration and updates ApplyFlashAttention signature to accept seqlen_k |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template | Adds get_total_sequence_length() function that reads from either GPU buffer or uniforms based on use_seqlen_k flag |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.h | Adds use_seqlen_k member to CopyKVCacheProgram and FlashAttentionProgram classes |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | Implements use_seqlen_k logic in shader code generation and removes past_sequence_length uniform |
| onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h | Updates validation logic to skip CPU-specific checks when total_seqlen is on GPU |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| .SetWorkgroupSize(tile_size) | ||
| .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia) | ||
| .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, use_seqlen_k) | ||
| .AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)}, |
Copilot
AI
Oct 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The removal of the past_sequence_length uniform variable from the uniform list may break shader code that expects this value at a specific index. The shader template now calculates past_sequence_length locally (line 161-162), but any shader code relying on the uniform variable order may fail. Verify that all shader references have been updated accordingly.
| .AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)}, | |
| .AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)}, | |
| {static_cast<uint32_t>(parameters.past_sequence_length_)}, |
This pull request enables conditionally register GQA with total_sequence_length on gpu or not. It resolves the issue that a MemcpyToHost is generated when graph capture is enabled (refer to #25868). This is the last functionality part to support graph capture in webgpu ep in ORT.
The main changes ensure that when graph capture is enabled, sequence length information is read from GPU buffers instead of CPU memory, and shader code generation adapts accordingly. This enables more efficient execution and compatibility with graph-captured models.
In this PR, we still get total sequence length from
seqlen_ktensor nottotal_seqlen_tensortensor to keep consistent with other parts. In the next PR, we can refactor all places to directly usetotal_seqlen_tensorinstead ofseqlen_kwhen graph capture enabled.