Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions test/prototype/mx_formats/test_inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


devices = []
if torch.cuda.is_available():
devices.append("cuda")

if torch.xpu.is_available():
devices.append("xpu")


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to use utils function of torchao to get the available device.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_available_devices(https://github.com/pytorch/ao/blob/main/torchao/utils.py#L139) interface can get devices. However, the output includes "cpu" which is not what we need.

# source: https://stackoverflow.com/a/22638709
@pytest.fixture(autouse=True)
def run_around_tests():
Expand Down Expand Up @@ -63,16 +71,22 @@ def cuda_kernel_profiler(kernel_pattern):
result["found"] = any(kernel_pattern in name for name in kernel_names)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not (torch.cuda.is_available() or torch.xpu.is_available()),
reason="CUDA or XPU not available",
)
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
)
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("compile", [True, False])
@pytest.mark.parametrize("emulate", [True, False])
@pytest.mark.parametrize(
"emulate", [True, False] if (not torch.xpu.is_available()) else [True]
)
@pytest.mark.parametrize("use_inference_mode", [True, False])
@pytest.mark.parametrize("x_rank", [2, 3])
@pytest.mark.parametrize("device", devices)
@torch.no_grad()
@skip_if_rocm(
"ROCm float4 gemm require gfx950"
Expand All @@ -84,25 +98,31 @@ def test_inference_workflow_mx(
emulate: bool,
use_inference_mode: bool,
x_rank: int,
device,
):
"""
Smoke test for inference compile
"""
# TODO(future): figure out why these CUDA capability conditions are not properly
# applied when inside `pytest.mark.skipif` for this test
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if (
elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
) and torch.cuda.is_available():
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
elif not is_sm_at_least_100() and not emulate:
pytest.skip("CUDA capability >= 10.0 required for mxfp8 gemm")
elif elem_dtype == torch.float4_e2m1fn_x2:
elif (elem_dtype == torch.float4_e2m1fn_x2) and torch.cuda.is_available():
if not is_sm_at_least_100() and not emulate:
pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm")
elif compile:
# TODO(future PR): investigate and fix this
pytest.skip("mxfp4 + compile currently does not work, low SQNR")
pytest.skip("mxfp4 + compile currently does not work on CUDA, low SQNR")

m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
if (elem_dtype == torch.float4_e2m1fn_x2) and torch.xpu.is_available() and compile:
pytest.skip("mxfp4 + compile currently does not work on XPU, low SQNR")

m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device=device)
m_mx = copy.deepcopy(m)

if emulate:
Expand All @@ -120,10 +140,9 @@ def test_inference_workflow_mx(
if compile:
m_mx = torch.compile(m_mx, fullgraph=True)

x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
x = torch.randn(128, 32, device=device, dtype=torch.bfloat16)
if x_rank == 3:
x = x.unsqueeze(0)

y_ref = m(x)
if use_inference_mode:
with torch.inference_mode():
Expand Down
Loading
Loading