Skip to content

Conversation

@mxinO
Copy link
Contributor

@mxinO mxinO commented Nov 11, 2025

What does this PR do?

Type of change: Bug fix

Overview:

  1. Use mak_block_ptr for loading blocks, now it's more safe, fix illegal memory access in rare cases.
  2. Now the tile rows and columns can be specified separately.
  3. Moving data type cast to kernel to save memory for bf16/fp16 inputs.
  4. I did a benchmark comparing with the old kernel on H100 and B200, it has significant speed-up for medium and large size inputs (B200: 1.4x - 2x, H100: 1.7x - 2.8x)

H100:

Shape: 512x512                                                                                                                                                                               
  dtype: torch.float32                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 35.32 µs                                                                                                                                                                     
    new kernel: 38.49 µs                                                                                                                                                                     
    speedup: 0.92x                                                                                                                                                                           
  dtype: torch.bfloat16                                                                                                                                                                      
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 43.48 µs                                                                                                                                                                     
    new kernel: 44.78 µs                                                                                                                                                                     
    speedup: 0.97x                                                                                                                                                                           
  dtype: torch.float16                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 43.25 µs                                                                                                                                                                     
    new kernel: 43.69 µs                                                                                                                                                                     
    speedup: 0.99x                                                                                                                                                                           
                                                                                                                                                                                             
Shape: 1024x1024                                                                                                                                                                             
  dtype: torch.float32                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 36.03 µs                                                                                                                                                                     
    new kernel: 38.17 µs                                                                                                                                                                     
    speedup: 0.94x                                                                                                                                                                           
  dtype: torch.bfloat16                                                                                                                                                                      
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 44.24 µs                                                                                                                                                                     
    new kernel: 43.78 µs                                                                                                                                                                     
    speedup: 1.01x                                                                                                                                                                           
  dtype: torch.float16                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 43.77 µs                                                                                                                                                                     
    new kernel: 43.61 µs                                                                                                                                                                     
    speedup: 1.00x                                                                                                                                                                           

Shape: 4096x4096
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 87.02 µs
    new kernel: 80.88 µs
    speedup: 1.08x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 116.12 µs
    new kernel: 65.80 µs
    speedup: 1.76x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 114.39 µs
    new kernel: 65.30 µs
    speedup: 1.75x

Shape: 8192x8192
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 237.29 µs
    new kernel: 219.42 µs
    speedup: 1.08x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 349.76 µs
    new kernel: 138.66 µs
    speedup: 2.52x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 341.89 µs
    new kernel: 136.91 µs
    speedup: 2.50x

Shape: 8192x12288
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 338.65 µs
    new kernel: 312.70 µs
    speedup: 1.08x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 505.63 µs
    new kernel: 188.24 µs
    speedup: 2.69x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 492.97 µs
    new kernel: 186.88 µs
    speedup: 2.64x

Shape: 12288x12288
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 490.25 µs
    new kernel: 451.16 µs
    speedup: 1.09x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 736.04 µs
    new kernel: 261.94 µs
    speedup: 2.81x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 717.64 µs
    new kernel: 257.82 µs
    speedup: 2.78x

Shape: 32x4096
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 35.61 µs
    new kernel: 38.23 µs
    speedup: 0.93x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 43.00 µs
    new kernel: 43.85 µs
    speedup: 0.98x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 42.83 µs
    new kernel: 44.13 µs
    speedup: 0.97x

Shape: 1024x4096
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 38.12 µs
    new kernel: 41.28 µs
    speedup: 0.92x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 52.80 µs
    new kernel: 45.96 µs
    speedup: 1.15x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 51.56 µs
    new kernel: 45.30 µs
    speedup: 1.14x

Shape: 32x5000
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 41.70 µs
    new kernel: 38.03 µs
    speedup: 1.10x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 52.95 µs
    new kernel: 44.14 µs
    speedup: 1.20x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 52.57 µs
    new kernel: 44.38 µs
    speedup: 1.18x

Shape: 32x5000
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 41.70 µs
    new kernel: 38.03 µs
    speedup: 1.10x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 52.95 µs
    new kernel: 44.14 µs
    speedup: 1.20x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 52.57 µs
    new kernel: 44.38 µs
    speedup: 1.18x

Shape: 128x8200
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 48.03 µs
    new kernel: 38.38 µs
    speedup: 1.25x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 60.54 µs
    new kernel: 44.51 µs
    speedup: 1.36x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 60.08 µs
    new kernel: 43.59 µs
    speedup: 1.38x

B200:

Shape: 512x512                                                                                                                                                                               
  dtype: torch.float32                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 34.63 µs                                                                                                                                                                     
    new kernel: 32.80 µs                                                                                                                                                                     
    speedup: 1.06x                                                                                                                                                                           
  dtype: torch.bfloat16                                                                                                                                                                      
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 42.26 µs                                                                                                                                                                     
    new kernel: 40.92 µs                                                                                                                                                                     
    speedup: 1.03x                                                                                                                                                                           
  dtype: torch.float16                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 41.38 µs                                                                                                                                                                     
    new kernel: 39.30 µs                                                                                                                                                                     
    speedup: 1.05x                                                                                                                                                                           
                                                                                                                                                                                             
