diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 786e0cf59f..f3bfa60664 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -30,6 +30,7 @@ _is_fbgemm_gpu_genai_available, is_sm_at_least_89, is_sm_at_least_90, + is_sm_at_least_100, torch_version_at_least, ) @@ -49,6 +50,28 @@ def forward(self, x): return x +class ToyConvModel(torch.nn.Module): + def __init__( + self, dim, in_channels, out_channels, kernel_size, bias, padding, dtype, device + ): + super().__init__() + convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} + self.conv = convs[dim]( + in_channels, + out_channels, + kernel_size, + bias=bias, + padding=padding, + dtype=dtype, + device=device, + ) + if dim == 3: + self.conv = self.conv.to(memory_format=torch.channels_last_3d) + + def forward(self, x): + return self.conv(x) + + # TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations @unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @@ -148,6 +171,85 @@ def test_fp8_linear_variants( f"Quantization error is too high got a SQNR of {error}" ) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_100(), "Requires GPU with compute capability >= 10.0" + ) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) + @common_utils.parametrize("compile", [True, False]) + @common_utils.parametrize("granularity", [PerTensor()]) + @common_utils.parametrize("inference_mode", [True, False]) + @common_utils.parametrize( + "kernel_preference", + [KernelPreference.AUTO], + ) + # only test for 3D conv for now + # Inputs are (N, C_in, C_out, D, H, W) + @common_utils.parametrize( + "sizes", + [ + (4, 16, 64, 32, 32, 32), + ], + ) + def test_fp8_conv_variants( + self, + dtype: torch.dtype, + compile: bool, + granularity, + inference_mode: bool, + kernel_preference: KernelPreference, + sizes: Tuple, + ): + if (not _is_fbgemm_gpu_genai_available()) or (not is_sm_at_least_100()): + return unittest.skip( + "Requires fbgemm_gpu_genai and sm version >= 10.0 to run " + "fbgemm kernel preference test" + ) + + dim = 3 + N, C_in, C_out, D, H, W = sizes + kernel_size = 3 + + # Note: this is channel last memory format + input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda") + input_tensor = input_tensor.to(memory_format=torch.channels_last_3d) + + # Create a linear layer with bfloat16 dtype + model = ToyConvModel( + dim, + C_in, + C_out, + kernel_size, + bias=False, + padding=0, + dtype=dtype, + device="cuda", + ).eval() + + quantized_model = copy.deepcopy(model) + + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, + kernel_preference=kernel_preference, + ) + + _is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d) + + quantize_(quantized_model, config, filter_fn=_is_conv3d) + + if compile: + quantized_model = torch.compile(quantized_model, fullgraph=True) + + inference_mode_ctx = torch.inference_mode() if inference_mode else nullcontext() + with inference_mode_ctx: + output_original = model(input_tensor) + output_quantized = quantized_model(input_tensor) + + error = compute_error(output_original, output_quantized) + assert compute_error(output_original, output_quantized) > 20, ( + f"Quantization error is too high got a SQNR of {error}" + ) + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) @unittest.skipIf( not is_sm_at_least_90(), diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ae8210a41a..39d2dc450f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1813,7 +1813,12 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): _check_hardware_support(granularity) activation_granularity, weight_granularity = granularity - if not _fp8_mm_compat(weight): + if weight.dim() == 5: + # weights for conv3d + assert isinstance(activation_granularity, PerTensor) and isinstance( + weight_granularity, PerTensor + ), "5D tensor only supports per tensor activation and weight quantization" + elif not _fp8_mm_compat(weight): # TODO(future PR): this should really throw an exception instead of silently # not doing what the user asked return weight diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 47395a15af..7cc6c195e9 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -39,6 +39,7 @@ _is_fbgemm_gpu_genai_available, fill_defaults, is_sm_at_least_90, + is_sm_at_least_100, ) __all__ = [ @@ -261,7 +262,7 @@ def _(func, types, args, kwargs): ) act_quant_kwargs = weight_tensor.act_quant_kwargs - # quantizing activation, if `act_quant_kwargs` is specified + # quantize activation, if `act_quant_kwargs` is specified if act_quant_kwargs is not None: input_tensor = _choose_quant_func_and_quantize_tensor( input_tensor, act_quant_kwargs @@ -418,6 +419,125 @@ def _(func, types, args, kwargs): return res +def _quantize_and_scaled_conv3d( + input_tensor, + weight_tensor, + bias, + stride, + padding, + dilation, +): + assert isinstance(weight_tensor, Float8Tensor), ( + f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}" + ) + + assert input_tensor.dim() == 5 and weight_tensor.dim() == 5, ( + "Only support 3D conv currently" + ) + assert _is_fbgemm_gpu_genai_available(), ( + "quantized fp8 conv3d requires fbgemm_gpu_genai to be available" + ) + act_quant_kwargs = weight_tensor.act_quant_kwargs + # quantize activation, if `act_quant_kwargs` is specified + if act_quant_kwargs is not None: + input_tensor = _choose_quant_func_and_quantize_tensor( + input_tensor, act_quant_kwargs + ) + + if isinstance(input_tensor, Float8Tensor): + kernel_choice = None + if weight_tensor.kernel_preference == KernelPreference.AUTO: + if _is_fbgemm_gpu_genai_available() and is_sm_at_least_100(): + kernel_choice = "fbgemm" + else: + raise NotImplementedError( + f"No available kernel choice for {weight_tensor.kernel_preference}" + ) + elif weight_tensor.kernel_preference == KernelPreference.FBGEMM: + kernel_choice = "fbgemm" + else: + raise NotImplementedError( + f"No available kernel choice for {weight_tensor.kernel_preference}" + ) + + assert kernel_choice == "fbgemm", "Only fbgemm kernel choice is supported currently" + # move C_in to last dim + # after permute: (N, D, H, W, C_in) + act_qdata = input_tensor.qdata.permute([0, 2, 3, 4, 1]) + + # move C_in to last dim + # after permute: (C_out, K1, K2, K3, C_in) + weight_qdata = weight_tensor.qdata.permute([0, 2, 3, 4, 1]) + + assert act_qdata.is_contiguous() and weight_qdata.is_contiguous(), ( + "Please make sure both activation and weights are in the `channels_last_3d` memory_format" + ) + + act_scale = input_tensor.scale + weight_scale = weight_tensor.scale + output = torch.ops.fbgemm.f8f8bf16_conv( + act_qdata, + weight_qdata, + act_scale * weight_scale, + padding, + stride, + dilation, + ) + # output shape after permute: N, C_out, D_out, H_out, W_out + output = output.permute([0, 4, 1, 2, 3]) + return output + + +@implements(aten.convolution.default) +def _(func, types, args, kwargs): + ( + input_tensor, + weight_tensor, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) = args + assert not transposed, "transposed conv is not supported currently" + assert tuple(output_padding) == (0, 0, 0), ( + f"Only (0, 0, 0) is supported for `output_padding`, got: f{output_padding}" + ) + assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}" + return _quantize_and_scaled_conv3d( + input_tensor, + weight_tensor, + bias, + stride, + padding, + dilation, + ) + + +@implements(aten.conv3d.default) +def _(func, types, args, kwargs): + ( + input_tensor, + weight_tensor, + bias, + stride, + padding, + dilation, + groups, + ) = fill_defaults(args, 7, [None, [1, 1, 1], [0, 0, 0], [1, 1, 1], 1]) + assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}" + return _quantize_and_scaled_conv3d( + input_tensor, + weight_tensor, + bias, + stride, + padding, + dilation, + ) + + @implements(aten.slice.Tensor) def _(func, types, args, kwargs): """Supports slicing for 1d, 2d, and 3d tensors diff --git a/torchao/utils.py b/torchao/utils.py index 5af3e00cfa..02013c5197 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -32,6 +32,7 @@ "is_MI300", "is_sm_at_least_89", "is_sm_at_least_90", + "is_sm_at_least_100", "is_package_at_least", "DummyModule", # Deprecated