Skip to content

Commit 8471264

Browse files
authored
[Quantization Args] Add scale and zp dtype (#508)
* update * add back test * update * update * fix serialization * fix condition * update * update * update * update * update * remove torch * update * update * update tests * update * update * fix comment * update * updatE * update * update * update * update * update * update * update * update * updatE * update * update * update
1 parent 52792be commit 8471264

File tree

16 files changed

+201
-90
lines changed

16 files changed

+201
-90
lines changed

.github/workflows/test-check.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ on:
1212

1313
jobs:
1414
python-tests:
15-
runs-on: ubuntu-22.04
15+
runs-on: gcp-k8s-vllm-l4-solo
1616
env:
1717
HF_TOKEN: ${{ secrets.HF_RED_HAT_READ_ONLY }}
1818
steps:

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ def compress(
9090
desc = "Compressing with quantization"
9191
for name in tqdm(uncompressed_names, desc=desc, disable=(not show_progress)):
9292
value = model_state[name]
93-
9493
# compress weights
9594
if name.endswith("weight"):
9695
prefix = name.removesuffix("weight")
@@ -129,10 +128,18 @@ def compress(
129128
if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme):
130129
continue
131130

131+
if name.endswith("weight_scale") and self._skip_scale():
132+
continue
133+
132134
compressed_dict[name] = value.to(compression_device)
133135

134136
return compressed_dict
135137

138+
def _skip_scale(self):
139+
from compressed_tensors.compressors import NVFP4PackedCompressor
140+
141+
return isinstance(self, NVFP4PackedCompressor)
142+
136143
def _skip_zp(
137144
self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
138145
) -> bool:

src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torch import Tensor
2727

2828

29-
__all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8"]
29+
__all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8", "NVFP4PackedCompressor"]
3030

3131
FLOAT_TO_E2M1 = [
3232
0.0,
@@ -103,6 +103,7 @@ def compress_weight(
103103
if device is not None:
104104
weight_packed = weight_packed.to(device)
105105
compressed_dict["weight_packed"] = weight_packed
106+
compressed_dict["weight_scale"] = scale.to(quantization_args.scale_dtype)
106107
return compressed_dict
107108

108109
def decompress_weight(
@@ -111,8 +112,8 @@ def decompress_weight(
111112
quantization_args: Optional[QuantizationArgs] = None,
112113
) -> torch.Tensor:
113114
weight = compressed_data["weight_packed"]
114-
scale = compressed_data["weight_scale"]
115115
global_scale = compressed_data["weight_global_scale"]
116+
scale = compressed_data["weight_scale"]
116117
m, n = weight.shape
117118
# TODO: use a user provided dequant dtype
118119
unpacked = unpack_fp4_from_uint8(weight, m, n * 2)

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
DynamicType,
2222
QuantizationArgs,
2323
QuantizationStrategy,
24-
round_to_quantized_type,
24+
round_to_quantized_type_args,
2525
)
2626
from compressed_tensors.quantization.quant_config import QuantizationStatus
2727
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
@@ -466,20 +466,17 @@ def _quantize(
466466
# if a global scale is optionally provided, use it
467467
# to further scale the local `scale` parameter
468468
if global_scale is not None:
469-
scale = scale.to(global_scale.dtype) / global_scale
469+
scale = scale / global_scale
470470

471471
scaled = x / scale
472472

473473
if zero_point is not None:
474474
scaled += zero_point.to(x.dtype)
475475

476-
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
477-
clamped_value = torch.clamp(
478-
scaled,
479-
q_min,
480-
q_max,
476+
# clamp and round
477+
quantized_value = round_to_quantized_type_args(
478+
tensor=scaled, args=args, min=q_min, max=q_max
481479
)
482-
quantized_value = round_to_quantized_type(clamped_value, args)
483480

484481
if dtype is not None:
485482
quantized_value = quantized_value.to(dtype)
@@ -499,7 +496,7 @@ def _dequantize(
499496
# if a global scale is optionally provided, use it
500497
# to further scale the local `scale` parameter
501498
if global_scale is not None:
502-
scale = scale.to(global_scale.dtype) / global_scale
499+
scale = scale / global_scale
503500

504501
dequant_value = x_q.to(scale.dtype)
505502

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
QuantizedKVCache,
2525
)
2626
from compressed_tensors.quantization import (
27-
FP8_E4M3_DATA,
2827
ActivationOrdering,
2928
DynamicType,
3029
QuantizationArgs,
@@ -36,7 +35,7 @@
3635
from compressed_tensors.quantization.lifecycle.forward import (
3736
wrap_module_forward_quantized,
3837
)
39-
from compressed_tensors.quantization.utils import is_fp4, strategy_cdiv
38+
from compressed_tensors.quantization.utils import strategy_cdiv
4039
from compressed_tensors.utils import (
4140
disable_hf_hook,
4241
get_execution_device,
@@ -250,20 +249,15 @@ def initialize_qparams(
250249

251250
# 2. Identify quantization scale and zp dtype
252251
scale_dtype = observed_dtype
253-
254-
if is_fp4(quantization_args=quantization_args):
255-
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
256-
else:
257-
# TODO: consider erroring out in the future as if the dtype if not one of these,
258-
# there is likely bug
252+
if quantization_args.scale_dtype is None:
259253
if scale_dtype not in [
260254
torch.float16,
261255
torch.bfloat16,
262256
torch.float32,
263257
torch.float64,
264258
]:
265-
scale_dtype = torch.bfloat16
266-
zp_dtype = quantization_args.pytorch_dtype()
259+
scale_dtype = torch.float16
260+
quantization_args.scale_dtype = scale_dtype
267261

268262
# 3. Initializes scale/zp for the module
269263
init_scale = Parameter(
@@ -274,7 +268,9 @@ def initialize_qparams(
274268

275269
if force_zero_point or not quantization_args.symmetric:
276270
init_zero_point = Parameter(
277-
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
271+
torch.zeros(
272+
expected_shape, device=device, dtype=quantization_args.zp_dtype
273+
),
278274
requires_grad=False,
279275
)
280276
register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)

src/compressed_tensors/quantization/quant_args.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
from compressed_tensors.utils import Aliasable
2121
from compressed_tensors.utils.helpers import deprecated
22+
from compressed_tensors.utils.type import TorchDtype
2223
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
2324

2425

@@ -30,7 +31,8 @@
3031
"QuantizationType",
3132
"QuantizationStrategy",
3233
"QuantizationArgs",
33-
"round_to_quantized_type",
34+
"round_to_quantized_type_args",
35+
"round_to_quantized_type_dtype",
3436
"ActivationOrdering",
3537
"DynamicType",
3638
]
@@ -174,6 +176,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
174176
block_structure: Optional[List[int]] = None
175177
dynamic: Union[DynamicType, bool] = False
176178
actorder: Union[ActivationOrdering, bool, None] = None
179+
scale_dtype: Optional[TorchDtype] = None
180+
zp_dtype: Optional[TorchDtype] = None
177181
observer: Optional[str] = Field(
178182
default=None,
179183
description=(
@@ -266,6 +270,7 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
266270
dynamic = model.dynamic
267271
observer = model.observer
268272
dynamic = model.dynamic
273+
zp_dtype = model.zp_dtype
269274

270275
# infer strategy
271276
if strategy is None:
@@ -353,9 +358,16 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
353358
# default to minmax for non-dynamic cases
354359
observer = "minmax"
355360

361+
if zp_dtype is None:
362+
if model.num_bits == 4 and model.type == QuantizationType.FLOAT:
363+
zp_dtype = FP8_E4M3_DATA.dtype
364+
else:
365+
zp_dtype = model.pytorch_dtype()
366+
356367
# write back modified values
357368
model.strategy = strategy
358369
model.observer = observer
370+
model.zp_dtype = zp_dtype
359371
return model
360372

361373
def pytorch_dtype(self) -> torch.dtype:
@@ -381,18 +393,47 @@ def get_observer(self) -> str:
381393
model_config = ConfigDict(extra="forbid")
382394

383395

384-
def round_to_quantized_type(
385-
tensor: torch.Tensor, args: QuantizationArgs
396+
def round_to_quantized_type_dtype(
397+
tensor: torch.Tensor, dtype: torch.dtype
386398
) -> torch.Tensor:
387399
"""
388-
Rounds each element of the input tensor to the nearest quantized representation,
389-
keeping to original dtype
400+
Rounds an input tensor to the nearest quantized representation given a dtype.
401+
The original dtype is kept post-rounding.
390402
391403
:param tensor: tensor to round
392-
:param args: QuantizationArgs to pull appropriate dtype from
404+
:param dtype: dtype to use for rounding
393405
:return: rounded tensor
394406
"""
395407
original_dtype = tensor.dtype
408+
if torch.is_floating_point(torch.tensor([], dtype=dtype)):
409+
finfo = torch.finfo(dtype)
410+
rounded = torch.clamp(tensor, finfo.min, finfo.max).to(dtype)
411+
else:
412+
iinfo = torch.iinfo(dtype)
413+
rounded = torch.round(torch.clamp(tensor, iinfo.min, iinfo.max))
414+
415+
return rounded.to(original_dtype)
416+
417+
418+
def round_to_quantized_type_args(
419+
tensor: torch.Tensor,
420+
args: QuantizationArgs,
421+
min: torch.Tensor,
422+
max: torch.Tensor,
423+
) -> torch.Tensor:
424+
"""
425+
Rounds an input tensor to the nearest quantized representation given
426+
qunatization args. The original dtype is kept post-rounding.
427+
428+
:param tensor: tensor to round
429+
:param args: quantization args to use for rounding
430+
:param min: min value to use for clamping
431+
:param max: max value to use for clamping
432+
:return: rounded tensor
433+
"""
434+
435+
original_dtype = tensor.dtype
436+
tensor = torch.clamp(tensor, min, max)
396437
if args.type == QuantizationType.FLOAT:
397438
if args.num_bits == 8:
398439
rounded = tensor.to(FP8_E4M3_DATA.dtype)

src/compressed_tensors/quantization/quant_config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
from collections import defaultdict
1615
from enum import Enum
1716
from typing import Annotated, Any, Dict, List, Optional, Set, Union

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from compressed_tensors.config import CompressionFormat
2020
from compressed_tensors.quantization.quant_args import (
21+
FP8_E4M3_DATA,
2122
DynamicType,
2223
QuantizationArgs,
2324
QuantizationStrategy,
@@ -160,6 +161,8 @@ def is_preset_scheme(name: str) -> bool:
160161
symmetric=True,
161162
dynamic=False,
162163
group_size=16,
164+
scale_dtype=FP8_E4M3_DATA.dtype,
165+
zp_dtype=FP8_E4M3_DATA.dtype,
163166
)
164167
)
165168

@@ -173,6 +176,8 @@ def is_preset_scheme(name: str) -> bool:
173176
dynamic=False,
174177
group_size=16,
175178
observer="static_minmax",
179+
scale_dtype=FP8_E4M3_DATA.dtype,
180+
zp_dtype=FP8_E4M3_DATA.dtype,
176181
),
177182
input_activations=QuantizationArgs(
178183
num_bits=4,
@@ -182,6 +187,8 @@ def is_preset_scheme(name: str) -> bool:
182187
dynamic=DynamicType.LOCAL,
183188
group_size=16,
184189
observer="static_minmax",
190+
scale_dtype=FP8_E4M3_DATA.dtype,
191+
zp_dtype=FP8_E4M3_DATA.dtype,
185192
),
186193
)
187194

0 commit comments

Comments
 (0)