-
Notifications
You must be signed in to change notification settings - Fork 555
Fix trtllm-gen attention illegal memory access #2002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix trtllm-gen attention illegal memory access #2002
Conversation
Summary of ChangesHello @Tom-Zheng, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical illegal memory access bug affecting the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
WalkthroughThe paged_run wrapper in the TrtllmGenDecodeModule now passes float_workspace_buffer instead of int_workspace_buffer to the trtllm_paged_attention_decode kernel, changing the workspace buffer data source for the decode computation. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a critical illegal memory access bug in the trtllm-gen attention kernels by correctly using float_workspace_buffer instead of the uninitialized int_workspace_buffer in flashinfer/decode.py. The change is correct and effectively resolves the issue for the decode path.
However, a similar vulnerability exists in flashinfer/prefill.py where int_workspace_buffer is also used for the trtllm-gen backend's paged attention. To ensure comprehensive stability and prevent similar memory errors, it is highly recommended to apply the same fix to the prefill path in this pull request.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
flashinfer/decode.py (2)
1928-1936: Op name mismatch: ragged_run vs paged_run.Custom op registers as ..._ragged_run (Line 1928) while the fake op registers as ..._paged_run (Line 2004). This inconsistency can confuse tooling/tests and complicate mock paths.
Unify the op name (ragged_run or paged_run) across both decorators for clarity and easier stubbing.
Also applies to: 2004-2012
1988-2002: Fix is applied correctly; capacity guard is optional but recommended.The critical change from int_workspace_buffer to float_workspace_buffer (line 1991) is confirmed in the trtllm-gen decode path. This prevents illegal memory access with the kernel's zero-initialized workspace.
The optional capacity guard suggestion is good practice to fail fast on buffer mismatches:
ws_cap = float_workspace_buffer.numel() * float_workspace_buffer.element_size() assert ws_cap >= workspace_size, f"buffer {ws_cap} bytes < required {workspace_size} bytes"Minor note: Custom op registered as
"..._ragged_run"(line 1930) but fake op uses"..._paged_run"(line 1935)—consider aligning names if applicable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for spotting the bug and working on bugfix!
|
/bot run |
|
[CANCELED] Pipeline #37527615: canceled |
…benchmark (#2018) <!-- .github/pull_request_template.md --> ## 📌 Description [tests/attention/test_trtllm_gen_attention.py](https://github.com/flashinfer-ai/flashinfer/blob/v0.5.0rc2/tests/attention/test_trtllm_gen_attention.py#L1021-L1076) was failing and therefore marked xfail. PR #2002 fixed the underlying root cause. Current PR thus removed the `xfail` marker so that these long seqlen cases could be fixed moving forward. Additionally, PR #2002 revealed a bug in the microbenchmark script where [trtllm_batch_decode_with_kv_cache](https://github.com/flashinfer-ai/flashinfer/blob/v0.5.0rc2/flashinfer/decode.py#L2082-L2083) explicitly requires the workspace to be zeroed before first use: ``` workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use. workspace ``` while the microbenchmark code does not zero out, causing undefined behavior such as IMAs that depend on the ordering of backends tested. Current PR fixes the issue by explicitly calling `workspace_buffer.zero_()` between testing different backends. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved stability of performance benchmarks by properly resetting workspace buffer between backend invocations. * **Tests** * Enabled previously skipped test for long sequence length handling. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
This PR fixes illegal memory access of trtllm-gen attention kernels. It changes the workspace buffer from
int_workspace_buffertofloat_workspace_buffer.int_workspace_bufferis a fixed sized buffer and not initialized to zero, which should not be used.🔍 Related Issues
Issue #1928
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit