-
Notifications
You must be signed in to change notification settings - Fork 172
Open
Description
Problem Description
I write a simple test script. gemm_a8w8 always give me incorrect results compared with torch native
Operating System
NAME="Ubuntu" VERSION="24.04.3 LTS (Noble Numbat)"
CPU
model name : Intel(R) Xeon(R) Platinum 8468
GPU
Marketing Name: AMD Instinct MI300X
ROCm Version
7.0.2
ROCm Component
No response
Steps to Reproduce
"""
Unit test to compare gemm_a8w8_CK output with torch._scaled_mm.
Uses random tensors and quantizes them to FP8 for testing.
"""
import torch
import aiter
from aiter import dtypes
def main():
device = "cuda"
dtype = torch.bfloat16
# Create random tensors (similar to aiter test)
M, N, K = 256, 4096, 1024
print(f"Testing with M={M}, N={N}, K={K}")
x = torch.randn((M, K), dtype=dtype, device=device)
weight = torch.randn((N, K), dtype=dtype, device=device) # (N, K) for gemm_a8w8_CK
# Quantize to FP8 using per-tensor quantization
x_quant, x_scale = aiter.per_tensor_quant(x, quant_dtype=dtypes.fp8)
weight_quant, w_scale = aiter.per_tensor_quant(weight, quant_dtype=dtypes.fp8)
print(f"\nInput shapes:")
print(f" x_quant: {x_quant.shape}, dtype: {x_quant.dtype}, strides: {x_quant.stride()}")
print(f" weight_quant: {weight_quant.shape}, dtype: {weight_quant.dtype}, strides: {weight_quant.stride()}")
print(f" x_scale: {x_scale.shape}, dtype: {x_scale.dtype}, strides: {x_scale.stride()}, value: {x_scale}")
print(f" w_scale: {w_scale.shape}, dtype: {w_scale.dtype}, strides: {w_scale.stride()}, value: {w_scale}")
# Reference computation: dequantized GEMM (like aiter's run_torch)
# gemm_a8w8_CK computes: x @ weight.T (like F.linear)
import torch.nn.functional as F
x_dequant = x_quant.float() * x_scale.float()
weight_dequant = weight_quant.float() * w_scale.float()
ref_output = F.linear(x_dequant, weight_dequant)
ref_output = ref_output.to(dtype)
print(f"\nReference output: {ref_output.shape}, range: [{ref_output.min().item():.4f}, {ref_output.max().item():.4f}]")
# Run gemm_a8w8_CK
from aiter import gemm_a8w8
ck_output = gemm_a8w8(
x_quant,
weight_quant,
x_scale=x_scale,
w_scale=w_scale,
bias=None,
dtype=dtype,
)
torch.cuda.synchronize()
print(f"gemm_a8w8_CK output: {ck_output.shape}, range: [{ck_output.min().item():.4f}, {ck_output.max().item():.4f}]")
# Run torch._scaled_mm
# torch._scaled_mm computes: A @ B where B is column-major
# For equivalent computation: x @ weight.T
# We need weight in (K, N) column-major format
# weight_quant is (N, K), so weight_quant.t() gives (K, N) column-major
weight_for_torch = weight_quant.t() # (K, N) column-major view
print(f"Weight for torch._scaled_mm: {weight_for_torch.shape}, strides: {weight_for_torch.stride()}")
print(f"scale_a: {x_scale.shape}, scale_b: {w_scale.shape}")
# For per-tensor scaling, scales should be singletons
torch_output = torch._scaled_mm(
x_quant,
weight_for_torch,
out_dtype=dtype,
scale_a=x_scale,
scale_b=w_scale,
bias=None,
)
torch.cuda.synchronize()
if isinstance(torch_output, tuple):
torch_output = torch_output[0]
print(f"torch._scaled_mm output: {torch_output.shape}, range: [{torch_output.min().item():.4f}, {torch_output.max().item():.4f}]")
# Compare gemm_a8w8_CK vs reference
print("\n" + "=" * 60)
print("Comparison: gemm_a8w8_CK vs Reference")
print("=" * 60)
abs_diff_ck = torch.abs(ck_output.float() - ref_output.float())
print(f" Max absolute difference: {abs_diff_ck.max().item():.6e}")
print(f" Mean absolute difference: {abs_diff_ck.mean().item():.6e}")
for rtol, atol in [(1e-2, 1e-2), (5e-2, 5e-2)]:
allclose = torch.allclose(ck_output, ref_output, rtol=rtol, atol=atol)
print(f" torch.allclose(rtol={rtol}, atol={atol}): {allclose}")
# Compare torch._scaled_mm vs reference
print("\n" + "=" * 60)
print("Comparison: torch._scaled_mm vs Reference")
print("=" * 60)
abs_diff_torch = torch.abs(torch_output.float() - ref_output.float())
print(f" Max absolute difference: {abs_diff_torch.max().item():.6e}")
print(f" Mean absolute difference: {abs_diff_torch.mean().item():.6e}")
for rtol, atol in [(1e-2, 1e-2), (5e-2, 5e-2)]:
allclose = torch.allclose(torch_output, ref_output, rtol=rtol, atol=atol)
print(f" torch.allclose(rtol={rtol}, atol={atol}): {allclose}")
# Compare gemm_a8w8_CK vs torch._scaled_mm
print("\n" + "=" * 60)
print("Comparison: gemm_a8w8_CK vs torch._scaled_mm")
print("=" * 60)
abs_diff = torch.abs(ck_output.float() - torch_output.float())
print(f" Max absolute difference: {abs_diff.max().item():.6e}")
print(f" Mean absolute difference: {abs_diff.mean().item():.6e}")
for rtol, atol in [(1e-2, 1e-2), (5e-2, 5e-2)]:
allclose = torch.allclose(ck_output, torch_output, rtol=rtol, atol=atol)
print(f" torch.allclose(rtol={rtol}, atol={atol}): {allclose}")
# Sample values
print("\nSample values (first 5 elements):")
ck_flat = ck_output.flatten()[:5]
torch_flat = torch_output.flatten()[:5]
ref_flat = ref_output.flatten()[:5]
for i in range(5):
print(f" [{i}] CK: {ck_flat[i].item():10.4f}, torch: {torch_flat[i].item():10.4f}, ref: {ref_flat[i].item():10.4f}")
# Final verdict
print("\n" + "=" * 60)
passed = torch.allclose(ck_output, torch_output, rtol=1e-2, atol=1e-2)
if passed:
print("✓ TEST PASSED: gemm_a8w8_CK matches torch._scaled_mm within rtol=1e-2, atol=1e-2")
else:
print("✗ TEST FAILED: gemm_a8w8_CK does not match torch._scaled_mm within tolerance")
print("=" * 60)
return 0 if passed else 1
if __name__ == "__main__":
exit(main())
============================================================
Comparison: torch._scaled_mm vs Reference
============================================================
Max absolute difference: 5.000000e-01
Mean absolute difference: 1.440744e-05
torch.allclose(rtol=0.01, atol=0.01): True
torch.allclose(rtol=0.05, atol=0.05): True
============================================================
Comparison: gemm_a8w8_CK vs torch._scaled_mm
============================================================
Max absolute difference: inf
Mean absolute difference: inf
torch.allclose(rtol=0.01, atol=0.01): False
torch.allclose(rtol=0.05, atol=0.05): False
Sample values (first 5 elements):
[0] CK: -8.3750, torch: -8.3750, ref: -8.3750
[1] CK: 4384.0000, torch: 23.2500, ref: 23.2500
[2] CK: -10752.0000, torch: -57.2500, ref: -57.2500
[3] CK: 107.0000, torch: 0.5000, ref: 0.5000
[4] CK: -4864.0000, torch: -26.8750, ref: -26.8750
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response
Metadata
Metadata
Assignees
Labels
No labels