diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 240e8eea49..14de3610b3 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -521,8 +521,9 @@ def test_triton_mxfp8_dim0_zeros(): ) @pytest.mark.parametrize("M", (256, 2048, 131072)) @pytest.mark.parametrize("K", (256, 5120, 7168)) -def test_triton_mxfp8_dequant_dim0(M, K): - x = torch.zeros(M, K, dtype=torch.bfloat16, device="cuda") +@pytest.mark.parametrize("orig_dtype", (torch.float32, torch.bfloat16)) +def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype): + x = torch.zeros(M, K, dtype=orig_dtype, device="cuda") block_size = 32 x_data, x_scales = triton_to_mxfp8_dim0_reference(x, block_size=32) hp_ref = to_dtype( @@ -530,9 +531,9 @@ def test_triton_mxfp8_dequant_dim0(M, K): x_scales, torch.float8_e4m3fn, block_size, - torch.bfloat16, + orig_dtype, ) - hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, torch.bfloat16, block_size) + hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, orig_dtype, block_size) torch.testing.assert_close(hp_t, hp_ref, rtol=0, atol=0) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index c085ed9740..671186c7c7 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1141,7 +1141,8 @@ def triton_to_mxfp8_dim0( * `scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0 """ assert x.is_contiguous(), "`x` must be contiguous" - assert inner_block_size <= 32 + assert inner_block_size <= 32, "inner_block_size must be <= 32" + assert x.dtype == torch.bfloat16, "only bfloat16 inputs are supported" # Reshape tensor to 2d if necessary and get shape x_orig_shape = x.shape @@ -1279,12 +1280,13 @@ def triton_to_mxfp8_dim1_reference( scale_e8m0_dim1, ) + @triton_op("torchao::triton_mxfp8_dequant_dim0", mutates_args={}) def triton_mxfp8_dequant_dim0( e4m3_data: torch.Tensor, e8m0_scales: torch.Tensor, out_dtype: torch.dtype, scale_block_size: int = 32, - ) -> None: + ) -> torch.Tensor: assert scale_block_size == 32, "scale_block_size must be 32 for now" assert out_dtype in (torch.bfloat16, torch.float32), ( "out_dtype must be bf16 or fp32" @@ -1300,7 +1302,7 @@ def triton_mxfp8_dequant_dim0( triton.cdiv(e4m3_data.shape[0], META["ROW_TILE_SIZE"]), triton.cdiv(e4m3_data.shape[1], META["COL_TILE_SIZE"]), ) - _dequant_mxfp8_kernel[grid]( + wrap_triton(_dequant_mxfp8_kernel)[grid]( e4m3_data, e8m0_scales.to(torch.uint8), out_buffer, @@ -1371,8 +1373,8 @@ def _dequant_mxfp8_kernel( @triton.jit def _e8m0_to_fp32(scale_e8m0): - e8m0_exponent_bias = 127 e8m0_nan_val = 255 + e8m0_exponent_bias = 127 s_offset = scale_e8m0.to(tl.int16) - e8m0_exponent_bias s_fp = tl.exp2(s_offset.to(tl.float32)) s_fp = tl.where(scale_e8m0 != e8m0_nan_val, s_fp, float("nan"))