Shape: 1024x1024                                                                                                                                                                             
  dtype: torch.float32                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 35.07 µs                                                                                                                                                                     
    new kernel: 33.93 µs                                                                                                                                                                     
    speedup: 1.03x                                                                                                                                                                           
  dtype: torch.bfloat16                                                                                                                                                                      
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 43.57 µs                                                                                                                                                                     
    new kernel: 39.55 µs                                                                                                                                                                     
    speedup: 1.10x                                                                                                                                                                           
  dtype: torch.float16                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 43.72 µs
    new kernel: 38.96 µs
    speedup: 1.12x

Shape: 4096x4096
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 71.64 µs
    new kernel: 58.66 µs
    speedup: 1.22x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 81.67 µs
    new kernel: 57.98 µs
    speedup: 1.41x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 82.19 µs
    new kernel: 57.56 µs
    speedup: 1.43x

Shape: 8192x8192
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 176.85 µs
    new kernel: 135.78 µs
    speedup: 1.30x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 217.99 µs
    new kernel: 121.84 µs
    speedup: 1.79x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 215.47 µs
    new kernel: 117.41 µs
    speedup: 1.84x

Shape: 8192x12288                                                                                                                                                     
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 248.18 µs
    new kernel: 186.64 µs
    speedup: 1.33x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 306.25 µs
    new kernel: 163.28 µs
    speedup: 1.88x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 303.06 µs
    new kernel: 157.59 µs
    speedup: 1.92x

Shape: 12288x12288
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 354.23 µs
    new kernel: 262.99 µs
    speedup: 1.35x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 439.44 µs
    new kernel: 224.71 µs
    speedup: 1.96x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 434.23 µs
    new kernel: 217.62 µs
    speedup: 2.00x

Shape: 32x4096                                                                                                                                                                               
  dtype: torch.float32                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 35.90 µs                                                                                                                                                                     
    new kernel: 34.88 µs
    speedup: 1.03x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 43.77 µs
    new kernel: 41.49 µs
    speedup: 1.05x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 43.22 µs
    new kernel: 41.79 µs
    speedup: 1.03x

Shape: 1024x4096
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 37.37 µs
    new kernel: 37.84 µs
    speedup: 0.99x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 49.69 µs
    new kernel: 43.85 µs
    speedup: 1.13x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 48.93 µs
    new kernel: 44.31 µs
    speedup: 1.10x

Shape: 32x5000
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 41.83 µs
    new kernel: 35.44 µs
    speedup: 1.18x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 53.23 µs
    new kernel: 40.64 µs
    speedup: 1.31x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 54.39 µs
    new kernel: 40.77 µs
    speedup: 1.33x

Shape: 128x8200
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 49.35 µs
    new kernel: 35.33 µs
    speedup: 1.40x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 60.89 µs
    new kernel: 41.46 µs
    speedup: 1.47x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 61.75 µs
    new kernel: 41.75 µs
    speedup: 1.48x

Testing

  1. Compared with old kernel, diff=0
  2. Benchmark speed

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: No
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?: No

Additional Information

Bug [5612406]

@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 11, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Signed-off-by: mxin <mxin@nvidia.com>
@mxinO mxinO force-pushed the mxin/fp4-kernel-improve branch from f58e420 to 0cf5fb6 Compare November 11, 2025 06:25
@mxinO mxinO self-assigned this Nov 11, 2025
@codecov
Copy link

codecov bot commented Nov 11, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 74.43%. Comparing base (be64f6b) to head (f576793).

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #533   +/-   ##
=======================================
  Coverage   74.43%   74.43%           
=======================================
  Files         182      182           
  Lines       18238    18238           
=======================================
  Hits        13576    13576           
  Misses       4662     4662           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: mxin <mxin@nvidia.com>
@mxinO mxinO changed the title Improve NVFP4 Triton kernel Optimize NVFP4 Triton kernel Nov 13, 2025
@mxinO mxinO requested review from RalphMao and realAsma November 18, 2025 01:33
@mxinO mxinO marked this pull request as ready for review November 18, 2025 01:34
@mxinO mxinO requested a review from a team as a code owner November 18, 2025 01:34
@mxinO mxinO requested a review from cjluo-nv November 18, 2025 01:34
@cjluo-nv
Copy link
Collaborator

Thanks @mxinO. Do you have unittest cover this change?

Copy link
Contributor

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!!

global_scale_safe = tl.where(global_scale > 0.0, global_scale, 1e-12)

# Load input data
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, it looks like the old version has the proper mask in tl.load and tl.store. Why does it cause the nvbug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the illegal memory is hard to debug, because the error message never directs to the correct position. I didn't find the root cause actually, just guess it was the addressing issue. So changed the way to load and it's fixed. That bug is a rare case, It's never seen before.

@mxinO
Copy link
Contributor Author

mxinO commented Nov 19, 2025

Thanks @mxinO. Do you have unittest cover this change?

We have tests covering the triton kernel's correctness.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants