From b1047255a79bc0c95ce059f50e732eb060b56c14 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Thu, 6 Nov 2025 20:36:40 +0000 Subject: [PATCH 1/2] AWQ MoE vllm fakequant Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 443 +++++++++++++++++++- modelopt/torch/quantization/plugins/vllm.py | 53 +-- modelopt/torch/quantization/utils.py | 15 + 3 files changed, 486 insertions(+), 25 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index f987efcd6..3ddec4e16 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -36,6 +36,7 @@ is_quantized_column_parallel_linear, is_quantized_linear, is_quantized_row_parallel_linear, + is_vllm_fused_moe, quantizer_attr_names, weight_attr_names, ) @@ -302,6 +303,150 @@ def apply_pre_quant_scale_and_smooth( ) +@torch.no_grad() +def apply_pre_quant_scale_and_smooth_fusedmoe( + module, + projection_name: str, + pre_quant_scale: torch.Tensor, + input_quantizer, + weight_quantizer, +): + """Apply pre_quant_scale and smooth for vLLM FusedMoE projections. + + Args: + module: The FusedMoE module + projection_name: Either "w13" or "w2" + pre_quant_scale: The scale to apply (best_scale from AWQ search) + input_quantizer: The input quantizer for this projection + weight_quantizer: The weight quantizer for this projection + """ + # Get weight + weight_attr = f"{projection_name}_weight" + weight = getattr(module, weight_attr) + + # Validate inputs + assert pre_quant_scale is not None, "pre_quant_scale must be provided" + assert torch.all(pre_quant_scale > 0), "pre_quant_scale should be positive" + assert weight_quantizer.pre_quant_scale is None, ( + "weight_quantizer.pre_quant_scale should be None first!" + ) + + # Convert to fp32 for numerical safety + pre_quant_scale = pre_quant_scale.to(torch.float32) + + # Reshape scale for broadcasting with weight + # Scale is always 1D [input_dim] for both linear and MoE weights + # For 2D weights [output_dim, input_dim]: broadcast [1, input_dim] + # For 3D weights [num_experts, output_dim, input_dim]: broadcast [1, 1, input_dim] + if weight.ndim == 3: + scale_reshaped = pre_quant_scale.view(1, 1, -1) + elif weight.ndim == 2: + scale_reshaped = pre_quant_scale.view(1, -1) + else: + raise ValueError(f"Unsupported weight dimensions: {weight.ndim}") + + # Fuse scale into weight + # AWQ formula: (input * 1/scale) @ (weight * scale) + # So we multiply weight by scale + weight.data = (weight.data * scale_reshaped).to(weight.dtype) + + # Reset and re-calibrate weight quantizer + weight_quantizer.reset_amax() + max_calibrate(weight_quantizer, lambda q: q(weight)) + + # Disable pre_quant_scale for weight quantizer (already folded into weight) + weight_quantizer._enable_pre_quant_scale = False + + # Handle input quantizer (if it was enabled and has amax) + if input_quantizer.amax is not None: + act_amax = input_quantizer.amax + + # Save per-channel amax for smoothing + input_quantizer._amax_for_smoothing = act_amax.cpu() + input_quantizer.reset_amax() + input_quantizer.axis = None + + # Set pre_quant_scale for input quantizer + # AWQ formula: (input * 1/scale) @ (weight * scale) + # Weight already has scale applied, so input needs 1/scale + inv_scale_for_input = 1.0 / pre_quant_scale + + # Input is always a regular activation tensor [batch, seq, hidden_dim] + # So we use 1D scale [hidden_dim] regardless of weight dimensions + # (Unlike weights which may be 3D for MoE: [num_experts, out_dim, in_dim]) + input_quantizer._enable_pre_quant_scale = True + input_quantizer.pre_quant_scale = inv_scale_for_input.to(weight.dtype) + + # Adjust amax to account for pre_quant_scale + # The input will be scaled by (1/scale), so the range changes + input_quantizer.amax = (act_amax * inv_scale_for_input).amax().to(weight.dtype) + + # Enable input quantizer for inference + input_quantizer.enable() + + +@torch.no_grad() +def postprocess_fusedmoe(module, name: str, projection_name: str, helper, get_scale_fn): + """Postprocess a single FusedMoE projection (w13 or w2) after AWQ-Lite calibration. + + Args: + module: The FusedMoE module + name: Module name (for logging) + projection_name: Either "w13" or "w2" + helper: AWQLiteFusedMoEHelper instance for this projection + get_scale_fn: Function to compute scale (takes act_scale, weight_scale, alpha, parallel_group) + """ + # Check if calibration was successful + if helper.num_cache_steps == 0: + helper.is_enabled = False + elif helper.num_search_steps == 0: + helper.is_enabled = False + warnings.warn( + f"awq_lite: {name}.{projection_name} did not receive data during search. " + "Falling back to max calibration." + ) + + # Compute best parameters + if helper.is_enabled: + helper.loss.update({k: float(v) for k, v in helper.loss.items()}) + helper.best_alpha = min(helper.loss, key=helper.loss.get) + helper.best_scale = get_scale_fn( + helper.act_scale, + helper.weight_scale, + helper.best_alpha, + None, + ) + + # Apply pre_quant_scale and smooth using utility function + # Note: Only apply if input quantizer was enabled during calibration + if helper.is_input_quantized: + apply_pre_quant_scale_and_smooth_fusedmoe( + module, + projection_name, + helper.best_scale, + helper.input_quantizer, + helper.weight_quantizer, + ) + else: + # Weight-only quantization: just fuse scale into weight + weight_attr = f"{projection_name}_weight" + weight = getattr(module, weight_attr) + weight.data = (weight.data * helper.best_scale.to(torch.float32)).to(weight.dtype) + helper.weight_quantizer.reset_amax() + max_calibrate(helper.weight_quantizer, lambda q: q(weight)) + helper.weight_quantizer._enable_pre_quant_scale = False + else: + # Fall back to max calibration + warnings.warn( + f"awq_lite: Disabling for {name}.{projection_name}, quantizing with max calibration." + ) + weight = getattr(module, f"{projection_name}_weight") + max_calibrate(helper.weight_quantizer, lambda q: q(weight)) + + # Cleanup helper + helper.cleanup() + + @torch.no_grad() def smoothquant(model: nn.Module, forward_loop: ForwardLoop | None = None, alpha=1.0): """Smooth-Quant variant with per-channel weight scaling. @@ -479,6 +624,103 @@ def cleanup(self): delattr(module, "_if_calib") unpatch_forward_method(module, "_forward_no_awq") + def forward_fusedmoe_w13(self, hidden_states, router_logits): + """Patched forward for FusedMoE to collect w13 stats during cache mode.""" + if AWQLiteFusedMoEHelper.cache_mode and self.awq_lite_w13.is_enabled: + # Collect act_scale from original hidden_states (before routing) + hidden_states_local = ( + hidden_states.to_local() if hasattr(hidden_states, "to_local") else hidden_states + ) + self.awq_lite_w13.act_scale += get_act_scale( + self.w13_input_quantizer(hidden_states_local) + ) + self.awq_lite_w13.num_cache_steps += 1 + self.awq_lite_w13.num_tokens += ( + hidden_states_local.numel() / hidden_states_local.shape[-1] + ) + + # Collect input quantizer amax + if self.awq_lite_w13.is_input_quantized: + with set_quantizer_by_cfg_context( + self.w13_input_quantizer, {"*": {"enable": True}} + ): + max_calibrate(self.w13_input_quantizer, lambda q: q(hidden_states_local), False) + + # Call original forward + return self._forward_no_awq_w13(hidden_states, router_logits) + + class AWQLiteFusedMoEHelper: + """Helper for AWQ-Lite calibration on vLLM FusedMoE modules.""" + + cache_mode: bool = False + + def __init__(self, module, name, projection_name): + """Initialize helper for w13 or w2 projection. + + Args: + module: The FusedMoE module + name: Module name + projection_name: Either "w13" or "w2" + """ + self.name = name + self.projection_name = projection_name + self.act_scale = 0.0 + self.num_cache_steps = 0 + self.num_search_steps = 0 + self.module = module + + # Get the appropriate weight and quantizers based on projection + if projection_name == "w13": + weight = module.w13_weight + self.input_quantizer = module.w13_input_quantizer + self.weight_quantizer = module.w13_weight_quantizer + else: # w2 + weight = module.w2_weight + self.input_quantizer = module.w2_input_quantizer + self.weight_quantizer = module.w2_weight_quantizer + + self.block_size = _get_awq_quantizer_block_size(weight, self.weight_quantizer) + self.weight_scale = get_weight_scale(weight, self.block_size) + self.loss = { + k.item(): torch.zeros((), device=weight.device, dtype=torch.float32) + for k in torch.arange(0, 1.0 + alpha_step, alpha_step) + } + self.best_scale = None + self.best_alpha = None + # Track if input quantizer was originally enabled + self.is_input_quantized = self.input_quantizer.is_enabled + self.num_tokens = 0 + self.is_enabled = True + + def setup(self): + """Setup for AWQ-Lite calibration. + + Similar to regular linear layers: + - Disable input_quantizer during calibration + - Set axis=-1 for per-channel calibration + - Will be temporarily enabled during cache mode for stats collection + - Stays disabled during search mode (manual pre_quant_scale per alpha) + - Re-enabled during postprocessing with optimal AWQ scale + """ + if self.input_quantizer.is_enabled: + self.input_quantizer.disable() + if self.input_quantizer.axis not in [None, -1]: + self.is_enabled = False + return + self.input_quantizer.axis = -1 + + # For w13: Patch forward to collect stats at forward level (before routing) + # For w2: Stats collected at kernel level (no patching needed) + if self.projection_name == "w13": + bind_forward_method(self.module, forward_fusedmoe_w13, "_forward_no_awq_w13") + + def cleanup(self): + """Cleanup after calibration.""" + if hasattr(self.module, "_if_calib"): + delattr(self.module, "_if_calib") + if self.projection_name == "w13": + unpatch_forward_method(self.module, "_forward_no_awq_w13") + def get_weight_scale(weight, block_size=None): org_shape = weight.shape slice_after_padding = None @@ -493,7 +735,11 @@ def get_weight_scale(weight, block_size=None): scale = scale.view(org_shape) if slice_after_padding is not None: scale = scale[..., slice_after_padding] - scale = scale.mean(0).to(torch.float32) + # Average across all dimensions except the last (input_dim) + # For 2D [output_dim, input_dim]: mean(0) → [input_dim] + # For 3D [num_experts, output_dim, input_dim]: mean([0,1]) → [input_dim] + dims_to_reduce = list(range(scale.ndim - 1)) + scale = scale.mean(dim=dims_to_reduce).to(torch.float32) return scale def get_act_scale(x): @@ -586,8 +832,150 @@ def forward(self, input, *args, **kwargs): module.awq_lite = AWQLiteHelper(module, name) module.awq_lite.setup() + # Setup vLLM FusedMoE modules + fused_moe_modules = [] + for name, module in model.named_modules(): + if is_vllm_fused_moe(module) and ( + module.w13_weight_quantizer.is_enabled or module.w2_weight_quantizer.is_enabled + ): + with enable_weight_access_and_writeback(module, model): + # Create helpers for both projections + module.awq_lite_w13 = AWQLiteFusedMoEHelper(module, name, "w13") + module.awq_lite_w2 = AWQLiteFusedMoEHelper(module, name, "w2") + module.awq_lite_w13.setup() + module.awq_lite_w2.setup() + fused_moe_modules.append(module) + + # Patch the global invoke_fused_moe_kernel function if we have FusedMoE modules + original_invoke_kernel = None + if fused_moe_modules: + try: + # Import vLLM package + import importlib + + vllm_fused_moe_package = importlib.import_module( + "vllm.model_executor.layers.fused_moe.fused_moe" + ) + + # Store original kernel + original_invoke_kernel = vllm_fused_moe_package.invoke_fused_moe_kernel + + def patched_invoke_fused_moe_kernel(A, B, C, *args, **kwargs): # noqa: N803 + """Patched kernel that handles AWQ-lite calibration for all FusedMoE modules.""" + # Find which module this call belongs to by checking B (weight tensor) + target_module = None + helper = None + input_q = None + weight_q = None + + for mod in fused_moe_modules: + if B is mod.w13_weight: + target_module = mod + helper = mod.awq_lite_w13 + input_q = mod.w13_input_quantizer + weight_q = mod.w13_weight_quantizer + break + elif B is mod.w2_weight: + target_module = mod + helper = mod.awq_lite_w2 + input_q = mod.w2_input_quantizer + weight_q = mod.w2_weight_quantizer + break + + # If not found or not enabled, use original kernel + if ( + target_module is None + or helper is None + or not helper.is_enabled + or A.numel() == 0 + ): + return original_invoke_kernel(A, B, C, *args, **kwargs) + + # Type assertions for mypy + assert helper is not None and input_q is not None and weight_q is not None + + # Compute actual output without quantization + weight_q.disable() + c_actual = torch.empty_like(C) + original_invoke_kernel(A, B, c_actual, *args, **kwargs) + weight_q.enable() + + if AWQLiteFusedMoEHelper.cache_mode: + # Cache mode: collect activation statistics + # For w13: Stats collected at forward level (skip here) + # For w2: Collect from intermediate activation A + if helper.projection_name == "w2": + a_local = A.to_local() if hasattr(A, "to_local") else A + helper.act_scale += get_act_scale(a_local) + helper.num_cache_steps += 1 + helper.num_tokens += a_local.numel() / a_local.shape[-1] + + # Collect input quantizer stats for w2 + if helper.is_input_quantized: + with set_quantizer_by_cfg_context(input_q, {"*": {"enable": True}}): + max_calibrate(input_q, lambda q: q(a_local), False) + + C[:] = c_actual + return + + # Search mode: try different alpha values + for alpha in helper.loss: + awq_scale = get_scale( + helper.act_scale, + helper.weight_scale, + alpha, + None, + ) + # Apply AWQ scaling: input * (1/scale) @ weight * scale + # For w13: input scaling applied at forward level (not here) + # For w2: input scaling applied at kernel level (here) + + # Reshape scale for broadcasting with weight + # awq_scale is always 1D [input_dim] + # For 3D weights [num_experts, output_dim, input_dim]: reshape to [1, 1, input_dim] + # For 2D weights [output_dim, input_dim]: reshape to [1, input_dim] + if B.ndim == 3: + scale_reshaped = awq_scale.view(1, 1, -1) + elif B.ndim == 2: + scale_reshaped = awq_scale.view(1, -1) + else: + raise ValueError(f"Unsupported weight dimensions: {B.ndim}") + + if helper.projection_name == "w2": + # w2: Apply input scaling at kernel level (on intermediate activations) + # Input A_local is 2D [num_tokens, hidden_dim], so use 1D scale + input_q._enable_pre_quant_scale = True + input_q.pre_quant_scale = (1 / awq_scale).to(B.dtype) + # For both: Apply weight scaling at kernel level + # Weight is 3D [num_experts, out_dim, in_dim], so use reshaped scale + weight_q._enable_pre_quant_scale = True + weight_q.pre_quant_scale = scale_reshaped.to(B.dtype) + + c_search = torch.empty_like(C) + original_invoke_kernel(A, B, c_search, *args, **kwargs) + + # Compute loss + loss = (c_search - c_actual).float().pow(2).mean() + helper.loss[alpha] += loss + + # Clean up temporary pre_quant_scale after search completes + # The last alpha iteration leaves _enable_pre_quant_scale=True, which would + # cause errors if any forward pass happens before postprocessing + weight_q._enable_pre_quant_scale = False + if helper.projection_name == "w2": + input_q._enable_pre_quant_scale = False + + helper.num_search_steps += 1 + C[:] = c_actual + + # Apply the patch globally + vllm_fused_moe_package.invoke_fused_moe_kernel = patched_invoke_fused_moe_kernel # type: ignore[attr-defined] + except ImportError: + warnings.warn("vLLM not installed, skipping FusedMoE AWQ-lite calibration") + # Collect activation scale values AWQLiteHelper.cache_mode = True + AWQLiteFusedMoEHelper.cache_mode = True print_rank_0("awq_lite: Caching activation statistics...") # Lets enable stats collection @@ -631,7 +1019,37 @@ def sync_act_scale_across_dp(module, data_parallel_group): module.parallel_state.data_parallel_group, ) + # Sync FusedMoE activation scales + for name, module in model.named_modules(): + if hasattr(module, "awq_lite_w13") and hasattr(module, "awq_lite_w2"): + module._if_calib = True + + for helper in [module.awq_lite_w13, module.awq_lite_w2]: + if helper.num_cache_steps > 0: + helper.act_scale = helper.act_scale / helper.num_cache_steps + + has_nan_local = torch.any(torch.isnan(helper.act_scale)) or torch.any( + torch.isnan(helper.weight_scale) + ) + has_nan = DistributedProcessGroup.get_dist_syncd_obj( + has_nan_local, + module.parallel_state.data_parallel_group, + lambda objs: any(objs), + ) + + if has_nan: + helper.is_enabled = False + elif module.parallel_state.data_parallel_group.is_initialized(): + dist.all_reduce( + helper.act_scale, + op=dist.ReduceOp.AVG, + group=module.parallel_state.data_parallel_group.group, + ) + else: + helper.is_enabled = False + AWQLiteHelper.cache_mode = False + AWQLiteFusedMoEHelper.cache_mode = False print_rank_0("awq_lite: Searching parameters...") with torch.no_grad(): forward_loop(model) @@ -675,6 +1093,29 @@ def postprocess(module, name): if not debug: delattr(module, "awq_lite") + # Restore original vLLM kernel and post-process FusedMoE modules + if fused_moe_modules and original_invoke_kernel is not None: + try: + import importlib + + vllm_fused_moe_package = importlib.import_module( + "vllm.model_executor.layers.fused_moe.fused_moe" + ) + vllm_fused_moe_package.invoke_fused_moe_kernel = original_invoke_kernel # type: ignore[attr-defined] + except ImportError: + pass + + for name, module in model.named_modules(): + if hasattr(module, "awq_lite_w13") and hasattr(module, "awq_lite_w2"): + for proj_name, helper_attr in [("w13", "awq_lite_w13"), ("w2", "awq_lite_w2")]: + helper = getattr(module, helper_attr) + + with enable_weight_access_and_writeback(module, model): + postprocess_fusedmoe(module, name, proj_name, helper, get_scale) + + if not debug: + delattr(module, helper_attr) + @torch.no_grad() def awq_clip( diff --git a/modelopt/torch/quantization/plugins/vllm.py b/modelopt/torch/quantization/plugins/vllm.py index 11954e69f..872d379b5 100644 --- a/modelopt/torch/quantization/plugins/vllm.py +++ b/modelopt/torch/quantization/plugins/vllm.py @@ -27,6 +27,13 @@ vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe") +def _assign_weight(target, weight): + if isinstance(target, torch.nn.Parameter): + target = torch.nn.Parameter(weight, requires_grad=target.requires_grad) + else: + target = weight + + class FakeQuantMethod: """A class that implements fake quantization methods for vLLM models. @@ -150,44 +157,42 @@ def invoke_fused_moe_quantized( **kwargs, ): if B is self.w13_weight: - # First layer of expert - A = self.w13_input_quantizer(A) # noqa: N806 if self.w13_weight_quantizer.is_enabled: original_weight = self.w13_weight - self.w13_weight = self.w13_weight_quantizer(self.w13_weight) - vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) - self.w13_weight = original_weight + _assign_weight(self.w13_weight, self.w13_weight_quantizer(self.w13_weight)) + self._original_invoke_kernel(A, B, C, *args, **kwargs) + _assign_weight(self.w13_weight, original_weight) else: - vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) + self._original_invoke_kernel(A, B, C, *args, **kwargs) if self.w13_output_quantizer.is_enabled: C[:] = self.w13_output_quantizer(C) elif B is self.w2_weight: A = self.w2_input_quantizer(A) # noqa: N806 if self.w2_weight_quantizer.is_enabled: original_weight = self.w2_weight - self.w2_weight = self.w2_weight_quantizer(self.w2_weight) - vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) - self.w2_weight = original_weight + _assign_weight(self.w2_weight, self.w2_weight_quantizer(self.w2_weight)) + self._original_invoke_kernel(A, B, C, *args, **kwargs) + _assign_weight(self.w2_weight, original_weight) else: - vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) + self._original_invoke_kernel(A, B, C, *args, **kwargs) if self.w2_output_quantizer.is_enabled: C[:] = self.w2_output_quantizer(C) else: raise ValueError("Cannot determine first or second layer of expert") def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - # This is again due to the bad coding of vLLM - # fused_moe submodule is overwritten by the fused_moe function - # so we need to import the fused_moe module explicitly - assert vllm_fused_moe_package.invoke_fused_moe_kernel is not None - # This context manager will conflict with torch.compile - # with replace_function( - # vllm_fused_moe_package, - # "invoke_fused_moe_kernel", - # self.invoke_fused_moe_quantized, - # ): - self._invoke_fused_moe_quantized = self.invoke_fused_moe_quantized - self.invoke_fused_moe_quantized = self.invoke_fused_moe_quantized - output = super().forward(hidden_states, router_logits) - self.invoke_fused_moe_quantized = self._invoke_fused_moe_quantized + hidden_states = self.w13_input_quantizer(hidden_states) + + # Save the original kernel function + self._original_invoke_kernel = vllm_fused_moe_package.invoke_fused_moe_kernel + + # Patch the module-level function to use our quantized version + vllm_fused_moe_package.invoke_fused_moe_kernel = self.invoke_fused_moe_quantized # type: ignore[attr-defined] + + try: + output = super().forward(hidden_states, router_logits) + finally: + # Restore the original kernel function + vllm_fused_moe_package.invoke_fused_moe_kernel = self._original_invoke_kernel # type: ignore[attr-defined] + return output diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 6167daf23..39c95e27f 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -266,6 +266,21 @@ def is_quantized_row_parallel_linear(module): return is_quantized_linear(module) and getattr(module, "_is_row_parallel", False) +def is_vllm_fused_moe(module): + """Check if a module is a vLLM FusedMoE module.""" + from .nn import QuantModule, TensorQuantizer + + return ( + isinstance(module, QuantModule) + and hasattr(module, "w13_weight") + and hasattr(module, "w2_weight") + and isinstance(getattr(module, "w13_input_quantizer", None), TensorQuantizer) + and isinstance(getattr(module, "w2_input_quantizer", None), TensorQuantizer) + and hasattr(module, "w13_weight_quantizer") + and hasattr(module, "w2_weight_quantizer") + ) + + def is_quantized_parallel_linear(module): """Check if a module is a quantized parallel linear module.""" return is_quantized_column_parallel_linear(module) or is_quantized_row_parallel_linear(module) From b7844871ba37642a9c695bdd17e4af1c51ed0948 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Thu, 6 Nov 2025 20:43:06 +0000 Subject: [PATCH 2/2] clean up Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 91 ++++----------------- modelopt/torch/quantization/plugins/vllm.py | 5 +- 2 files changed, 17 insertions(+), 79 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3ddec4e16..6890721d3 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -331,13 +331,9 @@ def apply_pre_quant_scale_and_smooth_fusedmoe( "weight_quantizer.pre_quant_scale should be None first!" ) - # Convert to fp32 for numerical safety pre_quant_scale = pre_quant_scale.to(torch.float32) - # Reshape scale for broadcasting with weight - # Scale is always 1D [input_dim] for both linear and MoE weights - # For 2D weights [output_dim, input_dim]: broadcast [1, input_dim] - # For 3D weights [num_experts, output_dim, input_dim]: broadcast [1, 1, input_dim] + # Reshape scale for weight broadcasting: 2D [1, in_dim] or 3D [1, 1, in_dim] if weight.ndim == 3: scale_reshaped = pre_quant_scale.view(1, 1, -1) elif weight.ndim == 2: @@ -345,43 +341,24 @@ def apply_pre_quant_scale_and_smooth_fusedmoe( else: raise ValueError(f"Unsupported weight dimensions: {weight.ndim}") - # Fuse scale into weight - # AWQ formula: (input * 1/scale) @ (weight * scale) - # So we multiply weight by scale + # Fuse scale into weight: (input * 1/scale) @ (weight * scale) weight.data = (weight.data * scale_reshaped).to(weight.dtype) - - # Reset and re-calibrate weight quantizer weight_quantizer.reset_amax() max_calibrate(weight_quantizer, lambda q: q(weight)) - - # Disable pre_quant_scale for weight quantizer (already folded into weight) weight_quantizer._enable_pre_quant_scale = False - # Handle input quantizer (if it was enabled and has amax) + # Setup input quantizer with inverse scale if input_quantizer.amax is not None: act_amax = input_quantizer.amax - - # Save per-channel amax for smoothing input_quantizer._amax_for_smoothing = act_amax.cpu() input_quantizer.reset_amax() input_quantizer.axis = None - # Set pre_quant_scale for input quantizer - # AWQ formula: (input * 1/scale) @ (weight * scale) - # Weight already has scale applied, so input needs 1/scale + # Input scale is 1D [in_dim] for both linear and MoE activations inv_scale_for_input = 1.0 / pre_quant_scale - - # Input is always a regular activation tensor [batch, seq, hidden_dim] - # So we use 1D scale [hidden_dim] regardless of weight dimensions - # (Unlike weights which may be 3D for MoE: [num_experts, out_dim, in_dim]) input_quantizer._enable_pre_quant_scale = True input_quantizer.pre_quant_scale = inv_scale_for_input.to(weight.dtype) - - # Adjust amax to account for pre_quant_scale - # The input will be scaled by (1/scale), so the range changes input_quantizer.amax = (act_amax * inv_scale_for_input).amax().to(weight.dtype) - - # Enable input quantizer for inference input_quantizer.enable() @@ -625,9 +602,8 @@ def cleanup(self): unpatch_forward_method(module, "_forward_no_awq") def forward_fusedmoe_w13(self, hidden_states, router_logits): - """Patched forward for FusedMoE to collect w13 stats during cache mode.""" + """Patched forward to collect w13 stats before routing.""" if AWQLiteFusedMoEHelper.cache_mode and self.awq_lite_w13.is_enabled: - # Collect act_scale from original hidden_states (before routing) hidden_states_local = ( hidden_states.to_local() if hasattr(hidden_states, "to_local") else hidden_states ) @@ -639,14 +615,12 @@ def forward_fusedmoe_w13(self, hidden_states, router_logits): hidden_states_local.numel() / hidden_states_local.shape[-1] ) - # Collect input quantizer amax if self.awq_lite_w13.is_input_quantized: with set_quantizer_by_cfg_context( self.w13_input_quantizer, {"*": {"enable": True}} ): max_calibrate(self.w13_input_quantizer, lambda q: q(hidden_states_local), False) - # Call original forward return self._forward_no_awq_w13(hidden_states, router_logits) class AWQLiteFusedMoEHelper: @@ -693,15 +667,7 @@ def __init__(self, module, name, projection_name): self.is_enabled = True def setup(self): - """Setup for AWQ-Lite calibration. - - Similar to regular linear layers: - - Disable input_quantizer during calibration - - Set axis=-1 for per-channel calibration - - Will be temporarily enabled during cache mode for stats collection - - Stays disabled during search mode (manual pre_quant_scale per alpha) - - Re-enabled during postprocessing with optimal AWQ scale - """ + """Setup for AWQ calibration: disable input quantizer and patch forward if needed.""" if self.input_quantizer.is_enabled: self.input_quantizer.disable() if self.input_quantizer.axis not in [None, -1]: @@ -709,8 +675,7 @@ def setup(self): return self.input_quantizer.axis = -1 - # For w13: Patch forward to collect stats at forward level (before routing) - # For w2: Stats collected at kernel level (no patching needed) + # Patch forward for w13 to collect stats before routing if self.projection_name == "w13": bind_forward_method(self.module, forward_fusedmoe_w13, "_forward_no_awq_w13") @@ -735,9 +700,7 @@ def get_weight_scale(weight, block_size=None): scale = scale.view(org_shape) if slice_after_padding is not None: scale = scale[..., slice_after_padding] - # Average across all dimensions except the last (input_dim) - # For 2D [output_dim, input_dim]: mean(0) → [input_dim] - # For 3D [num_experts, output_dim, input_dim]: mean([0,1]) → [input_dim] + # Average to 1D [input_dim] for both 2D and 3D weights dims_to_reduce = list(range(scale.ndim - 1)) scale = scale.mean(dim=dims_to_reduce).to(torch.float32) return scale @@ -861,8 +824,8 @@ def forward(self, input, *args, **kwargs): original_invoke_kernel = vllm_fused_moe_package.invoke_fused_moe_kernel def patched_invoke_fused_moe_kernel(A, B, C, *args, **kwargs): # noqa: N803 - """Patched kernel that handles AWQ-lite calibration for all FusedMoE modules.""" - # Find which module this call belongs to by checking B (weight tensor) + """Patched kernel for AWQ calibration of FusedMoE modules.""" + # Find module by weight tensor target_module = None helper = None input_q = None @@ -882,7 +845,6 @@ def patched_invoke_fused_moe_kernel(A, B, C, *args, **kwargs): # noqa: N803 weight_q = mod.w2_weight_quantizer break - # If not found or not enabled, use original kernel if ( target_module is None or helper is None @@ -891,26 +853,22 @@ def patched_invoke_fused_moe_kernel(A, B, C, *args, **kwargs): # noqa: N803 ): return original_invoke_kernel(A, B, C, *args, **kwargs) - # Type assertions for mypy assert helper is not None and input_q is not None and weight_q is not None - # Compute actual output without quantization + # Compute ground truth without quantization weight_q.disable() c_actual = torch.empty_like(C) original_invoke_kernel(A, B, c_actual, *args, **kwargs) weight_q.enable() if AWQLiteFusedMoEHelper.cache_mode: - # Cache mode: collect activation statistics - # For w13: Stats collected at forward level (skip here) - # For w2: Collect from intermediate activation A + # Cache mode: collect w2 stats (w13 stats collected at forward level) if helper.projection_name == "w2": a_local = A.to_local() if hasattr(A, "to_local") else A helper.act_scale += get_act_scale(a_local) helper.num_cache_steps += 1 helper.num_tokens += a_local.numel() / a_local.shape[-1] - # Collect input quantizer stats for w2 if helper.is_input_quantized: with set_quantizer_by_cfg_context(input_q, {"*": {"enable": True}}): max_calibrate(input_q, lambda q: q(a_local), False) @@ -920,20 +878,9 @@ def patched_invoke_fused_moe_kernel(A, B, C, *args, **kwargs): # noqa: N803 # Search mode: try different alpha values for alpha in helper.loss: - awq_scale = get_scale( - helper.act_scale, - helper.weight_scale, - alpha, - None, - ) - # Apply AWQ scaling: input * (1/scale) @ weight * scale - # For w13: input scaling applied at forward level (not here) - # For w2: input scaling applied at kernel level (here) + awq_scale = get_scale(helper.act_scale, helper.weight_scale, alpha, None) - # Reshape scale for broadcasting with weight - # awq_scale is always 1D [input_dim] - # For 3D weights [num_experts, output_dim, input_dim]: reshape to [1, 1, input_dim] - # For 2D weights [output_dim, input_dim]: reshape to [1, input_dim] + # Reshape scale for weight broadcasting if B.ndim == 3: scale_reshaped = awq_scale.view(1, 1, -1) elif B.ndim == 2: @@ -941,26 +888,20 @@ def patched_invoke_fused_moe_kernel(A, B, C, *args, **kwargs): # noqa: N803 else: raise ValueError(f"Unsupported weight dimensions: {B.ndim}") + # Apply AWQ scaling: input * (1/scale) @ weight * scale if helper.projection_name == "w2": - # w2: Apply input scaling at kernel level (on intermediate activations) - # Input A_local is 2D [num_tokens, hidden_dim], so use 1D scale input_q._enable_pre_quant_scale = True input_q.pre_quant_scale = (1 / awq_scale).to(B.dtype) - # For both: Apply weight scaling at kernel level - # Weight is 3D [num_experts, out_dim, in_dim], so use reshaped scale weight_q._enable_pre_quant_scale = True weight_q.pre_quant_scale = scale_reshaped.to(B.dtype) c_search = torch.empty_like(C) original_invoke_kernel(A, B, c_search, *args, **kwargs) - # Compute loss loss = (c_search - c_actual).float().pow(2).mean() helper.loss[alpha] += loss - # Clean up temporary pre_quant_scale after search completes - # The last alpha iteration leaves _enable_pre_quant_scale=True, which would - # cause errors if any forward pass happens before postprocessing + # Disable temporary scales weight_q._enable_pre_quant_scale = False if helper.projection_name == "w2": input_q._enable_pre_quant_scale = False diff --git a/modelopt/torch/quantization/plugins/vllm.py b/modelopt/torch/quantization/plugins/vllm.py index 872d379b5..d3dc2f20f 100644 --- a/modelopt/torch/quantization/plugins/vllm.py +++ b/modelopt/torch/quantization/plugins/vllm.py @@ -183,16 +183,13 @@ def invoke_fused_moe_quantized( def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): hidden_states = self.w13_input_quantizer(hidden_states) - # Save the original kernel function + # Temporarily patch kernel to apply quantization self._original_invoke_kernel = vllm_fused_moe_package.invoke_fused_moe_kernel - - # Patch the module-level function to use our quantized version vllm_fused_moe_package.invoke_fused_moe_kernel = self.invoke_fused_moe_quantized # type: ignore[attr-defined] try: output = super().forward(hidden_states, router_logits) finally: - # Restore the original kernel function vllm_fused_moe_package.invoke_fused_moe_kernel = self._original_invoke_kernel # type: ignore[attr-defined] return output