Skip to content

Commit 014f555

Browse files
Skip TE test on SM120+ as Float8BlockScaling is currently unsupported in thunder (#2475)
Co-authored-by: Riccardo Felluga <11768013+riccardofelluga@users.noreply.github.com>
1 parent e00b1a2 commit 014f555

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

thunder/tests/test_transformer_engine_executor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
import thunder
66
from thunder.tests.framework import requiresCUDA
77

8+
# NOTE: On SM120/121, TE defaults to using Float8BlockScaling
9+
# which is currently unsupported in thunder, we skip the tests for these SM architectures.
10+
from thunder.tests.utils import skip_on_sm120_and_sm121, is_sm120_orsm121
11+
812
pytest.importorskip("transformer_engine", reason="transformer_engine was not found, skipping the tests.")
913
from thunder.executors.transformer_engineex import transformer_engine_ex
1014
from transformer_engine.common import recipe
@@ -32,6 +36,9 @@ def test_te_linear_forward_backward(fp8_recipe: recipe.Recipe):
3236
if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported):
3337
pytest.skip(msg_mxfp8)
3438

39+
if is_sm120_orsm121 and fp8_recipe is None:
40+
pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported")
41+
3542
# Test Description:
3643
# Verify that `torch.nn.functional.linear` is replaced with `te_linear_*`
3744
# and the output as well as the gradients match for thunder compiled code.
@@ -89,6 +96,9 @@ def test_te_linear_forward_backward_multiple_iteration(fp8_recipe):
8996
if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported):
9097
pytest.skip(msg_mxfp8)
9198

99+
if is_sm120_orsm121 and fp8_recipe is None:
100+
pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported")
101+
92102
# Test Description:
93103
# In this test, we verify whether a model using TransformerEngine Linear
94104
# and transformer_engine executor converge to same state.
@@ -161,6 +171,7 @@ def thunder_model(x):
161171

162172

163173
@requiresCUDA
174+
@skip_on_sm120_and_sm121
164175
def test_te_linear_invalid_inputs():
165176
def assert_not_transformed(x, w):
166177
def fn(x, w):
@@ -185,6 +196,7 @@ def fn(x, w):
185196

186197

187198
@requiresCUDA
199+
@skip_on_sm120_and_sm121
188200
def test_te_with_autocast():
189201
from thunder.transforms.autocast import autocast
190202

@@ -215,6 +227,7 @@ def foo(x, w):
215227
reason="See https://github.com/Lightning-AI/lightning-thunder/issues/2221",
216228
)
217229
@requiresCUDA
230+
@skip_on_sm120_and_sm121
218231
def test_te_with_retain_graph():
219232
def foo(x, w):
220233
return thunder.torch.linear(x, w)
@@ -236,6 +249,7 @@ def foo(x, w):
236249

237250

238251
@requiresCUDA
252+
@skip_on_sm120_and_sm121
239253
def test_te_trace_metadata_propagation():
240254
# This test is to verify that we correctly propagate metadata `_include_te_fp8_autocast` on
241255
# trace using `from_trace`. `_include_te_fp8_autocast` is used to enable wrapping forward trace with `fp8_autocast`.
@@ -267,6 +281,7 @@ def transform_trace_post_optimization(self, computation_trace, **kwargs):
267281
assert any(bsym.sym.name.startswith("te_linear") for bsym in fwd_traces[-1].bound_symbols)
268282

269283

284+
@skip_on_sm120_and_sm121
270285
def test_te_grad_computation_with_intermediate():
271286
# Test for issue - https://github.com/Lightning-AI/lightning-thunder/issues/1966
272287
def fn(x, w):

thunder/tests/test_transformer_engine_v2_executor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
import thunder
88
from thunder.tests.framework import requiresCUDA
99

10+
# NOTE: On SM120/121, TE defaults to using Float8BlockScaling
11+
# which is currently unsupported in thunder, we skip the tests for these SM architectures.
12+
from thunder.tests.utils import skip_on_sm120_and_sm121, is_sm120_orsm121
13+
1014
transformer_engine_module = pytest.importorskip(
1115
"transformer_engine", reason="transformer_engine was not found, skipping the tests."
1216
)
@@ -33,10 +37,14 @@
3337

3438
@requiresCUDA
3539
@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids)
40+
@skip_on_sm120_and_sm121
3641
def test_te_linear_forward_backward(fp8_recipe: recipe.Recipe):
3742
if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported):
3843
pytest.skip(msg_mxfp8)
3944

