From 9b2cfe86ebe35dbd607f08d9e0f4af53a3d76bd8 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Wed, 15 Oct 2025 12:41:38 -0700 Subject: [PATCH 1/5] Add CutlassSemiSparseFp8Tensor Summary: Moving float8 cutlass sparse layout into its own class: https://github.com/pytorch/ao/blob/main/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py Differential Revision: D84467190 --- .../float8/cutlass_semi_sparse_fp8_tensor.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py diff --git a/torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py b/torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py new file mode 100644 index 0000000000..cc3118d1fb --- /dev/null +++ b/torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import torch +from torchao.utils import TorchAOBaseTensor + +__all__ = ["CutlassSemiSparseFp8Tensor"] +aten = torch.ops.aten + +class CutlassSemiSparseFp8Tensor(TorchAOBaseTensor): + tensor_data_names = ["sparse", "scale", "meta"] + + def __new__( + cls, + sparse: torch.Tensor, + meta: torch.Tensor, + scale: torch.Tensor, + ): + kwargs = {} + kwargs["device"] = sparse.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + shape = (sparse.shape[0], 2 * sparse.shape[-1]) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + + def __init__( + self, + sparse: torch.Tensor, + meta: torch.Tensor, + scale: torch.Tensor, + ): + super().__init__() + self.sparse = sparse + self.meta = meta + self.scale = scale + + def _quantization_type(self): + return f"shape={self.shape}, device={self.device}, dtype={self.dtype}" + + + @classmethod + def from_hp( + ): + raise NotImplementedError("CutlassSemiSparseFp8Tensor.from_hp is not implemented yet") + + +implements = CutlassSemiSparseFp8Tensor.implements +implements_torch_function = CutlassSemiSparseFp8Tensor.implements_torch_function + +CutlassSemiSparseFp8Tensor.__module__ = "torchao.quantization" + +# Allow a model with CutlassSemiSparseFp8Tensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([CutlassSemiSparseFp8Tensor]) From fc80e43499ce659405fcb49c71f9aeb95996b5d7 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Wed, 15 Oct 2025 15:18:44 -0700 Subject: [PATCH 2/5] Implement packing and linear Signed-off-by: Benji Beck --- .../float8/test_float8_semi_sparse.py | 108 +++++++++++++++++ torchao/quantization/__init__.py | 2 + torchao/quantization/quant_api.py | 2 + .../quantize_/common/packing_format.py | 1 + .../quantize_/workflows/__init__.py | 4 + .../float8/cutlass_semi_sparse_fp8_tensor.py | 56 --------- .../float8/float8_semi_sparse_tensor.py | 114 ++++++++++++++++++ 7 files changed, 231 insertions(+), 56 deletions(-) create mode 100644 test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py delete mode 100644 torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py create mode 100644 torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py diff --git a/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py b/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py new file mode 100644 index 0000000000..fe0eeddd55 --- /dev/null +++ b/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.quantization import ( + Float8WeightOnlyConfig, + quantize_, +) +from torchao.quantization.utils import compute_error +from torchao.sparsity.sparse_api import apply_fake_sparsity +from torchao.testing.utils import skip_if_rocm +from torchao.utils import torch_version_at_least + +BF16_ACT_CONFIG = Float8WeightOnlyConfig( + group_size=128, + packing_format="cutlass_semi_sparse", +) + + +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +class TestFloat8SemiSparseTensor(TestCase): + def setUp(self): + self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] + + @skip_if_rocm("ROCm enablement in progress") + @parametrize("config", [BF16_ACT_CONFIG]) + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 512, 128), + ((2, 32, 128), 256, 12), + ], + ) + def test_linear(self, config, sizes): + dtype = torch.bfloat16 + device = "cuda" + + M, N, K = sizes + input = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) + + apply_fake_sparsity(linear) + original = linear(input) + quantize_(linear, config) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + compiled_linear = torch.compile(linear) + quantized_and_compiled = compiled_linear(input) + self.assertTrue(compute_error(original, quantized_and_compiled) > 20) + + @skip_if_rocm("ROCm enablement in progress") + @unittest.skip("Fix later") + @parametrize("config", [BF16_ACT_CONFIG]) + def test_to_device(self, config): + for device in self.GPU_DEVICES: + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device=device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device) + + @skip_if_rocm("ROCm enablement in progress") + @parametrize("config", [BF16_ACT_CONFIG]) + def test_module_path(self, config): + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear.cuda(), config) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + +instantiate_parametrized_tests(TestFloat8SemiSparseTensor) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index aa19aa1890..b44bcb107c 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -78,6 +78,7 @@ quantize_affine, ) from .quantize_.workflows import ( + Float8SemiSparseTensor, Float8Tensor, Int4MarlinSparseTensor, Int4OpaqueTensor, @@ -148,6 +149,7 @@ "Int4TilePackedTo4dTensor", "Float8Tensor", "Int4OpaqueTensor", + "Float8SemiSparseTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 139b14cf3f..d9f3026913 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1336,6 +1336,7 @@ def _int8_weight_only_quantize_tensor(weight, config): if group_size is None: group_size = weight.shape[-1] block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) + # todo: support fp8 semi-sparse new_weight = to_affine_quantized_intx( weight, mapping_type, @@ -1584,6 +1585,7 @@ class Float8WeightOnlyConfig(AOBaseConfig): weight_dtype: torch.dtype = e4m3_dtype set_inductor_config: bool = True version: int = 2 + # todo: add packing format def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig") diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index c6546c55f9..9f547289f8 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -32,3 +32,4 @@ class PackingFormat(str, Enum): needed for the rest of the system to understand the specific format that's adopted. """ OPAQUE = "opaque" + # todo: add semi-sparse diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 4307637f8e..7166e244a6 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -1,3 +1,6 @@ +from .float8.float8_semi_sparse_tensor import ( + Float8SemiSparseTensor, +) from .float8.float8_tensor import ( Float8Tensor, QuantizeTensorToFloat8Kwargs, @@ -38,6 +41,7 @@ "Int4PlainInt32Tensor", "Int4TilePackedTo4dTensor", "Float8Tensor", + "Float8SemiSparseTensor", "QuantizeTensorToFloat8Kwargs", "Int4OpaqueTensor", "Int4ChooseQParamsAlgorithm", diff --git a/torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py b/torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py deleted file mode 100644 index cc3118d1fb..0000000000 --- a/torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -import torch -from torchao.utils import TorchAOBaseTensor - -__all__ = ["CutlassSemiSparseFp8Tensor"] -aten = torch.ops.aten - -class CutlassSemiSparseFp8Tensor(TorchAOBaseTensor): - tensor_data_names = ["sparse", "scale", "meta"] - - def __new__( - cls, - sparse: torch.Tensor, - meta: torch.Tensor, - scale: torch.Tensor, - ): - kwargs = {} - kwargs["device"] = sparse.device - kwargs["dtype"] = scale.dtype - kwargs["requires_grad"] = False - shape = (sparse.shape[0], 2 * sparse.shape[-1]) - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - - def __init__( - self, - sparse: torch.Tensor, - meta: torch.Tensor, - scale: torch.Tensor, - ): - super().__init__() - self.sparse = sparse - self.meta = meta - self.scale = scale - - def _quantization_type(self): - return f"shape={self.shape}, device={self.device}, dtype={self.dtype}" - - - @classmethod - def from_hp( - ): - raise NotImplementedError("CutlassSemiSparseFp8Tensor.from_hp is not implemented yet") - - -implements = CutlassSemiSparseFp8Tensor.implements -implements_torch_function = CutlassSemiSparseFp8Tensor.implements_torch_function - -CutlassSemiSparseFp8Tensor.__module__ = "torchao.quantization" - -# Allow a model with CutlassSemiSparseFp8Tensor weights to be loaded with `weights_only=True` -torch.serialization.add_safe_globals([CutlassSemiSparseFp8Tensor]) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py new file mode 100644 index 0000000000..78e58cbf68 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from typing import List + +import torch + +from torchao.ops import to_sparse_semi_structured_cutlass_sm9x_f8 +from torchao.quantization.quant_primitives import ( + _choose_scale_float8, + _quantize_affine_float8, +) +from torchao.utils import TorchAOBaseTensor + +__all__ = ["Float8SemiSparseTensor"] +aten = torch.ops.aten + + +class Float8SemiSparseTensor(TorchAOBaseTensor): + tensor_data_names = ["sparse", "scale", "meta"] + + def __new__( + cls, + sparse: torch.Tensor, + meta: torch.Tensor, + scale: torch.Tensor, + ): + kwargs = {} + kwargs["device"] = sparse.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + shape = (sparse.shape[0], 2 * sparse.shape[-1]) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + sparse: torch.Tensor, + meta: torch.Tensor, + scale: torch.Tensor, + ): + super().__init__() + self.sparse = sparse + self.meta = meta + self.scale = scale + + def _quantization_type(self): + return f"shape={self.shape}, device={self.device}, dtype={self.dtype}" + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: List[int], + ): + from torchao.sparsity.utils import mask_creator + + dense = w * mask_creator(w).bool() + + scale = _choose_scale_float8( + dense, + block_size=block_size, + float8_dtype=torch.float8_e4m3fn, + ) + + w_fp8 = _quantize_affine_float8( + dense, + scale=scale, + float8_dtype=torch.float8_e4m3fn, + ) + + sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(w_fp8) + + return cls( + sparse, + meta, + scale, + ) + + +implements = Float8SemiSparseTensor.implements +implements_torch_function = Float8SemiSparseTensor.implements_torch_function + + +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) +def _(func, types, args, kwargs): + from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8 + + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + input = input_tensor.qdata + input_scale = input_tensor.scale + weight = weight_tensor.sparse + weight_meta = weight_tensor.meta + weight_scale = weight_tensor.scale + out_dtype = input_tensor.dtype + + out = rowwise_scaled_linear_sparse_cutlass_f8f8( + input, input_scale, weight, weight_meta, weight_scale, bias, out_dtype + ) + + return out + + +Float8SemiSparseTensor.__module__ = "torchao.quantization" + +# Allow a model with Float8SemiSparseTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Float8SemiSparseTensor]) From 4199d0659aa99334204ec4a4d07a745d08927829 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sun, 26 Oct 2025 12:27:14 -0700 Subject: [PATCH 3/5] Add test for fp8 semi-sparse vs. dense Signed-off-by: Benji Beck --- .../float8/test_float8_semi_sparse.py | 81 +++++-------------- .../float8/float8_semi_sparse_tensor.py | 67 +++++++++++---- .../workflows/float8/float8_tensor.py | 7 +- 3 files changed, 75 insertions(+), 80 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py b/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py index fe0eeddd55..d1bf7600de 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py @@ -4,9 +4,10 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import tempfile import unittest - +from torchao.quantization.quantize_.workflows.float8.float8_semi_sparse_tensor import Float8SemiSparseTensor +from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor +from torchao.float8.inference import Float8MMConfig import torch from torch.testing._internal.common_utils import ( TestCase, @@ -14,39 +15,27 @@ parametrize, run_tests, ) - -from torchao.quantization import ( - Float8WeightOnlyConfig, - quantize_, -) -from torchao.quantization.utils import compute_error from torchao.sparsity.sparse_api import apply_fake_sparsity from torchao.testing.utils import skip_if_rocm -from torchao.utils import torch_version_at_least +from torchao.utils import is_sm_at_least_90 -BF16_ACT_CONFIG = Float8WeightOnlyConfig( - group_size=128, - packing_format="cutlass_semi_sparse", -) - -@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") +@unittest.skipIf(not is_sm_at_least_90(), "Need H100+ to run") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") class TestFloat8SemiSparseTensor(TestCase): def setUp(self): self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] @skip_if_rocm("ROCm enablement in progress") - @parametrize("config", [BF16_ACT_CONFIG]) @parametrize( "sizes", [ ((128,), 256, 128), ((32, 128), 512, 128), - ((2, 32, 128), 256, 12), + ((2, 32, 128), 256, 128), ], ) - def test_linear(self, config, sizes): + def test_sparse_vs_dense_fp8(self, sizes): dtype = torch.bfloat16 device = "cuda" @@ -55,52 +44,20 @@ def test_linear(self, config, sizes): linear = torch.nn.Linear(K, N, dtype=dtype, device=device) apply_fake_sparsity(linear) - original = linear(input) - quantize_(linear, config) - quantized = linear(input) - self.assertTrue(compute_error(original, quantized) > 20) - - compiled_linear = torch.compile(linear) - quantized_and_compiled = compiled_linear(input) - self.assertTrue(compute_error(original, quantized_and_compiled) > 20) - - @skip_if_rocm("ROCm enablement in progress") - @unittest.skip("Fix later") - @parametrize("config", [BF16_ACT_CONFIG]) - def test_to_device(self, config): - for device in self.GPU_DEVICES: - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, config) - linear.to(device) - - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, config) - linear.to(device=device) - - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, config) - linear.to(device) - - @skip_if_rocm("ROCm enablement in progress") - @parametrize("config", [BF16_ACT_CONFIG]) - def test_module_path(self, config): - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear.cuda(), config) - self.assertEqual( - str(type(linear.weight)), - "", + + mm_config = Float8MMConfig(use_fast_accum=True) + input_fp8 = Float8Tensor.from_hp(input, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config) + + weight_fp8 = Float8Tensor.from_hp(linear.weight.data, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config) + dense_output = torch.nn.functional.linear(input_fp8, weight_fp8, linear.bias) + + weight_sparse_fp8 = Float8SemiSparseTensor.from_hp(linear.weight.data, [1, K]) + sparse_output = torch.nn.functional.linear(input_fp8, weight_sparse_fp8, linear.bias) + + torch.testing.assert_close( + dense_output, sparse_output, atol=3e-1, rtol=3e-1 ) - with tempfile.NamedTemporaryFile() as f: - torch.save(linear.state_dict(), f) - f.seek(0) - state_dict = torch.load(f) - self.assertEqual( - str(type(state_dict["weight"])), - "", - ) - - instantiate_parametrized_tests(TestFloat8SemiSparseTensor) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py index 78e58cbf68..a7f4adeb8a 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py @@ -19,7 +19,7 @@ class Float8SemiSparseTensor(TorchAOBaseTensor): - tensor_data_names = ["sparse", "scale", "meta"] + tensor_data_names = ["sparse", "meta", "scale"] def __new__( cls, @@ -83,29 +83,66 @@ def from_hp( implements_torch_function = Float8SemiSparseTensor.implements_torch_function -@implements(aten.linear.default) -@implements_torch_function(torch.nn.functional.linear) +@implements(aten.t.default) def _(func, types, args, kwargs): - from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8 - - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, + from torch.utils._python_dispatch import return_and_correct_aliasing + + self = args[0] + new = Float8SemiSparseTensor( + sparse=self.sparse, + meta=self.meta, + scale=self.scale, ) + return return_and_correct_aliasing(func, args, kwargs, new) + - input = input_tensor.qdata - input_scale = input_tensor.scale +def _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias): + from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8 + from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor + + if isinstance(input_tensor, Float8Tensor): + input = input_tensor.qdata + input_scale = input_tensor.scale + out_dtype = input_tensor.dtype + else: + input = input_tensor.qdata + input_scale = input_tensor.scale + out_dtype = input_tensor.dtype + weight = weight_tensor.sparse weight_meta = weight_tensor.meta weight_scale = weight_tensor.scale - out_dtype = input_tensor.dtype - - out = rowwise_scaled_linear_sparse_cutlass_f8f8( + + # Reshape input_scale if needed: kernel expects scale to match input shape minus last dim + # For input [B, K], scale should be [B] not [B, 1] + if input_scale.dim() > input.dim() - 1: + input_scale = input_scale.squeeze(-1) + + return rowwise_scaled_linear_sparse_cutlass_f8f8( input, input_scale, weight, weight_meta, weight_scale, bias, out_dtype ) - return out + +@implements([aten.mm.default, aten.addmm.default]) +def _(func, types, args, kwargs): + if func == aten.addmm.default: + bias, input_tensor, weight_tensor = args + else: # aten.mm.default + input_tensor, weight_tensor = args + bias = None + + return _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias) + + +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + return _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias) Float8SemiSparseTensor.__module__ = "torchao.quantization" diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 47395a15af..9814dc4c4f 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -256,9 +256,10 @@ def _(func, types, args, kwargs): args[1], args[2] if len(args) > 2 else None, ) - assert isinstance(weight_tensor, Float8Tensor), ( - f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}" - ) + + # If weight is not Float8Tensor, return NotImplemented to allow weight's dispatch to handle it + if not isinstance(weight_tensor, Float8Tensor): + return NotImplemented act_quant_kwargs = weight_tensor.act_quant_kwargs # quantizing activation, if `act_quant_kwargs` is specified From 450d4f65e2bcf67278ca968f97b6558f77b8f48d Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sun, 26 Oct 2025 12:29:38 -0700 Subject: [PATCH 4/5] Clean up comments Signed-off-by: Benji Beck --- torchao/quantization/quant_api.py | 2 -- torchao/quantization/quantize_/common/packing_format.py | 1 - 2 files changed, 3 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index d9f3026913..139b14cf3f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1336,7 +1336,6 @@ def _int8_weight_only_quantize_tensor(weight, config): if group_size is None: group_size = weight.shape[-1] block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) - # todo: support fp8 semi-sparse new_weight = to_affine_quantized_intx( weight, mapping_type, @@ -1585,7 +1584,6 @@ class Float8WeightOnlyConfig(AOBaseConfig): weight_dtype: torch.dtype = e4m3_dtype set_inductor_config: bool = True version: int = 2 - # todo: add packing format def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig") diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index 9f547289f8..c6546c55f9 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -32,4 +32,3 @@ class PackingFormat(str, Enum): needed for the rest of the system to understand the specific format that's adopted. """ OPAQUE = "opaque" - # todo: add semi-sparse From 960bc91e0b9f26282c104f353a3153f68b1d2be0 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sun, 26 Oct 2025 12:31:11 -0700 Subject: [PATCH 5/5] Apply ruff formatting fix Signed-off-by: Benji Beck --- .../float8/test_float8_semi_sparse.py | 33 ++++++++++++------- .../float8/float8_semi_sparse_tensor.py | 16 +++++---- .../workflows/float8/float8_tensor.py | 2 +- 3 files changed, 31 insertions(+), 20 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py b/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py index d1bf7600de..232414bbb6 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py @@ -5,9 +5,7 @@ # LICENSE file in the root directory of this source tree. import unittest -from torchao.quantization.quantize_.workflows.float8.float8_semi_sparse_tensor import Float8SemiSparseTensor -from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor -from torchao.float8.inference import Float8MMConfig + import torch from torch.testing._internal.common_utils import ( TestCase, @@ -15,6 +13,12 @@ parametrize, run_tests, ) + +from torchao.float8.inference import Float8MMConfig +from torchao.quantization.quantize_.workflows.float8.float8_semi_sparse_tensor import ( + Float8SemiSparseTensor, +) +from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor from torchao.sparsity.sparse_api import apply_fake_sparsity from torchao.testing.utils import skip_if_rocm from torchao.utils import is_sm_at_least_90 @@ -44,20 +48,25 @@ def test_sparse_vs_dense_fp8(self, sizes): linear = torch.nn.Linear(K, N, dtype=dtype, device=device) apply_fake_sparsity(linear) - + mm_config = Float8MMConfig(use_fast_accum=True) - input_fp8 = Float8Tensor.from_hp(input, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config) - - weight_fp8 = Float8Tensor.from_hp(linear.weight.data, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config) + input_fp8 = Float8Tensor.from_hp( + input, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config + ) + + weight_fp8 = Float8Tensor.from_hp( + linear.weight.data, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config + ) dense_output = torch.nn.functional.linear(input_fp8, weight_fp8, linear.bias) - + weight_sparse_fp8 = Float8SemiSparseTensor.from_hp(linear.weight.data, [1, K]) - sparse_output = torch.nn.functional.linear(input_fp8, weight_sparse_fp8, linear.bias) - - torch.testing.assert_close( - dense_output, sparse_output, atol=3e-1, rtol=3e-1 + sparse_output = torch.nn.functional.linear( + input_fp8, weight_sparse_fp8, linear.bias ) + torch.testing.assert_close(dense_output, sparse_output, atol=3e-1, rtol=3e-1) + + instantiate_parametrized_tests(TestFloat8SemiSparseTensor) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py index a7f4adeb8a..4384cc0aff 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py @@ -86,7 +86,7 @@ def from_hp( @implements(aten.t.default) def _(func, types, args, kwargs): from torch.utils._python_dispatch import return_and_correct_aliasing - + self = args[0] new = Float8SemiSparseTensor( sparse=self.sparse, @@ -98,8 +98,10 @@ def _(func, types, args, kwargs): def _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias): from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8 - from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor - + from torchao.quantization.quantize_.workflows.float8.float8_tensor import ( + Float8Tensor, + ) + if isinstance(input_tensor, Float8Tensor): input = input_tensor.qdata input_scale = input_tensor.scale @@ -108,16 +110,16 @@ def _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias): input = input_tensor.qdata input_scale = input_tensor.scale out_dtype = input_tensor.dtype - + weight = weight_tensor.sparse weight_meta = weight_tensor.meta weight_scale = weight_tensor.scale - + # Reshape input_scale if needed: kernel expects scale to match input shape minus last dim # For input [B, K], scale should be [B] not [B, 1] if input_scale.dim() > input.dim() - 1: input_scale = input_scale.squeeze(-1) - + return rowwise_scaled_linear_sparse_cutlass_f8f8( input, input_scale, weight, weight_meta, weight_scale, bias, out_dtype ) @@ -130,7 +132,7 @@ def _(func, types, args, kwargs): else: # aten.mm.default input_tensor, weight_tensor = args bias = None - + return _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 9814dc4c4f..97faa8ce06 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -256,7 +256,7 @@ def _(func, types, args, kwargs): args[1], args[2] if len(args) > 2 else None, ) - + # If weight is not Float8Tensor, return NotImplemented to allow weight's dispatch to handle it if not isinstance(weight_tensor, Float8Tensor): return NotImplemented