diff --git a/.github/scripts/ci_test_xpu.sh b/.github/scripts/ci_test_xpu.sh index d765696b40..79114d01c0 100644 --- a/.github/scripts/ci_test_xpu.sh +++ b/.github/scripts/ci_test_xpu.sh @@ -15,3 +15,5 @@ python3 -c "import torch; import torchao; print(f'Torch version: {torch.__versio pip install pytest expecttest parameterized accelerate hf_transfer 'modelscope!=1.15.0' pytest -v -s torchao/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py + +pytest -v -s torchao/test/quantization/ diff --git a/test/quantization/test_gptq.py b/test/quantization/test_gptq.py index 6f7ac10d45..34dafcdbc4 100644 --- a/test/quantization/test_gptq.py +++ b/test/quantization/test_gptq.py @@ -18,13 +18,15 @@ from torchao._models.llama.tokenizer import get_tokenizer from torchao.quantization import Int4WeightOnlyConfig, quantize_ from torchao.quantization.utils import compute_error +from torchao.utils import auto_detect_device torch.manual_seed(0) +_DEVICE = auto_detect_device() + class TestGPTQ(TestCase): @unittest.skip("skipping until we get checkpoints for gpt-fast") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_gptq_quantizer_int4_weight_only(self): from torchao._models._eval import ( LMEvalInputRecorder, @@ -33,7 +35,6 @@ def test_gptq_quantizer_int4_weight_only(self): from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer precision = torch.bfloat16 - device = "cuda" checkpoint_path = Path( "../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth" ) @@ -80,19 +81,19 @@ def test_gptq_quantizer_int4_weight_only(self): ) model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) - model = quantizer.quantize(model, *inputs).cuda() + model = quantizer.quantize(model, *inputs).to(_DEVICE) model.reset_caches() - with torch.device("cuda"): + with torch.device(_DEVICE): model.setup_caches(max_batch_size=1, max_seq_length=model.config.block_size) limit = 1 result = TransformerEvalWrapper( - model.cuda(), + model.to(_DEVICE), tokenizer, model.config.block_size, prepare_inputs_for_model, - device, + _DEVICE, ).run_eval( ["wikitext"], limit, @@ -104,7 +105,6 @@ def test_gptq_quantizer_int4_weight_only(self): class TestMultiTensorFlow(TestCase): - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_add_tensors(self): from torchao.quantization.GPTQ import MultiTensor @@ -116,7 +116,6 @@ def test_multitensor_add_tensors(self): self.assertTrue(torch.equal(mt.values[0], tensor1)) self.assertTrue(torch.equal(mt.values[1], tensor2)) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_pad_unpad(self): from torchao.quantization.GPTQ import MultiTensor @@ -127,7 +126,6 @@ def test_multitensor_pad_unpad(self): mt.unpad() self.assertEqual(mt.count, 1) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_inplace_operation(self): from torchao.quantization.GPTQ import MultiTensor @@ -138,7 +136,6 @@ def test_multitensor_inplace_operation(self): class TestMultiTensorInputRecorder(TestCase): - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_input_recorder(self): from torchao.quantization.GPTQ import MultiTensor, MultiTensorInputRecorder @@ -159,7 +156,7 @@ def test_multitensor_input_recorder(self): self.assertTrue(isinstance(MT_input[2][2], MultiTensor)) self.assertEqual(MT_input[3], torch.float) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_gptq_with_input_recorder(self): from torchao.quantization.GPTQ import ( Int4WeightOnlyGPTQQuantizer, @@ -170,7 +167,7 @@ def test_gptq_with_input_recorder(self): config = ModelArgs(n_layer=2) - with torch.device("cuda"): + with torch.device(_DEVICE): model = Transformer(config) model.setup_caches(max_batch_size=2, max_seq_length=100) idx = torch.randint(1, 10000, (10, 2, 50)).to(torch.int32) @@ -191,7 +188,14 @@ def test_gptq_with_input_recorder(self): args = input_recorder.get_recorded_inputs() - quantizer = Int4WeightOnlyGPTQQuantizer() + if _DEVICE.type == "xpu": + from torchao.dtypes import Int4XPULayout + + quantizer = Int4WeightOnlyGPTQQuantizer( + device=torch.device("xpu"), layout=Int4XPULayout() + ) + else: + quantizer = Int4WeightOnlyGPTQQuantizer() quantizer.quantize(model, *args) diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py index 61000babc1..55a6a87e24 100644 --- a/test/quantization/test_moe_quant.py +++ b/test/quantization/test_moe_quant.py @@ -33,7 +33,13 @@ quantize_, ) from torchao.quantization.utils import compute_error -from torchao.utils import is_sm_at_least_90 +from torchao.testing.utils import skip_if_no_cuda +from torchao.utils import ( + auto_detect_device, + is_sm_at_least_90, +) + +_DEVICE = auto_detect_device() if torch.version.hip is not None: pytest.skip( @@ -54,7 +60,7 @@ def _test_impl_moe_quant( base_class=AffineQuantizedTensor, tensor_impl_class=None, dtype=torch.bfloat16, - device="cuda", + device=_DEVICE, fullgraph=False, ): """ @@ -115,10 +121,8 @@ def _test_impl_moe_quant( ("multiple_tokens", 8, False), ] ) + @skip_if_no_cuda() def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): - if not torch.cuda.is_available(): - self.skipTest("Need CUDA available") - config = MoEQuantConfig( Int4WeightOnlyConfig(version=1), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, @@ -138,6 +142,7 @@ def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): ("multiple_tokens", 8, False), ] ) + @skip_if_no_cuda() def test_int4wo_base(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") @@ -160,10 +165,8 @@ def test_int4wo_base(self, name, num_tokens, fullgraph): ("multiple_tokens", 8, False), ] ) + @skip_if_no_cuda() def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): - if not torch.cuda.is_available(): - self.skipTest("Need CUDA available") - config = MoEQuantConfig( Int8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE ) @@ -182,10 +185,8 @@ def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): ("multiple_tokens", 8, False), ] ) + @skip_if_no_cuda() def test_int8wo_base(self, name, num_tokens, fullgraph): - if not torch.cuda.is_available(): - self.skipTest("Need CUDA available") - config = MoEQuantConfig(Int8WeightOnlyConfig()) tensor_impl_class = PlainAQTTensorImpl @@ -202,6 +203,7 @@ def test_int8wo_base(self, name, num_tokens, fullgraph): ("multiple_tokens", 8, False), ] ) + @skip_if_no_cuda() def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): config = MoEQuantConfig(Int8WeightOnlyConfig()) tensor_impl_class = PlainAQTTensorImpl @@ -219,10 +221,8 @@ def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): ("multiple_tokens", 32, False), ] ) + @skip_if_no_cuda() def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): - if not torch.cuda.is_available(): - self.skipTest("Need CUDA available") - config = MoEQuantConfig( Int8DynamicActivationInt8WeightConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, @@ -242,10 +242,8 @@ def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): ("multiple_tokens", 32, False), ] ) + @skip_if_no_cuda() def test_int8dq_base(self, name, num_tokens, fullgraph): - if not torch.cuda.is_available(): - self.skipTest("Need CUDA available") - config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) base_class = LinearActivationQuantizedTensor @@ -263,9 +261,8 @@ def test_int8dq_base(self, name, num_tokens, fullgraph): ("multiple_tokens", 8, False), ] ) + @skip_if_no_cuda() def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph): - if not torch.cuda.is_available(): - self.skipTest("Need CUDA available") if not is_sm_at_least_90(): self.skipTest("Requires CUDA capability >= 9.0") @@ -335,9 +332,8 @@ def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph): ("multiple_tokens", 8, False), ] ) + @skip_if_no_cuda() def test_fp8dq_base(self, name, num_tokens, fullgraph): - if not torch.cuda.is_available(): - self.skipTest("Need CUDA available") if not is_sm_at_least_90(): self.skipTest("Requires CUDA capability >= 9.0") diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index f523cb091c..73b8009a81 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -98,12 +98,15 @@ ) from torchao.utils import ( _is_fbgemm_gpu_genai_available, + auto_detect_device, is_fbcode, is_sm_at_least_89, ) # TODO: put this in a common test utils file _CUDA_IS_AVAILABLE = torch.cuda.is_available() +_GPU_IS_AVAILABLE = torch.accelerator.is_available() +_DEVICE = auto_detect_device() class Sub(torch.nn.Module): @@ -347,7 +350,7 @@ def _set_ptq_weight( group_size, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to("cuda"), + q_weight.to(_DEVICE), qat_linear.inner_k_tiles, ) ptq_linear.weight = q_weight @@ -600,13 +603,15 @@ def _assert_close_4w(self, val, ref): print(mean_err) self.assertTrue(mean_err < 0.05) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf( + not _GPU_IS_AVAILABLE, "skipping when cuda or xpu is not available" + ) def test_qat_4w_primitives(self): n_bit = 4 group_size = 32 inner_k_tiles = 8 scales_precision = torch.bfloat16 - device = torch.device("cuda") + device = torch.device(_DEVICE) dtype = torch.bfloat16 torch.manual_seed(self.SEED) x = torch.randn(100, 256, dtype=dtype, device=device) @@ -699,11 +704,12 @@ def test_qat_4w_quantizer(self): group_size = 32 inner_k_tiles = 8 - device = torch.device("cuda") + device = torch.device(_DEVICE) dtype = torch.bfloat16 torch.manual_seed(self.SEED) m = M().to(device).to(dtype) m2 = copy.deepcopy(m) + qat_quantizer = Int4WeightOnlyQATQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 577ca6789a..164cf6bad0 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -60,14 +60,17 @@ ) from torchao.quantization.quant_primitives import MappingType from torchao.quantization.utils import compute_error -from torchao.testing.utils import skip_if_rocm +from torchao.testing.utils import skip_if_rocm, skip_if_xpu from torchao.utils import ( + auto_detect_device, is_sm_at_least_89, is_sm_at_least_90, torch_version_at_least, unwrap_tensor_subclass, ) +_DEVICE = auto_detect_device() + try: import gemlite # noqa: F401 @@ -258,7 +261,7 @@ def api(model): m2.load_state_dict(state_dict) m2 = m2.to(device="cuda") - example_inputs = map(lambda x: x.cuda(), example_inputs) + example_inputs = map(lambda x: x.to(_DEVICE), example_inputs) res = m2(*example_inputs) # TODO: figure out why ROCm has a larger error @@ -290,12 +293,13 @@ def test_8da4w_quantizer_linear_bias(self): m(*example_inputs) @unittest.skip("skipping until we get checkpoints for gpt-fast") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_quantizer_int4_weight_only(self): from torchao._models._eval import TransformerEvalWrapper from torchao.quantization.linear_quant_modules import Int4WeightOnlyQuantizer precision = torch.bfloat16 - device = "cuda" + device = _DEVICE checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") model = Transformer.from_name(checkpoint_path.parent.name) checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) @@ -312,7 +316,7 @@ def test_quantizer_int4_weight_only(self): quantizer = Int4WeightOnlyQuantizer( groupsize, ) - model = quantizer.quantize(model).cuda() + model = quantizer.quantize(model).to(_DEVICE) result = TransformerEvalWrapper( model, tokenizer, @@ -328,11 +332,12 @@ def test_quantizer_int4_weight_only(self): ) @unittest.skip("skipping until we get checkpoints for gpt-fast") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_eval_wrapper(self): from torchao._models._eval import TransformerEvalWrapper precision = torch.bfloat16 - device = "cuda" + device = _DEVICE checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") model = Transformer.from_name(checkpoint_path.parent.name) checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) @@ -361,11 +366,12 @@ def test_eval_wrapper(self): # EVAL IS CURRENTLY BROKEN FOR LLAMA 3, VERY LOW ACCURACY @unittest.skip("skipping until we get checkpoints for gpt-fast") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_eval_wrapper_llama3(self): from torchao._models._eval import TransformerEvalWrapper precision = torch.bfloat16 - device = "cuda" + device = _DEVICE checkpoint_path = Path( ".../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth" ) @@ -534,7 +540,7 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): assert "aten.mm.default" not in code[0] # TODO(#1690): move to new config names - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") @common_utils.parametrize( "config", [ @@ -551,6 +557,7 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): UIntXWeightOnlyConfig(dtype=torch.uint4), ], ) + @skip_if_xpu("XPU enablement in progress") @skip_if_rocm("ROCm enablement in progress") def test_workflow_e2e_numerics(self, config): """ @@ -579,17 +586,17 @@ def test_workflow_e2e_numerics(self, config): # scale has to be moved to cuda here because the parametrization init # code happens before gating for cuda availability if isinstance(config, Float8StaticActivationFloat8WeightConfig): - config.scale = config.scale.to("cuda") + config.scale = config.scale.to(_DEVICE) dtype = torch.bfloat16 if isinstance(config, GemliteUIntXWeightOnlyConfig): dtype = torch.float16 # set up inputs - x = torch.randn(128, 128, device="cuda", dtype=dtype) + x = torch.randn(128, 128, device=_DEVICE, dtype=dtype) # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 # is that expected? - m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().to(dtype) + m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).to(_DEVICE).to(dtype) m_q = copy.deepcopy(m_ref) # quantize @@ -602,13 +609,13 @@ def test_workflow_e2e_numerics(self, config): sqnr = compute_error(y_ref, y_q) assert sqnr >= 16.5, f"SQNR {sqnr} is too low" - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_module_fqn_to_config_default(self): config1 = Int4WeightOnlyConfig(group_size=32, version=1) config2 = Int8WeightOnlyConfig() config = ModuleFqnToConfig({"_default": config1, "linear2": config2}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) quantize_(model, config) model(*example_inputs) assert isinstance(model.linear1.weight, AffineQuantizedTensor) @@ -616,13 +623,13 @@ def test_module_fqn_to_config_default(self): assert isinstance(model.linear2.weight, AffineQuantizedTensor) assert isinstance(model.linear2.weight._layout, PlainLayout) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_module_fqn_to_config_module_name(self): config1 = Int4WeightOnlyConfig(group_size=32, version=1) config2 = Int8WeightOnlyConfig() config = ModuleFqnToConfig({"linear1": config1, "linear2": config2}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) quantize_(model, config) model(*example_inputs) assert isinstance(model.linear1.weight, AffineQuantizedTensor) @@ -756,25 +763,25 @@ def test_module_fqn_to_config_embedding_linear(self): assert isinstance(model.emb.weight, IntxUnpackedToInt8Tensor) assert isinstance(model.linear.weight, IntxUnpackedToInt8Tensor) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_module_fqn_to_config_skip(self): config1 = Int4WeightOnlyConfig(group_size=32, version=1) config = ModuleFqnToConfig({"_default": config1, "linear2": None}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) quantize_(model, config) model(*example_inputs) assert isinstance(model.linear1.weight, AffineQuantizedTensor) assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout) assert not isinstance(model.linear2.weight, AffineQuantizedTensor) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_int4wo_cuda_serialization(self): config = Int4WeightOnlyConfig(group_size=32, version=1) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) # quantize in cuda quantize_(model, config) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) model(*example_inputs) with tempfile.NamedTemporaryFile() as ckpt: # save checkpoint in cuda @@ -783,7 +790,7 @@ def test_int4wo_cuda_serialization(self): # This is what torchtune does: https://github.com/pytorch/torchtune/blob/v0.6.1/torchtune/training/checkpointing/_utils.py#L253 sd = torch.load(ckpt.name, weights_only=False, map_location="cpu") for k, v in sd.items(): - sd[k] = v.to("cuda") + sd[k] = v.to(_DEVICE) # load state_dict in cuda model.load_state_dict(sd, assign=True) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index bed8421671..c251d71915 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -30,6 +30,7 @@ groupwise_affine_quantize_tensor_from_qparams, ) from torchao.utils import ( + auto_detect_device, check_cpu_version, check_xpu_version, is_fbcode, @@ -38,6 +39,8 @@ _SEED = 1234 torch.manual_seed(_SEED) +_DEVICE = auto_detect_device() + # Helper function to run a function twice # and verify that the result is the same. @@ -575,7 +578,7 @@ def test_choose_qparams_tensor_asym_eps(self): ) def test_get_group_qparams_symmetric_memory(self): """Check the memory usage of the op""" - weight = torch.randn(1024, 1024).to(device="cuda") + weight = torch.randn(1024, 1024).to(device=_DEVICE) original_mem_use = torch.cuda.memory_allocated() n_bit = 4 groupsize = 128 diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index a1dc40fdd3..aef3ea3ecf 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -98,6 +98,47 @@ def wrapper(*args, **kwargs): return decorator +def skip_if_no_xpu(message=None): + """Decorator to skip tests on ROCm platform with custom message. + + Args: + message (str, optional): Additional information about why the test is skipped. + """ + import unittest + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not torch.xpu.is_available(): + skip_message = "Skipping the test in XPU" + if message: + skip_message += f": {message}" + unittest.skip(skip_message) + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def skip_if_xpu(message=None): + """ + Decorator to skip tests if XPU is available. + + Args: + message (str, optional): Additional information about why the test is skipped. + """ + + def decorator(func): + reason = "Skipping the test on XPU" + if message: + reason += f": {message}" + + return unittest.skipIf(torch.xpu.is_available(), reason)(func) + + return decorator + + def skip_if_no_cuda(): import unittest diff --git a/torchao/utils.py b/torchao/utils.py index 5af3e00cfa..4ebd2d781c 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -147,6 +147,13 @@ def get_available_devices(): return devices +def auto_detect_device(): + if torch.accelerator.is_available(): + return torch.accelerator.current_accelerator() + else: + return "cpu" + + def get_compute_capability(): if torch.cuda.is_available(): capability = torch.cuda.get_device_capability()