-
Couldn't load subscription status.
- Fork 354
introduce new int8 quantization API #3241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3241
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| quant_min=self.int8_min, | ||
| quant_max=self.int8_max, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can omit these two args if these are the same as default (-128, 127)
| ) | ||
|
|
||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| def test_quantization_shapes(self, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to be a combination of two tests, one for dynamic quant one for static quant, can you use something like this:
| @common_utils.parametrize("mode", ["dynamic", "weight-only"]) |
also I feel it might be better to not add static quant in this PR, and in a separate PR add both the tensor support and config support for static quant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, not sure to remove static flags (although its not fully implemented) before, but small PR should be always better I feel. I will remove static_scale and all those supports.
| if act_quant_kwargs is not None and act_quant_kwargs.static_scale is not None: | ||
| # INT8 × INT8 (static) | ||
| scale = act_quant_kwargs.static_scale | ||
| zero_point = torch.zeros_like(scale, dtype=torch.int8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think user should specify static_zero_point as well
but again, it's better to do this in a separate PR, since current state is a half of the static quant feature (no config)
| # Cast fp16 scale to float | ||
| intermediate_dtype = ( | ||
| torch.float if x_scales.dtype == torch.half else x_scales.dtype | ||
| ) | ||
| # Note: CUDA doesn't support int32/int64 matmul, so we convert to float | ||
| # Error message is NotImplementedError: "addmm_cuda" not implemented for 'Int' | ||
| # This may introduce minor numerical differences compared to int arithmetic | ||
| y_dot = torch.mm(tmp.to(intermediate_dtype), w_vals_t.to(intermediate_dtype)) | ||
|
|
||
| # Apply activation scale | ||
| is_per_tensor_act = x_scales.numel() == 1 | ||
| if is_per_tensor_act: | ||
| y_dot.mul_(x_scales.to(intermediate_dtype)) | ||
| else: | ||
| # For block-wise activation scale, reshape to match y_dot | ||
| x_scales_reshaped = x_scales.view(y_dot.shape[0], -1) | ||
| y_dot.mul_(x_scales_reshaped.to(intermediate_dtype)) | ||
|
|
||
| # Apply weight scale | ||
| is_per_tensor_weight = w_scales.numel() == 1 | ||
| if is_per_tensor_weight: | ||
| result = y_dot.mul_(w_scales.to(intermediate_dtype)) | ||
| else: | ||
| # Per-row weight scale - transpose and broadcast | ||
| w_scales_broadcast = w_scales.t().expand_as(y_dot) | ||
| result = y_dot.mul_(w_scales_broadcast.to(intermediate_dtype)) | ||
|
|
||
| # Reshape back to original shape | ||
| result = result.view(*x_vals.shape[:-1], result.shape[-1]) | ||
| result = result.to(activation_tensor.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should follow:
ao/torchao/dtypes/uintx/plain_layout.py
Line 281 in e9c7bea
| def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): |
| result = result.view(*x_vals.shape[:-1], result.shape[-1]) | ||
| result = result.to(activation_tensor.dtype) | ||
| else: | ||
| # FP × INT8 (weight-only) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should follow
ao/torchao/dtypes/uintx/plain_layout.py
Line 250 in e9c7bea
| def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should
- split the static quant support to separate PR
- follow what https://github.com/pytorch/ao/blob/main/torchao/dtypes/uintx/plain_layout.py is doing for quantized linear implementation
this should be a refactor PR, not a refactor + some extra modifications + some feature implementations I think
Summary:
Introduce a new tensor subclass API. Main features are
Int8Tensor: Main api which handles quantization and dequantization operationsThis api is integrated to global variants (
Int8WeightOnlyConfig,Int8DynamicActivationInt8WeightConfig) usingversion, and not defined as a default.Related Issue/PR:
This is reopened PR for #3038
Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py