-
Notifications
You must be signed in to change notification settings - Fork 197
Optimize NVFP4 Triton kernel #533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Signed-off-by: mxin <mxin@nvidia.com>
f58e420 to
0cf5fb6
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #533 +/- ##
=======================================
Coverage 74.43% 74.43%
=======================================
Files 182 182
Lines 18238 18238
=======================================
Hits 13576 13576
Misses 4662 4662 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: mxin <mxin@nvidia.com>
|
Thanks @mxinO. Do you have unittest cover this change? |
realAsma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!!
| global_scale_safe = tl.where(global_scale > 0.0, global_scale, 1e-12) | ||
|
|
||
| # Load input data | ||
| x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, it looks like the old version has the proper mask in tl.load and tl.store. Why does it cause the nvbug?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the illegal memory is hard to debug, because the error message never directs to the correct position. I didn't find the root cause actually, just guess it was the addressing issue. So changed the way to load and it's fixed. That bug is a rare case, It's never seen before.
We have tests covering the triton kernel's correctness. |
What does this PR do?
Type of change: Bug fix
Overview:
H100:
Shape: 512x512 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 35.32 µs new kernel: 38.49 µs speedup: 0.92x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 43.48 µs new kernel: 44.78 µs speedup: 0.97x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 43.25 µs new kernel: 43.69 µs speedup: 0.99x Shape: 1024x1024 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 36.03 µs new kernel: 38.17 µs speedup: 0.94x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 44.24 µs new kernel: 43.78 µs speedup: 1.01x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 43.77 µs new kernel: 43.61 µs speedup: 1.00x Shape: 4096x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 87.02 µs new kernel: 80.88 µs speedup: 1.08x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 116.12 µs new kernel: 65.80 µs speedup: 1.76x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 114.39 µs new kernel: 65.30 µs speedup: 1.75x Shape: 8192x8192 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 237.29 µs new kernel: 219.42 µs speedup: 1.08x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 349.76 µs new kernel: 138.66 µs speedup: 2.52x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 341.89 µs new kernel: 136.91 µs speedup: 2.50x Shape: 8192x12288 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 338.65 µs new kernel: 312.70 µs speedup: 1.08x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 505.63 µs new kernel: 188.24 µs speedup: 2.69x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 492.97 µs new kernel: 186.88 µs speedup: 2.64x Shape: 12288x12288 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 490.25 µs new kernel: 451.16 µs speedup: 1.09x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 736.04 µs new kernel: 261.94 µs speedup: 2.81x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 717.64 µs new kernel: 257.82 µs speedup: 2.78x Shape: 32x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 35.61 µs new kernel: 38.23 µs speedup: 0.93x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 43.00 µs new kernel: 43.85 µs speedup: 0.98x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 42.83 µs new kernel: 44.13 µs speedup: 0.97x Shape: 1024x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 38.12 µs new kernel: 41.28 µs speedup: 0.92x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 52.80 µs new kernel: 45.96 µs speedup: 1.15x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 51.56 µs new kernel: 45.30 µs speedup: 1.14x Shape: 32x5000 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 41.70 µs new kernel: 38.03 µs speedup: 1.10x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 52.95 µs new kernel: 44.14 µs speedup: 1.20x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 52.57 µs new kernel: 44.38 µs speedup: 1.18x Shape: 32x5000 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 41.70 µs new kernel: 38.03 µs speedup: 1.10x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 52.95 µs new kernel: 44.14 µs speedup: 1.20x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 52.57 µs new kernel: 44.38 µs speedup: 1.18x Shape: 128x8200 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 48.03 µs new kernel: 38.38 µs speedup: 1.25x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 60.54 µs new kernel: 44.51 µs speedup: 1.36x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 60.08 µs new kernel: 43.59 µs speedup: 1.38xB200:
Shape: 512x512 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 34.63 µs new kernel: 32.80 µs speedup: 1.06x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 42.26 µs new kernel: 40.92 µs speedup: 1.03x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 41.38 µs new kernel: 39.30 µs speedup: 1.05x Shape: 1024x1024 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 35.07 µs new kernel: 33.93 µs speedup: 1.03x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 43.57 µs new kernel: 39.55 µs speedup: 1.10x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 43.72 µs new kernel: 38.96 µs speedup: 1.12x Shape: 4096x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 71.64 µs new kernel: 58.66 µs speedup: 1.22x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 81.67 µs new kernel: 57.98 µs speedup: 1.41x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 82.19 µs new kernel: 57.56 µs speedup: 1.43x Shape: 8192x8192 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 176.85 µs new kernel: 135.78 µs speedup: 1.30x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 217.99 µs new kernel: 121.84 µs speedup: 1.79x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 215.47 µs new kernel: 117.41 µs speedup: 1.84x Shape: 8192x12288 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 248.18 µs new kernel: 186.64 µs speedup: 1.33x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 306.25 µs new kernel: 163.28 µs speedup: 1.88x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 303.06 µs new kernel: 157.59 µs speedup: 1.92x Shape: 12288x12288 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 354.23 µs new kernel: 262.99 µs speedup: 1.35x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 439.44 µs new kernel: 224.71 µs speedup: 1.96x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 434.23 µs new kernel: 217.62 µs speedup: 2.00x Shape: 32x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 35.90 µs new kernel: 34.88 µs speedup: 1.03x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 43.77 µs new kernel: 41.49 µs speedup: 1.05x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 43.22 µs new kernel: 41.79 µs speedup: 1.03x Shape: 1024x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 37.37 µs new kernel: 37.84 µs speedup: 0.99x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 49.69 µs new kernel: 43.85 µs speedup: 1.13x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 48.93 µs new kernel: 44.31 µs speedup: 1.10x Shape: 32x5000 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 41.83 µs new kernel: 35.44 µs speedup: 1.18x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 53.23 µs new kernel: 40.64 µs speedup: 1.31x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 54.39 µs new kernel: 40.77 µs speedup: 1.33x Shape: 128x8200 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 49.35 µs new kernel: 35.33 µs speedup: 1.40x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 60.89 µs new kernel: 41.46 µs speedup: 1.47x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 61.75 µs new kernel: 41.75 µs speedup: 1.48xTesting
Before your PR is "Ready for review"
Additional Information
Bug [5612406]