Skip to content

Conversation

@namgyu-youn
Copy link
Contributor

@namgyu-youn namgyu-youn commented Oct 24, 2025

Summary:
Introduce a new tensor subclass API. Main features are

  • Int8Tensor: Main api which handles quantization and dequantization operations
  • Utility operation functions: Tensor slice, index selection

This api is integrated to global variants (Int8WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig) using version, and not defined as a default.

Related Issue/PR:
This is reopened PR for #3038

Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3241

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 24, 2025
Comment on lines 140 to 141
quant_min=self.int8_min,
quant_max=self.int8_max,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can omit these two args if these are the same as default (-128, 127)

)

@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_quantization_shapes(self, dtype):
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be a combination of two tests, one for dynamic quant one for static quant, can you use something like this:

@common_utils.parametrize("mode", ["dynamic", "weight-only"])

also I feel it might be better to not add static quant in this PR, and in a separate PR add both the tensor support and config support for static quant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, not sure to remove static flags (although its not fully implemented) before, but small PR should be always better I feel. I will remove static_scale and all those supports.

if act_quant_kwargs is not None and act_quant_kwargs.static_scale is not None:
# INT8 × INT8 (static)
scale = act_quant_kwargs.static_scale
zero_point = torch.zeros_like(scale, dtype=torch.int8)
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think user should specify static_zero_point as well

but again, it's better to do this in a separate PR, since current state is a half of the static quant feature (no config)

Comment on lines 196 to 225
# Cast fp16 scale to float
intermediate_dtype = (
torch.float if x_scales.dtype == torch.half else x_scales.dtype
)
# Note: CUDA doesn't support int32/int64 matmul, so we convert to float
# Error message is NotImplementedError: "addmm_cuda" not implemented for 'Int'
# This may introduce minor numerical differences compared to int arithmetic
y_dot = torch.mm(tmp.to(intermediate_dtype), w_vals_t.to(intermediate_dtype))

# Apply activation scale
is_per_tensor_act = x_scales.numel() == 1
if is_per_tensor_act:
y_dot.mul_(x_scales.to(intermediate_dtype))
else:
# For block-wise activation scale, reshape to match y_dot
x_scales_reshaped = x_scales.view(y_dot.shape[0], -1)
y_dot.mul_(x_scales_reshaped.to(intermediate_dtype))

# Apply weight scale
is_per_tensor_weight = w_scales.numel() == 1
if is_per_tensor_weight:
result = y_dot.mul_(w_scales.to(intermediate_dtype))
else:
# Per-row weight scale - transpose and broadcast
w_scales_broadcast = w_scales.t().expand_as(y_dot)
result = y_dot.mul_(w_scales_broadcast.to(intermediate_dtype))

# Reshape back to original shape
result = result.view(*x_vals.shape[:-1], result.shape[-1])
result = result.to(activation_tensor.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should follow:

def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):

result = result.view(*x_vals.shape[:-1], result.shape[-1])
result = result.to(activation_tensor.dtype)
else:
# FP × INT8 (weight-only)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should follow

def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should

  1. split the static quant support to separate PR
  2. follow what https://github.com/pytorch/ao/blob/main/torchao/dtypes/uintx/plain_layout.py is doing for quantized linear implementation

this should be a refactor PR, not a refactor + some extra modifications + some feature implementations I think

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants