Skip to content

tune_kernels.py Cannot Give Better Result Than Pytorch baseline #8

@Dr-Left

Description

@Dr-Left

My Environment

torch version: 2.8.0
cuda: 12.1.105 (Build cuda_12.1.r12.1/compiler.32688072_0)
triton 3.2.0
Hardware: NVIDIA H100 NVL 94GB

Question

For many of the kernels, the pytorch baseline has a comparable or even better latency with the best triton configs. What's wrong?

Tuning Summary:

pytorch baseline for:

  • multi_lora_xw_sb: 0.279ms
  • multi_lora_dyw_dsa: 0.279ms
  • multi_lora_dyw_dsa_tma: 0.277ms
  • lora_XW_SB_TMA: 0.278ms
  • lora_dyw_dsa: 0.277ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:550 - ================================================================================
2025-11-07 14:40:30.252 | INFO     | __main__:main:551 - TUNING SUMMARY
2025-11-07 14:40:30.252 | INFO     | __main__:main:552 - ================================================================================
2025-11-07 14:40:30.252 | INFO     | __main__:main:554 - Device short name: h100-nvl
2025-11-07 14:40:30.252 | INFO     | __main__:main:556 - fused_lora_xw_sb: 0.293ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:557 -   Config: LoRATritonConfig(block_size_m=128, block_size_n=256, block_size_k=64, group_size_m=8, num_stages=3, num_warps=8, epilogue_subtile=None, loop_unroll_factor=None, flatten=None)
2025-11-07 14:40:30.252 | INFO     | __main__:main:556 - fused_lora_dyw_dsa: 0.284ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:557 -   Config: LoRATritonConfig(block_size_m=128, block_size_n=256, block_size_k=64, group_size_m=8, num_stages=3, num_warps=8, epilogue_subtile=None, loop_unroll_factor=None, flatten=None)
2025-11-07 14:40:30.252 | INFO     | __main__:main:556 - fused_lora_dys_dyb: 0.040ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:557 -   Config: LoRATritonConfig(block_size_m=128, block_size_n=None, block_size_k=128, group_size_m=8, num_stages=3, num_warps=8, epilogue_subtile=None, loop_unroll_factor=None, flatten=None)
2025-11-07 14:40:30.252 | INFO     | __main__:main:556 - fused_lora_xw_sb_tma: 0.272ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:557 -   Config: LoRATritonConfig(block_size_m=128, block_size_n=256, block_size_k=64, group_size_m=8, num_stages=3, num_warps=8, epilogue_subtile=None, loop_unroll_factor=None, flatten=None)
2025-11-07 14:40:30.252 | INFO     | __main__:main:556 - fused_lora_dyw_dsa_tma: 0.272ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:557 -   Config: LoRATritonConfig(block_size_m=128, block_size_n=256, block_size_k=64, group_size_m=8, num_stages=3, num_warps=8, epilogue_subtile=None, loop_unroll_factor=None, flatten=None)
2025-11-07 14:40:30.252 | INFO     | __main__:main:556 - fused_multi_lora_xw_sb: 0.511ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:557 -   Config: LoRATritonConfig(block_size_m=128, block_size_n=256, block_size_k=64, group_size_m=8, num_stages=3, num_warps=8, epilogue_subtile=None, loop_unroll_factor=None, flatten=None)
2025-11-07 14:40:30.252 | INFO     | __main__:main:556 - fused_multi_lora_dyw_dsa: 0.259ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:557 -   Config: LoRATritonConfig(block_size_m=128, block_size_n=256, block_size_k=64, group_size_m=8, num_stages=3, num_warps=8, epilogue_subtile=None, loop_unroll_factor=None, flatten=None)
2025-11-07 14:40:30.252 | INFO     | __main__:main:556 - fused_multi_lora_dys_dyb: 0.119ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:557 -   Config: LoRATritonConfig(block_size_m=128, block_size_n=None, block_size_k=128, group_size_m=8, num_stages=4, num_warps=8, epilogue_subtile=None, loop_unroll_factor=None, flatten=None)

Full logging when running tune_kernels.py:
https://gist.github.com/Dr-Left/c1889749d27fabdb1ec966b0e4060d2b

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions