From 9231d4f7dc97ba1388a2b9db7c89b18f32513406 Mon Sep 17 00:00:00 2001 From: "Xiao, Wang" Date: Mon, 27 Oct 2025 00:58:36 -0700 Subject: [PATCH 1/2] Support mx_tensor and enable it's test on Intel GPU --- .../mx_formats/test_inference_workflow.py | 35 ++- test/prototype/mx_formats/test_mx_tensor.py | 203 ++++++++++++------ torchao/prototype/mx_formats/kernels.py | 8 +- 3 files changed, 171 insertions(+), 75 deletions(-) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 1c8c1bc207..f334bf1b61 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -36,6 +36,14 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) +devices = [] +if torch.cuda.is_available(): + devices.append("cuda") + +if torch.xpu.is_available(): + devices.append("xpu") + + # source: https://stackoverflow.com/a/22638709 @pytest.fixture(autouse=True) def run_around_tests(): @@ -63,16 +71,22 @@ def cuda_kernel_profiler(kernel_pattern): result["found"] = any(kernel_pattern in name for name in kernel_names) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) @pytest.mark.skipif( not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" ) @pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("compile", [True, False]) -@pytest.mark.parametrize("emulate", [True, False]) +@pytest.mark.parametrize( + "emulate", [True, False] if (not torch.xpu.is_available()) else [True] +) @pytest.mark.parametrize("use_inference_mode", [True, False]) @pytest.mark.parametrize("x_rank", [2, 3]) +@pytest.mark.parametrize("device", devices) @torch.no_grad() @skip_if_rocm( "ROCm float4 gemm require gfx950" @@ -84,25 +98,31 @@ def test_inference_workflow_mx( emulate: bool, use_inference_mode: bool, x_rank: int, + device, ): """ Smoke test for inference compile """ # TODO(future): figure out why these CUDA capability conditions are not properly # applied when inside `pytest.mark.skipif` for this test - if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + if ( + elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + ) and torch.cuda.is_available(): if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") elif not is_sm_at_least_100() and not emulate: pytest.skip("CUDA capability >= 10.0 required for mxfp8 gemm") - elif elem_dtype == torch.float4_e2m1fn_x2: + elif (elem_dtype == torch.float4_e2m1fn_x2) and torch.cuda.is_available(): if not is_sm_at_least_100() and not emulate: pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm") elif compile: # TODO(future PR): investigate and fix this - pytest.skip("mxfp4 + compile currently does not work, low SQNR") + pytest.skip("mxfp4 + compile currently does not work on CUDA, low SQNR") - m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda") + if (elem_dtype == torch.float4_e2m1fn_x2) and torch.xpu.is_available() and compile: + pytest.skip("mxfp4 + compile currently does not work on XPU, low SQNR") + + m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device=device) m_mx = copy.deepcopy(m) if emulate: @@ -120,10 +140,9 @@ def test_inference_workflow_mx( if compile: m_mx = torch.compile(m_mx, fullgraph=True) - x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16) + x = torch.randn(128, 32, device=device, dtype=torch.bfloat16) if x_rank == 3: x = x.unsqueeze(0) - y_ref = m(x) if use_inference_mode: with torch.inference_mode(): diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 0f22f2f8ae..093abd8d91 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -38,6 +38,14 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) +devices = [] +if torch.cuda.is_available(): + devices.append("cuda") + +if torch.xpu.is_available(): + devices.append("xpu") + + @pytest.fixture(autouse=True) def run_before_and_after_tests(): # source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501 @@ -81,35 +89,51 @@ def assert_sqnr_gt_threshold(orig, new, threshold): assert data_mx.scale.shape == (*prev_dims, K // block_size) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) -def test_hello_world(elem_dtype): - data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16) +@pytest.mark.parametrize("device", devices) +def test_hello_world(elem_dtype, device): + data = torch.randn(8, 8, device=device, dtype=torch.bfloat16) block_size = 4 _test_mx(data, elem_dtype, block_size) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) @pytest.mark.parametrize("scale_calculation_mode", [s for s in ScaleCalculationMode]) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) -def test_realistic_numerics(elem_dtype, scale_calculation_mode): - data = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) +@pytest.mark.parametrize("device", devices) +def test_realistic_numerics(elem_dtype, scale_calculation_mode, device): + data = torch.randn(128, 128, device=device, dtype=torch.bfloat16) block_size = 32 _test_mx(data, elem_dtype, block_size, scale_calculation_mode) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) -def test_all_zeros(elem_dtype): - data = torch.zeros(4, 4, device="cuda", dtype=torch.bfloat16) +@pytest.mark.parametrize("device", devices) +def test_all_zeros(elem_dtype, device): + data = torch.zeros(4, 4, device=device, dtype=torch.bfloat16) block_size = 4 _test_mx(data, elem_dtype, block_size) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) -def test_some_zeros(elem_dtype): - data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16) +@pytest.mark.parametrize("device", devices) +def test_some_zeros(elem_dtype, device): + data = torch.randn(4, 4, device=device, dtype=torch.bfloat16) data[0, :] = 0.0 data[:, 2] = 0.0 block_size = 4 @@ -331,15 +355,19 @@ def test_to_mx_rceil(): torch.testing.assert_close(data_mx.qdata, ground_truth_fp8) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) -def test_exponent_nan_in(elem_dtype): +@pytest.mark.parametrize("device", devices) +def test_exponent_nan_in(elem_dtype, device): """ If high precision block values has a NaN, the exponent block value is set to is NaN """ tensor_hp = torch.tensor( - [float("nan"), 1, 2, 3, 4, 5, 6, 7], device="cuda", dtype=torch.bfloat16 + [float("nan"), 1, 2, 3, 4, 5, 6, 7], device=device, dtype=torch.bfloat16 ) block_size = 4 tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size) @@ -347,10 +375,14 @@ def test_exponent_nan_in(elem_dtype): assert not torch.any(torch.isnan(tensor_mx.scale[1:])) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("pack_fp6", [False, True]) -def test_exponent_nan_out(elem_dtype, pack_fp6): +@pytest.mark.parametrize("device", devices) +def test_exponent_nan_out(elem_dtype, pack_fp6, device): """ If block exponent value is NaN, the MX tensor block value is NaN """ @@ -358,25 +390,25 @@ def test_exponent_nan_out(elem_dtype, pack_fp6): pytest.skip("invalid configuration") scale_e8m0 = torch.tensor( - [float("nan"), 1.0], dtype=torch.float8_e8m0fnu, device="cuda" + [float("nan"), 1.0], dtype=torch.float8_e8m0fnu, device=device ) block_size = 4 if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): data_bits = torch.tensor( - [0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device="cuda" + [0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device=device ) # noqa: E501 elif elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2): data_bits = torch.tensor( - [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda" + [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device=device ) # noqa: E501 if pack_fp6: data_bits = data_bits.reshape(-1, block_size) data_bits = pack_uint6(data_bits) elif elem_dtype == torch.float4_e2m1fn_x2: data_bits = torch.tensor( - [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda" + [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device=device ) # noqa: E501 data_bits = pack_uint4(data_bits) else: @@ -398,23 +430,31 @@ def test_exponent_nan_out(elem_dtype, pack_fp6): assert not torch.any(torch.isnan(tensor_hp.flatten()[4:])) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) -def test_ranks(elem_dtype): +@pytest.mark.parametrize("device", devices) +def test_ranks(elem_dtype, device): """ The reshaping logic works for various ranks """ B = 4 shapes = ((B * 4,), (B * 4, 4), (B * 4, 4, 4), (B * 4, 4, 4, 4)) for s in shapes: - tensor_hp = torch.randn(*s, device="cuda", dtype=torch.bfloat16) + tensor_hp = torch.randn(*s, device=device, dtype=torch.bfloat16) _test_mx(tensor_hp, elem_dtype, B) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("B", [1, 4, 32]) -def test_block_sizes(elem_dtype, B): +@pytest.mark.parametrize("device", devices) +def test_block_sizes(elem_dtype, B, device): """ Smoke test for various block sizes """ @@ -422,19 +462,23 @@ def test_block_sizes(elem_dtype, B): pytest.skip("unsupported configuration") elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]: pytest.skip("unsupported configuration") - tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16) + tensor_hp = torch.randn(B, device=device, dtype=torch.bfloat16) _test_mx(tensor_hp, elem_dtype, B) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) -def test_transpose(elem_dtype): +@pytest.mark.parametrize("device", devices) +def test_transpose(elem_dtype, device): """ Verify that transposing an MX tensor works """ M, K = 128, 256 block_size = 32 - tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + tensor_hp = torch.randn(M, K, device=device, dtype=torch.bfloat16) tensor_mx = MXTensor.to_mx( tensor_hp, elem_dtype, @@ -449,18 +493,23 @@ def test_transpose(elem_dtype): torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) -def test_view(elem_dtype): - x = torch.randn(1, 2, 4, device="cuda") +@pytest.mark.parametrize("device", devices) +def test_view(elem_dtype, device): + x = torch.randn(1, 2, 4, device=device) block_size = 4 x_mx = MXTensor.to_mx(x, elem_dtype, block_size) x_mx_2 = x_mx.view(2, 4) # noqa: F841 -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_clone(): - data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16) +@pytest.mark.skipif(not (torch.cuda.is_available()), reason="CUDA not available") +@pytest.mark.parametrize("device", devices) +def test_clone(device): + data = torch.randn(8, 8, device=device, dtype=torch.bfloat16) block_size = 4 data_mx = MXTensor.to_mx(data, torch.float8_e4m3fn, block_size) data_mx_c = data_mx.clone() @@ -472,11 +521,15 @@ def test_clone(): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) @pytest.mark.parametrize("elem_dtype", [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]) @pytest.mark.parametrize("pack_fp6", [False, True]) -def test_fp6_packing(elem_dtype, pack_fp6): - x = torch.randn(1, 2, 4, device="cuda") +@pytest.mark.parametrize("device", devices) +def test_fp6_packing(elem_dtype, pack_fp6, device): + x = torch.randn(1, 2, 4, device=device) block_size = 4 x_mx = MXTensor.to_mx(x, elem_dtype, block_size, pack_fp6=pack_fp6) if pack_fp6: @@ -487,24 +540,31 @@ def test_fp6_packing(elem_dtype, pack_fp6): assert x_mx.qdata.shape == expected_packed_shape -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("all_zeros", [False, True]) -def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): +@pytest.mark.parametrize("device", devices) +def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros, device): """ Verifies that compile does not change numerics of MX casts """ - if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + if ( + elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + and torch.cuda.is_available() + ): if not is_sm_at_least_89(): # separate ifs because flake8 is outsmarting me pytest.skip("CUDA capability >= 8.9 required for float8 in triton") shape = 4, 8 if not all_zeros: - x = torch.randn(*shape, dtype=hp_dtype, device="cuda") + x = torch.randn(*shape, dtype=hp_dtype, device=device) else: - x = torch.zeros(*shape, dtype=hp_dtype, device="cuda") + x = torch.zeros(*shape, dtype=hp_dtype, device=device) block_size = 4 to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True) @@ -540,28 +600,36 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not is_sm_at_least_89(), + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) +@pytest.mark.skipif( + not (is_sm_at_least_89() or torch.xpu.is_available()), reason="float8 in triton requires CUDA capability 8.9 or greater", ) -def test_to_mx_inductor_single_kernel(): +@pytest.mark.parametrize("device", devices) +def test_to_mx_inductor_single_kernel(device): """ Verify that inductor can fuse the cast of a high precision tensor to mx into a single kernel """ # TODO(future PR): add fp4 and fp6 here # TODO(#1773): add swizzled scale format here - x = torch.randn(2048, 2048, dtype=torch.bfloat16, device="cuda") + x = torch.randn(2048, 2048, dtype=torch.bfloat16, device=device) block_size = 32 to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True) out, code = run_and_get_code(to_mx_c, x, torch.float8_e4m3fn, block_size) FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0]) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipIf(not is_sm_at_least_90(), "Need sm90+") -def test_index_select(): +@pytest.mark.skipif( + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) +@pytest.mark.skipIf(not (is_sm_at_least_90() or torch.xpu.is_available()), "Need sm90+") +@pytest.mark.parametrize("device", devices) +def test_index_select(device): """ test that `x_0 = x[0]` works when `x` is a 3D `MXTensor`. This is useful when stitching checkpoints of `num_experts` 2D parameters into @@ -570,7 +638,7 @@ def test_index_select(): """ E, K, N = 128, 256, 512 - x = torch.randn(E, N, K, device="cuda", dtype=torch.bfloat16) + x = torch.randn(E, N, K, device=device, dtype=torch.bfloat16) x_mx = MXTensor.to_mx(x, torch.float8_e4m3fn, 32) x_mx_1 = x_mx[1] @@ -579,12 +647,16 @@ def test_index_select(): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not is_sm_at_least_89(), + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) +@pytest.mark.skipif( + not (is_sm_at_least_89() or torch.xpu.is_available()), reason="float8 in triton requires CUDA capability 8.9 or greater", ) -def test_cast_to_float8_e4m3fn_saturation_behavior(): +@pytest.mark.parametrize("device", devices) +def test_cast_to_float8_e4m3fn_saturation_behavior(device): # TODO(#1912): make the saturated cast work in eager mode and remove this # test max_val = torch.finfo(torch.float8_e4m3fn).max @@ -596,7 +668,7 @@ def test_cast_to_float8_e4m3fn_saturation_behavior(): -1 * max_val, ], dtype=torch.bfloat16, - device="cuda", + device=device, ) # create example data outside the representable range @@ -606,7 +678,7 @@ def test_cast_to_float8_e4m3fn_saturation_behavior(): -1 * (max_val * 2), ], dtype=torch.bfloat16, - device="cuda", + device=device, ) # verify that in eager mode PyTorch casting to float8 is unsaturated @@ -666,7 +738,10 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not (torch.cuda.is_available() or torch.xpu.is_available()), + reason="CUDA or XPU not available", +) @pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") @pytest.mark.parametrize("transpose", [False, True]) @pytest.mark.parametrize( @@ -676,13 +751,14 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): (1, 128, 64), ), ) -def test_scale_shape_matches_qdata(transpose, shape): +@pytest.mark.parametrize("device", devices) +def test_scale_shape_matches_qdata(transpose, shape, device): if len(shape) == 3 and transpose: pytest.skip("transpose not yet implemented for 3D MXTensor") block_size = 32 - x_hp = torch.randn(*shape, device="cuda") + x_hp = torch.randn(*shape, device=device) x = MXTensor.to_mx( x_hp, torch.float8_e4m3fn, @@ -720,7 +796,7 @@ def test_scale_shape_matches_qdata(transpose, shape): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not (torch.cuda.is_available()), reason="CUDA not available") @pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") @pytest.mark.parametrize("elem_dtype", (torch.float8_e4m3fn, torch.float4_e2m1fn_x2)) @pytest.mark.parametrize("transpose", [False, True]) @@ -731,13 +807,14 @@ def test_scale_shape_matches_qdata(transpose, shape): (1, 128, 64), ), ) -def test_swizzle(elem_dtype, transpose, shape): +@pytest.mark.parametrize("device", devices) +def test_swizzle(elem_dtype, transpose, shape, device): if len(shape) == 3 and transpose: pytest.skip("transpose not yet implemented for 3D MXTensor") block_size = 32 - x_hp = torch.randn(*shape, device="cuda") + x_hp = torch.randn(*shape, device=device) x = MXTensor.to_mx( x_hp, elem_dtype, diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 173d99f746..a7dea575ea 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -552,7 +552,7 @@ def triton_f6_e2m3_to_bf16(x: torch.Tensor) -> torch.Tensor: output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) assert x.is_contiguous() - assert x.is_cuda and output.is_cuda + assert (x.is_cuda and output.is_cuda) or (x.is_xpu and output.is_xpu) n_mx_blocks = x.shape[0] grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) @@ -588,7 +588,7 @@ def triton_f6_e3m2_to_bf16(x: torch.Tensor) -> torch.Tensor: output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) assert x.is_contiguous() - assert x.is_cuda and output.is_cuda + assert (x.is_cuda and output.is_cuda) or (x.is_xpu and output.is_xpu) n_mx_blocks = x.shape[0] grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) @@ -628,7 +628,7 @@ def triton_f6_e2m3_to_scaled_bf16( output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) assert x.is_contiguous() - assert x.is_cuda and output.is_cuda + assert (x.is_cuda and output.is_cuda) or (x.is_xpu and output.is_xpu) n_mx_blocks = x.shape[0] grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) @@ -671,7 +671,7 @@ def triton_f6_e3m2_to_scaled_bf16( output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) assert x.is_contiguous() - assert x.is_cuda and output.is_cuda + assert (x.is_cuda and output.is_cuda) or (x.is_xpu and output.is_xpu) n_mx_blocks = x.numel() // packed_mx_block_size grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) From 64f415105a47cea00a546499e750233e329666b5 Mon Sep 17 00:00:00 2001 From: "Xiao, Wang" Date: Mon, 27 Oct 2025 23:11:48 -0700 Subject: [PATCH 2/2] Support mx_tensor and enable it's test on Intel GPU --- torchao/prototype/mx_formats/inference_workflow.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 8725c33b44..4aa26bf167 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -102,6 +102,9 @@ def _mx_inference_linear_transform( module: torch.nn.Module, config: MXFPInferenceConfig ): weight = module.weight + is_swizzled_scales = True + if "xpu" in weight.device.type: + is_swizzled_scales = False assert weight.dtype == torch.bfloat16, ( f"Only supporting bf16 out dtype for now, got {weight.dtype}" @@ -111,7 +114,7 @@ def _mx_inference_linear_transform( block_size=config.block_size, gemm_kernel_choice=config.gemm_kernel_choice, pack_fp6=False, - is_swizzled_scales=True, + is_swizzled_scales=is_swizzled_scales, ) # Convert weight to MX Tensor @@ -122,7 +125,7 @@ def _mx_inference_linear_transform( gemm_kernel_choice=config.gemm_kernel_choice, pack_fp6=False, # TODO act_quant_kwargs=act_quant_kwargs, - is_swizzled_scales=True, + is_swizzled_scales=is_swizzled_scales, ) module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)