diff --git a/examples/awq/llama_example.py b/examples/awq/llama_example.py index d06a2ccb91..e31304b293 100644 --- a/examples/awq/llama_example.py +++ b/examples/awq/llama_example.py @@ -50,7 +50,9 @@ def tokenize(sample): # Configure the quantization algorithm to run. recipe = [ - AWQModifier(ignore=["lm_head"], scheme="W4A16_ASYM", targets=["Linear"]), + AWQModifier( + ignore=["lm_head"], scheme="W4A16_ASYM", targets=["Linear"], duo_scaling="both" + ), ] # Apply algorithms. diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 98e53b4e00..57c29d592a 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -3,15 +3,20 @@ from typing import Literal import torch -from compressed_tensors.quantization import disable_quantization +from compressed_tensors.quantization import ( + QuantizationStrategy, + disable_quantization, + forward_quantize, +) from compressed_tensors.utils import ( align_modules, get_execution_device, match_named_modules, + patch_attrs, update_offload_parameter, ) from loguru import logger -from pydantic import ConfigDict, PrivateAttr, model_validator +from pydantic import ConfigDict, PrivateAttr from torch.nn import Module from tqdm import tqdm @@ -22,9 +27,13 @@ ResolvedMapping, get_layer_mappings_from_architecture, ) -from llmcompressor.modifiers.quantization.calibration import update_weight_zp_scale +from llmcompressor.modifiers.quantization.calibration import ( + call_observer, + update_weight_zp_scale, +) from llmcompressor.modifiers.quantization.quantization import QuantizationMixin from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.observers.base import Observer from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context @@ -133,11 +142,6 @@ class AWQModifier(Modifier, QuantizationMixin): duo_scaling: bool | Literal["both"] = True n_grid: int = 20 - # Private vars set during validation - _num_bits: int | None = PrivateAttr(default=None) - _symmetric: bool | None = PrivateAttr(default=None) - _group_size: int | None = PrivateAttr(default=None) - # Private vars set during initialization, cleared during finalization _resolved_mappings: list[ResolvedMapping] = PrivateAttr(default_factory=list) # Cache list of forward input args for each parent module, one dict for each batch @@ -149,74 +153,6 @@ class AWQModifier(Modifier, QuantizationMixin): default_factory=dict ) - # NOTE: different name chosen to avoid collision with - # QuantizationMixin.validate_model_after, which must be called first - @model_validator(mode="after") - def validate_awq_after(model: "AWQModifier") -> "AWQModifier": - """ - Confirm only one configuration for group_size, symmetric, and num_bits, - as AWQ algorithm depends on it - Confirm no activation quantization, as AWQ only works with WNA16 - """ - config = model.resolve_quantization_config() - - num_bits_set = set( - group.weights.num_bits - for group in config.config_groups.values() - if group.weights is not None - ) - assert ( - len(num_bits_set) == 1 - ), "In AWQ, all config groups must use the same configuration for num_bits" - - model._num_bits = next(iter(num_bits_set)) - - symmetric_set = set( - group.weights.symmetric - for group in config.config_groups.values() - if group.weights is not None - ) - assert ( - len(symmetric_set) == 1 - ), "In AWQ, all config groups must use the same configuration for symmetric" - - model._symmetric = next(iter(symmetric_set)) - - group_size_set = set( - group.weights.group_size - for group in config.config_groups.values() - if group.weights is not None - ) - assert ( - len(group_size_set) == 1 - ), "In AWQ, all config groups must use the same configuration for group_size" - - model._group_size = next(iter(group_size_set)) - if model._group_size is None: - model._group_size = -1 - - in_num_bits_set = set( - group.input_activations.num_bits - for group in config.config_groups.values() - if group.input_activations is not None - ) - assert len(in_num_bits_set) == 0 or in_num_bits_set == {16}, ( - "AWQ activations must be 16-bit precision, " - f"input activations {in_num_bits_set} not allowed" - ) - - out_num_bits_set = set( - group.output_activations.num_bits - for group in config.config_groups.values() - if group.output_activations is not None - ) - assert len(out_num_bits_set) == 0 or out_num_bits_set == {16}, ( - "AWQ activations must be 16-bit precision, " - f"output activations {out_num_bits_set} not allowed" - ) - - return model - def on_initialize(self, state: State, **kwargs) -> bool: """ Initialize AWQ on the given state @@ -398,7 +334,7 @@ def _setup_activation_cache_hooks(self) -> None: """ def cache_parent_kwargs_hook( - module: torch.nn.Module, + module: Module, args: tuple[torch.Tensor, ...], kwargs, ): @@ -407,7 +343,7 @@ def cache_parent_kwargs_hook( def create_cache_smooth_activations_hook_fn(smooth_name): def cache_smooth_activations_hook( - _module: torch.nn.Module, + _module: Module, args: tuple[torch.Tensor, ...], _output: torch.Tensor, ): @@ -469,28 +405,7 @@ def _apply_smoothing(self, model: Module) -> None: calibration_forward_context(model), HooksMixin.disable_hooks(), ): - # [STEP 1]: Compute per-channel mean of normalised weights - # All layer weights are concatted together - weight = torch.cat([bl.weight for bl in balance_layers], dim=0) - org_shape = weight.shape - # The weights are reshaped to be organised by quantization group - if self._group_size > 0: - weight = weight.view(-1, self._group_size) - # Calculates the relative magnitude of the weights within - # each of the quantization groups, and rescales each group - # individually so that each group has weights on a 0-1 scale. - weight.abs_() - weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) - if self._group_size > 0: - # Resizes the rescaled weight matrix back up to - # its original dimensions - weight = weight.view(org_shape) - # Gets the average rescaled magnitude for each output channel - w_mean = weight.mean(0) - del weight - - # [STEP 3]: Compute output of module - # could cache from hook, rather than recomputing here + # Compute output of unquantized module fp16_outputs = self._run_samples(parent_module) if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs): logger.info( @@ -515,15 +430,10 @@ def _apply_smoothing(self, model: Module) -> None: del self._smooth_activation_means[mapping.smooth_name] continue - x_mean = self._smooth_activation_means[mapping.smooth_name][0] - - # [STEP 4]: Compute loss - best_scales = self._compute_best_scale( - x_mean, w_mean, parent_module, balance_layers, fp16_outputs - ) + best_scales = self._compute_best_scale(mapping, fp16_outputs) @torch.no_grad() - def _smooth(module): + def _smooth(module: Module): scales = best_scales.to(module.weight.device) if module in balance_layers: update_offload_parameter( @@ -576,42 +486,49 @@ def _run_samples(self, module: Module) -> list[torch.Tensor]: module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module] ] return [ - # If Tuple, assume that first argument is the input + # If tuple, assume that first argument is the input output[0] if isinstance(output, tuple) else output for output in outputs ] def _compute_best_scale( self, - x_mean: torch.Tensor, - w_mean: torch.Tensor, - parent_module: torch.nn.Module, - linears2scale: list[torch.nn.Linear], + mapping: ResolvedMapping, fp16_outputs: list[torch.Tensor], ) -> torch.Tensor: """ - Compute loss and select best scales + Select best scales for a given mapping in a grid search + Best scales are those that minimize MSE loss of quantized weight + outputs compared to fp16_outputs L(s) = || Q(W * s) (s^-1 * X) - W * X || Q: weight quantization function | _pseudo_quantize_tensor(W * s) X: inputs from calib dataset | X W: original weights in FP16 | layer s: per channel scaling factor | s^-1 * X + + :param mapping: best scales will be found for thi ResolvedMapping. + :param fp16_outputs: output of mapping.parent in unquantized case, + one tensor for each batch. + :return: tensor of best scales, one for each channel """ history = [] best_ratio = -1 + best_duo_scaling = -1 best_scales = None best_error = float("inf") org_sd = { k: v.cpu() - for k, v in parent_module.state_dict().items() + for k, v in mapping.parent.state_dict().items() if v.device != torch.device("meta") } - device = get_execution_device(parent_module) - x_mean = x_mean.view(-1).to(device) - w_mean = w_mean.view(-1).to(device) + device = get_execution_device(mapping.parent) + + x_mean = self._smooth_activation_means[mapping.smooth_name][0].to(device) + if self.duo_scaling: + w_mean = self._compute_layer_means(mapping.balance_layers).to(device) match self.duo_scaling: # if self.duo_scaling is "both", perform half the grid search with @@ -622,52 +539,81 @@ def _compute_best_scale( case _: n_grid = self.n_grid duo_scalings = [self.duo_scaling] - for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings): - # create new scales - ratio = grid_idx / n_grid - - # NOTE: s^-1 * x is fused here, according to paper - if use_duo_scaling: - scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp( - min=1e-4 - ) - else: - scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) - scales = scales / (scales.max() * scales.min()).sqrt() - _scalesview = scales.view(1, -1).to(device) - - # avoid scaling values that overflow - scales[torch.isinf(scales)] = 1 - scales[torch.isnan(scales)] = 1 - - # Q(W * s) - for linear in linears2scale: - linear.weight.mul_(_scalesview) - update_offload_parameter( - linear, - "weight", - _pseudo_quantize_tensor( - w=linear.weight.data, - symmetric=self._symmetric, - bit_width=self._num_bits, - group_size=self._group_size, - )[0] - / _scalesview, + + # Where appropriate, replace observers with memoryless_minmax + # for duration of grid search + balance_layers_to_patch = [ + balance_layer + for balance_layer in mapping.balance_layers + if hasattr(balance_layer, "quantization_scheme") + and hasattr(balance_layer.quantization_scheme, "weights") + ] + with patch_attrs( + balance_layers_to_patch, + "weight_observer", + [ + Observer.load_from_registry( + "memoryless_minmax", + base_name="weight", + args=balance_layer.quantization_scheme.weights, + module=balance_layer, ) + for balance_layer in balance_layers_to_patch + ], + ): + for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings): + # create new scales + ratio = grid_idx / n_grid + + # NOTE: s^-1 * x is fused here, according to paper + if use_duo_scaling: + scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp( + min=1e-4 + ) + else: + scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() + _scalesview = scales.view(1, -1).to(device) + + # avoid scaling values that overflow + scales[torch.isinf(scales)] = 1 + scales[torch.isnan(scales)] = 1 + + # Q(W * s) + for balance_layer in balance_layers_to_patch: + if not hasattr(balance_layer, "quantization_scheme") or not hasattr( + balance_layer.quantization_scheme, "weights" + ): + continue + + balance_layer.weight.mul_(_scalesview) + call_observer(balance_layer, "weight", balance_layer.weight) + update_offload_parameter( + balance_layer, + "weight", + forward_quantize( + balance_layer, + balance_layer.weight.data, + "weight", + balance_layer.quantization_scheme.weights, + ) + / _scalesview, + ) - # W * X - int_w_outputs = self._run_samples(parent_module) + # W * X + int_w_outputs = self._run_samples(mapping.parent) - # compute mean squared error (L2 norm) - loss = self._compute_loss(fp16_outputs, int_w_outputs, device) + # compute mean squared error (L2 norm) + loss = self._compute_loss(fp16_outputs, int_w_outputs, device) - history.append(loss) - if loss < best_error: - best_error = loss - best_ratio = ratio - best_scales = scales.clone() + history.append(loss) + if loss < best_error: + best_error = loss + best_duo_scaling = use_duo_scaling + best_ratio = ratio + best_scales = scales.clone() - parent_module.load_state_dict(org_sd, strict=False) + mapping.parent.load_state_dict(org_sd, strict=False) if best_ratio == -1: logger.debug(history) @@ -682,6 +628,8 @@ def _compute_best_scale( torch.isnan(best_scales).sum() == 0 ), f"Nan found in scales: {best_scales}" + print("BEST CONFIGURATION", best_duo_scaling, best_ratio) + return best_scales.detach().cpu() @torch.no_grad() @@ -690,20 +638,16 @@ def _compute_loss( fp16_outputs: list[torch.Tensor], int_w_outputs: list[torch.Tensor], device: torch.device, - ) -> torch.Tensor: + ) -> float: loss = 0.0 num_elements = 0 # Compute the MSE loss for each batch for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs): - batch_loss = ( - (fp16_batch.to(device) - int_w_batch.to(device)) - .view(-1) - .float() - .pow(2) - .sum() - .item() - ) + batch_loss = torch.nn.functional.mse_loss( + fp16_batch.to(device), int_w_batch.to(device) + ).item() + loss += batch_loss num_elements += fp16_batch.numel() @@ -720,48 +664,64 @@ def _assert_all_activations_consumed(self): if len(self._smooth_activation_means) != 0: raise RuntimeError("Some cached activations were not used") + @staticmethod + def _compute_layer_means(layers: list[Module]) -> torch.Tensor: + """ + Compute per-channel mean of normalised weights for all passed in layers. + Layers with group-wise quantization will be normalized against the group + abs max instead of the abs max of the channel. + + To minimize memory requirements, layers are reduced to a running total + of sums and counts when calculating mean + """ + # TODO: allow for block-wise layer means as well + + group_size = None -def _pseudo_quantize_tensor( - w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 -): - org_w_shape = w.shape - if group_size > 0: - assert org_w_shape[-1] % group_size == 0, ( - f"org_w_shape ({org_w_shape[-1]}) must be a multiple " - + f"of group_size ({group_size})!" - ) - w = w.reshape(-1, group_size) - assert w.dim() == 2 - assert torch.isnan(w).sum() == 0 - - # zero point quantization - if not symmetric: - max_val = w.amax(dim=1, keepdim=True) - min_val = w.amin(dim=1, keepdim=True) - max_int = 2**bit_width - 1 - min_int = 0 - scales = (max_val - min_val).clamp(min=1e-5) / max_int - zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) - w = ( - torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros - ) * scales - zeros = (zeros - 2 ** (bit_width - 1)).view(org_w_shape[0], -1) - else: - max_val = w.abs().amax(dim=1, keepdim=True) - max_val = max_val.clamp(min=1e-5) - max_int = 2 ** (bit_width - 1) - 1 - min_int = -(2 ** (bit_width - 1)) - scales = max_val / max_int - zeros = None - w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales - - assert torch.isnan(scales).sum() == 0 - assert torch.isnan(w).sum() == 0 - - scales = scales.view(org_w_shape[0], -1) - w = w.reshape(org_w_shape) - - return w, scales, zeros + # to calculate mean without having to carry full population + weight_total_count = 0 + weight_total_sum = None + + for layer in layers: + if not hasattr(layer, "weight"): + continue + + weight = layer.weight + org_shape = weight.shape + + # If group-wise, calculate abs max based on group + # abs max, rather than channel + if (group_size := _infer_group_size(layer)) > 0: + weight = weight.view(-1, group_size) + + weight.abs_() + weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) + + # Reshape back to original dimensions + weight = weight.view(org_shape) + + # Gets the average rescaled magnitude for each output channel + weight_total_count += weight.size(0) + weight_sum = weight.sum(0, dtype=torch.float64) + if weight_total_sum is None: + weight_total_sum = weight_sum + else: + weight_total_sum += weight_sum + + return weight_total_sum / weight_total_count + + +def _infer_group_size(layer: Module) -> int: + """ + Returns group_size of layer if applicable, otherwise -1 + """ + if ( + hasattr(layer, "quantization_scheme") + and hasattr(layer.quantization_scheme, "weights") + and layer.quantization_scheme.weights.strategy == QuantizationStrategy.GROUP + ): + return layer.quantization_scheme.weights.group_size + return -1 def _accumulate_mean( diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index da974b25b8..d71a2b2190 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -78,7 +78,8 @@ def call_observer( base_name is "weight", then the module's weight tensor will be used """ with align_module_device(module): - value = module.weight if base_name == "weight" else value + if value is None and base_name == "weight": + value = module.weight observer: Observer = getattr(module, f"{base_name}_observer") if should_calculate_gparam: diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 950ab0f51a..32bf9a490d 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -1,7 +1,12 @@ import pytest import torch -from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + QuantizationStrategy, +) from pydantic import ValidationError +from torch.testing import assert_close from llmcompressor.modifiers.awq import AWQMapping, AWQModifier from llmcompressor.modifiers.awq.base import get_lowest_common_parent @@ -114,63 +119,6 @@ def test_set_resolved_mappings(): @pytest.mark.unit def test_validate(): - with pytest.raises(ValidationError): - AWQModifier(scheme="W8A8") - - with pytest.raises(ValidationError): - AWQModifier( - config_groups={ - "group_0": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=4, - group_size=64, - ), - ), - "group_1": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=4, - group_size=128, - ), - ), - } - ) - - with pytest.raises(ValidationError): - AWQModifier( - config_groups={ - "group_0": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=4, - group_size=128, - ), - ), - "group_1": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=8, - group_size=128, - ), - ), - } - ) - - # valid configuration - AWQModifier( - config_groups={ - "group_0": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs(num_bits=4, group_size=128, symmetric=False), - ), - "group_1": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs(num_bits=4, group_size=128, symmetric=False), - ), - } - ) - AWQModifier(scheme="W4A16", duo_scaling="both") with pytest.raises(ValidationError): AWQModifier(scheme="W4A16", duo_scaling="Both") @@ -234,3 +182,68 @@ def test_get_lowest_common_parent(): ["embed_tokens", "decoder.self_attn.v_proj"], model ) assert parent_name == "" and parent == model + + +@torch.no_grad +@pytest.mark.unit +@pytest.mark.parametrize( + "n_balance_layers, group_size, n_input_features", + [ + (5, None, 32), + (4, 10, 40), + ], +) +def test_compute_layer_means(n_balance_layers, group_size, n_input_features): + """ + Confirm our logic to compute duo_scaling layer means via a running tally + matches the original memory-intensive AutoAWQ implementation, which concats + all balance layers into a single tensor before reducing to mean + Large models were prone to fail at this step. + """ + balance_layers = [ + torch.nn.Linear(n_input_features, 10) for _ in range(n_balance_layers) + ] + for balance_layer in balance_layers: + setattr( + balance_layer, + "quantization_scheme", + QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + strategy=( + QuantizationStrategy.GROUP + if group_size is not None + else QuantizationStrategy.CHANNEL + ), + group_size=group_size, + ), + ), + ) + + def _auto_awq_compute_layer_means(layers: list[torch.nn.Module]) -> torch.Tensor: + """ + Original AutoAwq implementation + """ + # [STEP 1]: Compute per-channel mean of normalised weights + # All layer weights are concatted together + weight = torch.cat([bl.weight for bl in balance_layers], dim=0) + org_shape = weight.shape + # The weights are reshaped to be organised by quantization group + if group_size is not None: + weight = weight.view(-1, group_size) + # Calculates the relative magnitude of the weights within + # each of the quantization groups, and rescales each group + # individually so that each group has weights on a 0-1 scale. + weight.abs_() + weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) + weight = weight.view(org_shape) + # Gets the average rescaled magnitude for each output channel + return weight.mean(0) + + w_mean_auto_awq = _auto_awq_compute_layer_means(balance_layers) + + w_mean_awq = AWQModifier._compute_layer_means(balance_layers).to( + w_mean_auto_awq.dtype + ) + + assert_close(w_mean_auto_awq, w_mean_awq)