77import thunder
88from thunder .tests .framework import requiresCUDA
99
10+ # NOTE: On SM120/121, TE defaults to using Float8BlockScaling
11+ # which is currently unsupported in thunder, we skip the tests for these SM architectures.
12+ from thunder .tests .utils import skip_on_sm120_and_sm121 , is_sm120_orsm121
13+
1014transformer_engine_module = pytest .importorskip (
1115 "transformer_engine" , reason = "transformer_engine was not found, skipping the tests."
1216)
3337
3438@requiresCUDA
3539@pytest .mark .parametrize ("fp8_recipe" , recipes , ids = recipe_ids )
40+ @skip_on_sm120_and_sm121
3641def test_te_linear_forward_backward (fp8_recipe : recipe .Recipe ):
3742 if fp8_recipe and not (fp8_recipe .delayed () or is_mxfp8_supported ):
3843 pytest .skip (msg_mxfp8 )
3944
45+ if is_sm120_orsm121 and fp8_recipe is None :
46+ pytest .skip ("On SM120/121, default recipe is Float8BlockScaling which is not supported" )
47+
4048 # Test Description:
4149 # Verify that `torch.nn.functional.linear` is replaced with `te_linear_*`
4250 # and the output as well as the gradients match for thunder compiled code.
@@ -96,6 +104,7 @@ def fn(x, w1, w2):
96104
97105@requiresCUDA
98106@pytest .mark .parametrize ("fp8_recipe" , recipes , ids = recipe_ids )
107+ @skip_on_sm120_and_sm121
99108def test_te_linear_forward_backward_multiple_iteration (fp8_recipe : recipe .Recipe ):
100109 if not fp8_recipe :
101110 pytest .skip (
@@ -277,6 +286,7 @@ def fn(x, w):
277286
278287
279288@requiresCUDA
289+ @skip_on_sm120_and_sm121
280290def test_te_with_autocast ():
281291 from thunder .transforms .autocast import autocast
282292
@@ -303,6 +313,7 @@ def foo(x, w):
303313# NOTE: strict=False as it passes on Blackwell.
304314@pytest .mark .xfail (strict = False , raises = (RuntimeError , TypeError ), reason = "Retain graph is not supported by TE" )
305315@requiresCUDA
316+ @skip_on_sm120_and_sm121
306317def test_te_with_retain_graph ():
307318 def foo (x , w ):
308319 return thunder .torch .linear (x , w )
@@ -325,6 +336,7 @@ def foo(x, w):
325336
326337
327338@requiresCUDA
339+ @skip_on_sm120_and_sm121
328340def test_te_trace_metadata_propagation ():
329341 # This test is to verify that we correctly propagate metadata `_include_te_fp8_autocast` on
330342 # trace using `from_trace`. `_include_te_fp8_autocast` is used to enable wrapping forward trace with `fp8_autocast`.
@@ -357,6 +369,7 @@ def transform_trace_post_optimization(self, computation_trace, **kwargs):
357369 assert any (bsym .sym .name .startswith ("te_functional_linear" ) for bsym in fwd_traces [- 1 ].bound_symbols )
358370
359371
372+ @skip_on_sm120_and_sm121
360373def test_te_grad_computation_with_intermediate ():
361374 # Test for issue - https://github.com/Lightning-AI/lightning-thunder/issues/1966
362375 def fn (x , w ):
@@ -381,6 +394,7 @@ def fn(x, w):
381394
382395@requiresCUDA
383396@pytest .mark .parametrize ("fp8_recipe" , recipes , ids = recipe_ids )
397+ @skip_on_sm120_and_sm121
384398def test_te_trace_correctness (fp8_recipe : recipe .Recipe ):
385399 if fp8_recipe and not (fp8_recipe .delayed () or is_mxfp8_supported ):
386400 pytest .skip (msg_mxfp8 )
@@ -451,6 +465,7 @@ def foo(x, w):
451465@requiresCUDA
452466@pytest .mark .parametrize ("fp8_recipe" , recipes , ids = recipe_ids )
453467@pytest .mark .parametrize ("compile_path" , ["jit" , "ThunderFX" ])
468+ @skip_on_sm120_and_sm121
454469def test_te_activation_checkpointing_trace (fp8_recipe : recipe .Recipe , compile_path : str ):
455470 if fp8_recipe and not (fp8_recipe .delayed () or is_mxfp8_supported ):
456471 pytest .skip (msg_mxfp8 )
@@ -505,6 +520,7 @@ def fn(x, w, w2):
505520@pytest .mark .parametrize ("fp8_recipe" , recipes , ids = recipe_ids )
506521@pytest .mark .parametrize ("compile_path" , ["jit" , "ThunderFX" ])
507522@pytest .mark .filterwarnings ("ignore::FutureWarning" ) # Coming from TE v2.3
523+ @skip_on_sm120_and_sm121
508524def test_te_activation_checkpointing_correctness (fp8_recipe : recipe .Recipe , compile_path : str ):
509525 if not fp8_recipe :
510526 pytest .skip (
0 commit comments