45+
if is_sm120_orsm121 and fp8_recipe is None:
46+
pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported")
47+
4048
# Test Description:
4149
# Verify that `torch.nn.functional.linear` is replaced with `te_linear_*`
4250
# and the output as well as the gradients match for thunder compiled code.
@@ -96,6 +104,7 @@ def fn(x, w1, w2):
96104

97105
@requiresCUDA
98106
@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids)
107+
@skip_on_sm120_and_sm121
99108
def test_te_linear_forward_backward_multiple_iteration(fp8_recipe: recipe.Recipe):
100109
if not fp8_recipe:
101110
pytest.skip(
@@ -277,6 +286,7 @@ def fn(x, w):
277286

278287

279288
@requiresCUDA
289+
@skip_on_sm120_and_sm121
280290
def test_te_with_autocast():
281291
from thunder.transforms.autocast import autocast
282292

@@ -303,6 +313,7 @@ def foo(x, w):
303313
# NOTE: strict=False as it passes on Blackwell.
304314
@pytest.mark.xfail(strict=False, raises=(RuntimeError, TypeError), reason="Retain graph is not supported by TE")
305315
@requiresCUDA
316+
@skip_on_sm120_and_sm121
306317
def test_te_with_retain_graph():
307318
def foo(x, w):
308319
return thunder.torch.linear(x, w)
@@ -325,6 +336,7 @@ def foo(x, w):
325336

326337

327338
@requiresCUDA
339+
@skip_on_sm120_and_sm121
328340
def test_te_trace_metadata_propagation():
329341
# This test is to verify that we correctly propagate metadata `_include_te_fp8_autocast` on
330342
# trace using `from_trace`. `_include_te_fp8_autocast` is used to enable wrapping forward trace with `fp8_autocast`.
@@ -357,6 +369,7 @@ def transform_trace_post_optimization(self, computation_trace, **kwargs):
357369
assert any(bsym.sym.name.startswith("te_functional_linear") for bsym in fwd_traces[-1].bound_symbols)
358370

359371

372+
@skip_on_sm120_and_sm121
360373
def test_te_grad_computation_with_intermediate():
361374
# Test for issue - https://github.com/Lightning-AI/lightning-thunder/issues/1966
362375
def fn(x, w):
@@ -381,6 +394,7 @@ def fn(x, w):
381394

382395
@requiresCUDA
383396
@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids)
397+
@skip_on_sm120_and_sm121
384398
def test_te_trace_correctness(fp8_recipe: recipe.Recipe):
385399
if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported):
386400
pytest.skip(msg_mxfp8)
@@ -451,6 +465,7 @@ def foo(x, w):
451465
@requiresCUDA
452466
@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids)
453467
@pytest.mark.parametrize("compile_path", ["jit", "ThunderFX"])
468+
@skip_on_sm120_and_sm121
454469
def test_te_activation_checkpointing_trace(fp8_recipe: recipe.Recipe, compile_path: str):
455470
if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported):
456471
pytest.skip(msg_mxfp8)
@@ -505,6 +520,7 @@ def fn(x, w, w2):
505520
@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids)
506521
@pytest.mark.parametrize("compile_path", ["jit", "ThunderFX"])
507522
@pytest.mark.filterwarnings("ignore::FutureWarning") # Coming from TE v2.3
523+
@skip_on_sm120_and_sm121
508524
def test_te_activation_checkpointing_correctness(fp8_recipe: recipe.Recipe, compile_path: str):
509525
if not fp8_recipe:
510526
pytest.skip(

thunder/tests/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import torch
2+
import pytest
3+
import functools
24

35

46
def is_output_differentiable(x):
@@ -36,3 +38,18 @@ def filter_differentiable_outputs(outputs):
3638
outputs = [outputs]
3739

3840
return list(filter(is_output_differentiable, outputs))
41+
42+
43+
def is_sm120_orsm121():
44+
return torch.cuda.get_device_capability() in ((12, 1), (12, 0))
45+
46+
47+
def skip_on_sm120_and_sm121(fn):
48+
@functools.wraps(fn)
49+
def wrapped_fn(*args, **kwargs):
50+
if is_sm120_orsm121():
51+
pytest.skip("Skipped on SM120/121")
52+
else:
53+
fn(*args, **kwargs)
54+
55+
return wrapped_fn

0 commit comments

Comments
 (0)