Skip to content

[Issue]: Does gemm_a8w8 support per-tensor quant? #1765

@byjiang1996

Description

@byjiang1996

Problem Description

I write a simple test script. gemm_a8w8 always give me incorrect results compared with torch native

Operating System

NAME="Ubuntu" VERSION="24.04.3 LTS (Noble Numbat)"

CPU

model name : Intel(R) Xeon(R) Platinum 8468

GPU

Marketing Name: AMD Instinct MI300X

ROCm Version

7.0.2

ROCm Component

No response

Steps to Reproduce

"""
Unit test to compare gemm_a8w8_CK output with torch._scaled_mm.

Uses random tensors and quantizes them to FP8 for testing.
"""

import torch
import aiter
from aiter import dtypes


def main():
    device = "cuda"
    dtype = torch.bfloat16
    
    # Create random tensors (similar to aiter test)
    M, N, K = 256, 4096, 1024
    print(f"Testing with M={M}, N={N}, K={K}")
    
    x = torch.randn((M, K), dtype=dtype, device=device)
    weight = torch.randn((N, K), dtype=dtype, device=device)  # (N, K) for gemm_a8w8_CK
    
    # Quantize to FP8 using per-tensor quantization
    x_quant, x_scale = aiter.per_tensor_quant(x, quant_dtype=dtypes.fp8)
    weight_quant, w_scale = aiter.per_tensor_quant(weight, quant_dtype=dtypes.fp8)
    
    print(f"\nInput shapes:")
    print(f"  x_quant: {x_quant.shape}, dtype: {x_quant.dtype}, strides: {x_quant.stride()}")
    print(f"  weight_quant: {weight_quant.shape}, dtype: {weight_quant.dtype}, strides: {weight_quant.stride()}")
    print(f"  x_scale: {x_scale.shape}, dtype: {x_scale.dtype}, strides: {x_scale.stride()}, value: {x_scale}")
    print(f"  w_scale: {w_scale.shape}, dtype: {w_scale.dtype}, strides: {w_scale.stride()}, value: {w_scale}")

    # Reference computation: dequantized GEMM (like aiter's run_torch)
    # gemm_a8w8_CK computes: x @ weight.T (like F.linear)
    import torch.nn.functional as F
    x_dequant = x_quant.float() * x_scale.float()
    weight_dequant = weight_quant.float() * w_scale.float()
    ref_output = F.linear(x_dequant, weight_dequant)
    ref_output = ref_output.to(dtype)
    print(f"\nReference output: {ref_output.shape}, range: [{ref_output.min().item():.4f}, {ref_output.max().item():.4f}]")

    # Run gemm_a8w8_CK
    from aiter import gemm_a8w8
    
    ck_output = gemm_a8w8(
        x_quant,
        weight_quant,
        x_scale=x_scale,
        w_scale=w_scale,
        bias=None,
        dtype=dtype,
    )
    torch.cuda.synchronize()
    print(f"gemm_a8w8_CK output: {ck_output.shape}, range: [{ck_output.min().item():.4f}, {ck_output.max().item():.4f}]")

    # Run torch._scaled_mm
    # torch._scaled_mm computes: A @ B where B is column-major
    # For equivalent computation: x @ weight.T
    # We need weight in (K, N) column-major format
    # weight_quant is (N, K), so weight_quant.t() gives (K, N) column-major
    weight_for_torch = weight_quant.t()  # (K, N) column-major view
    
    print(f"Weight for torch._scaled_mm: {weight_for_torch.shape}, strides: {weight_for_torch.stride()}")
    print(f"scale_a: {x_scale.shape}, scale_b: {w_scale.shape}")
    
    # For per-tensor scaling, scales should be singletons
    torch_output = torch._scaled_mm(
        x_quant,
        weight_for_torch,
        out_dtype=dtype,
        scale_a=x_scale,
        scale_b=w_scale,
        bias=None,
    )
    torch.cuda.synchronize()

    if isinstance(torch_output, tuple):
        torch_output = torch_output[0]

    print(f"torch._scaled_mm output: {torch_output.shape}, range: [{torch_output.min().item():.4f}, {torch_output.max().item():.4f}]")

    # Compare gemm_a8w8_CK vs reference
    print("\n" + "=" * 60)
    print("Comparison: gemm_a8w8_CK vs Reference")
    print("=" * 60)
    
    abs_diff_ck = torch.abs(ck_output.float() - ref_output.float())
    print(f"  Max absolute difference: {abs_diff_ck.max().item():.6e}")
    print(f"  Mean absolute difference: {abs_diff_ck.mean().item():.6e}")
    
    for rtol, atol in [(1e-2, 1e-2), (5e-2, 5e-2)]:
        allclose = torch.allclose(ck_output, ref_output, rtol=rtol, atol=atol)
        print(f"  torch.allclose(rtol={rtol}, atol={atol}): {allclose}")

    # Compare torch._scaled_mm vs reference
    print("\n" + "=" * 60)
    print("Comparison: torch._scaled_mm vs Reference")
    print("=" * 60)
    
    abs_diff_torch = torch.abs(torch_output.float() - ref_output.float())
    print(f"  Max absolute difference: {abs_diff_torch.max().item():.6e}")
    print(f"  Mean absolute difference: {abs_diff_torch.mean().item():.6e}")
    
    for rtol, atol in [(1e-2, 1e-2), (5e-2, 5e-2)]:
        allclose = torch.allclose(torch_output, ref_output, rtol=rtol, atol=atol)
        print(f"  torch.allclose(rtol={rtol}, atol={atol}): {allclose}")

    # Compare gemm_a8w8_CK vs torch._scaled_mm
    print("\n" + "=" * 60)
    print("Comparison: gemm_a8w8_CK vs torch._scaled_mm")
    print("=" * 60)
    
    abs_diff = torch.abs(ck_output.float() - torch_output.float())
    print(f"  Max absolute difference: {abs_diff.max().item():.6e}")
    print(f"  Mean absolute difference: {abs_diff.mean().item():.6e}")
    
    for rtol, atol in [(1e-2, 1e-2), (5e-2, 5e-2)]:
        allclose = torch.allclose(ck_output, torch_output, rtol=rtol, atol=atol)
        print(f"  torch.allclose(rtol={rtol}, atol={atol}): {allclose}")

    # Sample values
    print("\nSample values (first 5 elements):")
    ck_flat = ck_output.flatten()[:5]
    torch_flat = torch_output.flatten()[:5]
    ref_flat = ref_output.flatten()[:5]
    for i in range(5):
        print(f"  [{i}] CK: {ck_flat[i].item():10.4f}, torch: {torch_flat[i].item():10.4f}, ref: {ref_flat[i].item():10.4f}")

    # Final verdict
    print("\n" + "=" * 60)
    passed = torch.allclose(ck_output, torch_output, rtol=1e-2, atol=1e-2)
    if passed:
        print("✓ TEST PASSED: gemm_a8w8_CK matches torch._scaled_mm within rtol=1e-2, atol=1e-2")
    else:
        print("✗ TEST FAILED: gemm_a8w8_CK does not match torch._scaled_mm within tolerance")
    print("=" * 60)

    return 0 if passed else 1


if __name__ == "__main__":
    exit(main())
============================================================
Comparison: torch._scaled_mm vs Reference
============================================================
  Max absolute difference: 5.000000e-01
  Mean absolute difference: 1.440744e-05
  torch.allclose(rtol=0.01, atol=0.01): True
  torch.allclose(rtol=0.05, atol=0.05): True

============================================================
Comparison: gemm_a8w8_CK vs torch._scaled_mm
============================================================
  Max absolute difference: inf
  Mean absolute difference: inf
  torch.allclose(rtol=0.01, atol=0.01): False
  torch.allclose(rtol=0.05, atol=0.05): False

Sample values (first 5 elements):
  [0] CK:    -8.3750, torch:    -8.3750, ref:    -8.3750
  [1] CK:  4384.0000, torch:    23.2500, ref:    23.2500
  [2] CK: -10752.0000, torch:   -57.2500, ref:   -57.2500
  [3] CK:   107.0000, torch:     0.5000, ref:     0.5000
  [4] CK: -4864.0000, torch:   -26.8750, ref:   -26.8750

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions