diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index f987efcd6..6890721d3 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -36,6 +36,7 @@ is_quantized_column_parallel_linear, is_quantized_linear, is_quantized_row_parallel_linear, + is_vllm_fused_moe, quantizer_attr_names, weight_attr_names, ) @@ -302,6 +303,127 @@ def apply_pre_quant_scale_and_smooth( ) +@torch.no_grad() +def apply_pre_quant_scale_and_smooth_fusedmoe( + module, + projection_name: str, + pre_quant_scale: torch.Tensor, + input_quantizer, + weight_quantizer, +): + """Apply pre_quant_scale and smooth for vLLM FusedMoE projections. + + Args: + module: The FusedMoE module + projection_name: Either "w13" or "w2" + pre_quant_scale: The scale to apply (best_scale from AWQ search) + input_quantizer: The input quantizer for this projection + weight_quantizer: The weight quantizer for this projection + """ + # Get weight + weight_attr = f"{projection_name}_weight" + weight = getattr(module, weight_attr) + + # Validate inputs + assert pre_quant_scale is not None, "pre_quant_scale must be provided" + assert torch.all(pre_quant_scale > 0), "pre_quant_scale should be positive" + assert weight_quantizer.pre_quant_scale is None, ( + "weight_quantizer.pre_quant_scale should be None first!" + ) + + pre_quant_scale = pre_quant_scale.to(torch.float32) + + # Reshape scale for weight broadcasting: 2D [1, in_dim] or 3D [1, 1, in_dim] + if weight.ndim == 3: + scale_reshaped = pre_quant_scale.view(1, 1, -1) + elif weight.ndim == 2: + scale_reshaped = pre_quant_scale.view(1, -1) + else: + raise ValueError(f"Unsupported weight dimensions: {weight.ndim}") + + # Fuse scale into weight: (input * 1/scale) @ (weight * scale) + weight.data = (weight.data * scale_reshaped).to(weight.dtype) + weight_quantizer.reset_amax() + max_calibrate(weight_quantizer, lambda q: q(weight)) + weight_quantizer._enable_pre_quant_scale = False + + # Setup input quantizer with inverse scale + if input_quantizer.amax is not None: + act_amax = input_quantizer.amax + input_quantizer._amax_for_smoothing = act_amax.cpu() + input_quantizer.reset_amax() + input_quantizer.axis = None + + # Input scale is 1D [in_dim] for both linear and MoE activations + inv_scale_for_input = 1.0 / pre_quant_scale + input_quantizer._enable_pre_quant_scale = True + input_quantizer.pre_quant_scale = inv_scale_for_input.to(weight.dtype) + input_quantizer.amax = (act_amax * inv_scale_for_input).amax().to(weight.dtype) + input_quantizer.enable() + + +@torch.no_grad() +def postprocess_fusedmoe(module, name: str, projection_name: str, helper, get_scale_fn): + """Postprocess a single FusedMoE projection (w13 or w2) after AWQ-Lite calibration. + + Args: + module: The FusedMoE module + name: Module name (for logging) + projection_name: Either "w13" or "w2" + helper: AWQLiteFusedMoEHelper instance for this projection + get_scale_fn: Function to compute scale (takes act_scale, weight_scale, alpha, parallel_group) + """ + # Check if calibration was successful + if helper.num_cache_steps == 0: + helper.is_enabled = False + elif helper.num_search_steps == 0: + helper.is_enabled = False + warnings.warn( + f"awq_lite: {name}.{projection_name} did not receive data during search. " + "Falling back to max calibration." + ) + + # Compute best parameters + if helper.is_enabled: + helper.loss.update({k: float(v) for k, v in helper.loss.items()}) + helper.best_alpha = min(helper.loss, key=helper.loss.get) + helper.best_scale = get_scale_fn( + helper.act_scale, + helper.weight_scale, + helper.best_alpha, + None, + ) + + # Apply pre_quant_scale and smooth using utility function + # Note: Only apply if input quantizer was enabled during calibration + if helper.is_input_quantized: + apply_pre_quant_scale_and_smooth_fusedmoe( + module, + projection_name, + helper.best_scale, + helper.input_quantizer, + helper.weight_quantizer, + ) + else: + # Weight-only quantization: just fuse scale into weight + weight_attr = f"{projection_name}_weight" + weight = getattr(module, weight_attr) + weight.data = (weight.data * helper.best_scale.to(torch.float32)).to(weight.dtype) + helper.weight_quantizer.reset_amax() + max_calibrate(helper.weight_quantizer, lambda q: q(weight)) + helper.weight_quantizer._enable_pre_quant_scale = False + else: + # Fall back to max calibration + warnings.warn( + f"awq_lite: Disabling for {name}.{projection_name}, quantizing with max calibration." + ) + weight = getattr(module, f"{projection_name}_weight") + max_calibrate(helper.weight_quantizer, lambda q: q(weight)) + + # Cleanup helper + helper.cleanup() + + @torch.no_grad() def smoothquant(model: nn.Module, forward_loop: ForwardLoop | None = None, alpha=1.0): """Smooth-Quant variant with per-channel weight scaling. @@ -479,6 +601,91 @@ def cleanup(self): delattr(module, "_if_calib") unpatch_forward_method(module, "_forward_no_awq") + def forward_fusedmoe_w13(self, hidden_states, router_logits): + """Patched forward to collect w13 stats before routing.""" + if AWQLiteFusedMoEHelper.cache_mode and self.awq_lite_w13.is_enabled: + hidden_states_local = ( + hidden_states.to_local() if hasattr(hidden_states, "to_local") else hidden_states + ) + self.awq_lite_w13.act_scale += get_act_scale( + self.w13_input_quantizer(hidden_states_local) + ) + self.awq_lite_w13.num_cache_steps += 1 + self.awq_lite_w13.num_tokens += ( + hidden_states_local.numel() / hidden_states_local.shape[-1] + ) + + if self.awq_lite_w13.is_input_quantized: + with set_quantizer_by_cfg_context( + self.w13_input_quantizer, {"*": {"enable": True}} + ): + max_calibrate(self.w13_input_quantizer, lambda q: q(hidden_states_local), False) + + return self._forward_no_awq_w13(hidden_states, router_logits) + + class AWQLiteFusedMoEHelper: + """Helper for AWQ-Lite calibration on vLLM FusedMoE modules.""" + + cache_mode: bool = False + + def __init__(self, module, name, projection_name): + """Initialize helper for w13 or w2 projection. + + Args: + module: The FusedMoE module + name: Module name + projection_name: Either "w13" or "w2" + """ + self.name = name + self.projection_name = projection_name + self.act_scale = 0.0 + self.num_cache_steps = 0 + self.num_search_steps = 0 + self.module = module + + # Get the appropriate weight and quantizers based on projection + if projection_name == "w13": + weight = module.w13_weight + self.input_quantizer = module.w13_input_quantizer + self.weight_quantizer = module.w13_weight_quantizer + else: # w2 + weight = module.w2_weight + self.input_quantizer = module.w2_input_quantizer + self.weight_quantizer = module.w2_weight_quantizer + + self.block_size = _get_awq_quantizer_block_size(weight, self.weight_quantizer) + self.weight_scale = get_weight_scale(weight, self.block_size) + self.loss = { + k.item(): torch.zeros((), device=weight.device, dtype=torch.float32) + for k in torch.arange(0, 1.0 + alpha_step, alpha_step) + } + self.best_scale = None + self.best_alpha = None + # Track if input quantizer was originally enabled + self.is_input_quantized = self.input_quantizer.is_enabled + self.num_tokens = 0 + self.is_enabled = True + + def setup(self): + """Setup for AWQ calibration: disable input quantizer and patch forward if needed.""" + if self.input_quantizer.is_enabled: + self.input_quantizer.disable() + if self.input_quantizer.axis not in [None, -1]: + self.is_enabled = False + return + self.input_quantizer.axis = -1 + + # Patch forward for w13 to collect stats before routing + if self.projection_name == "w13": + bind_forward_method(self.module, forward_fusedmoe_w13, "_forward_no_awq_w13") + + def cleanup(self): + """Cleanup after calibration.""" + if hasattr(self.module, "_if_calib"): + delattr(self.module, "_if_calib") + if self.projection_name == "w13": + unpatch_forward_method(self.module, "_forward_no_awq_w13") + def get_weight_scale(weight, block_size=None): org_shape = weight.shape slice_after_padding = None @@ -493,7 +700,9 @@ def get_weight_scale(weight, block_size=None): scale = scale.view(org_shape) if slice_after_padding is not None: scale = scale[..., slice_after_padding] - scale = scale.mean(0).to(torch.float32) + # Average to 1D [input_dim] for both 2D and 3D weights + dims_to_reduce = list(range(scale.ndim - 1)) + scale = scale.mean(dim=dims_to_reduce).to(torch.float32) return scale def get_act_scale(x): @@ -586,8 +795,128 @@ def forward(self, input, *args, **kwargs): module.awq_lite = AWQLiteHelper(module, name) module.awq_lite.setup() + # Setup vLLM FusedMoE modules + fused_moe_modules = [] + for name, module in model.named_modules(): + if is_vllm_fused_moe(module) and ( + module.w13_weight_quantizer.is_enabled or module.w2_weight_quantizer.is_enabled + ): + with enable_weight_access_and_writeback(module, model): + # Create helpers for both projections + module.awq_lite_w13 = AWQLiteFusedMoEHelper(module, name, "w13") + module.awq_lite_w2 = AWQLiteFusedMoEHelper(module, name, "w2") + module.awq_lite_w13.setup() + module.awq_lite_w2.setup() + fused_moe_modules.append(module) + + # Patch the global invoke_fused_moe_kernel function if we have FusedMoE modules + original_invoke_kernel = None + if fused_moe_modules: + try: + # Import vLLM package + import importlib + + vllm_fused_moe_package = importlib.import_module( + "vllm.model_executor.layers.fused_moe.fused_moe" + ) + + # Store original kernel + original_invoke_kernel = vllm_fused_moe_package.invoke_fused_moe_kernel + + def patched_invoke_fused_moe_kernel(A, B, C, *args, **kwargs): # noqa: N803 + """Patched kernel for AWQ calibration of FusedMoE modules.""" + # Find module by weight tensor + target_module = None + helper = None + input_q = None + weight_q = None + + for mod in fused_moe_modules: + if B is mod.w13_weight: + target_module = mod + helper = mod.awq_lite_w13 + input_q = mod.w13_input_quantizer + weight_q = mod.w13_weight_quantizer + break + elif B is mod.w2_weight: + target_module = mod + helper = mod.awq_lite_w2 + input_q = mod.w2_input_quantizer + weight_q = mod.w2_weight_quantizer + break + + if ( + target_module is None + or helper is None + or not helper.is_enabled + or A.numel() == 0 + ): + return original_invoke_kernel(A, B, C, *args, **kwargs) + + assert helper is not None and input_q is not None and weight_q is not None + + # Compute ground truth without quantization + weight_q.disable() + c_actual = torch.empty_like(C) + original_invoke_kernel(A, B, c_actual, *args, **kwargs) + weight_q.enable() + + if AWQLiteFusedMoEHelper.cache_mode: + # Cache mode: collect w2 stats (w13 stats collected at forward level) + if helper.projection_name == "w2": + a_local = A.to_local() if hasattr(A, "to_local") else A + helper.act_scale += get_act_scale(a_local) + helper.num_cache_steps += 1 + helper.num_tokens += a_local.numel() / a_local.shape[-1] + + if helper.is_input_quantized: + with set_quantizer_by_cfg_context(input_q, {"*": {"enable": True}}): + max_calibrate(input_q, lambda q: q(a_local), False) + + C[:] = c_actual + return + + # Search mode: try different alpha values + for alpha in helper.loss: + awq_scale = get_scale(helper.act_scale, helper.weight_scale, alpha, None) + + # Reshape scale for weight broadcasting + if B.ndim == 3: + scale_reshaped = awq_scale.view(1, 1, -1) + elif B.ndim == 2: + scale_reshaped = awq_scale.view(1, -1) + else: + raise ValueError(f"Unsupported weight dimensions: {B.ndim}") + + # Apply AWQ scaling: input * (1/scale) @ weight * scale + if helper.projection_name == "w2": + input_q._enable_pre_quant_scale = True + input_q.pre_quant_scale = (1 / awq_scale).to(B.dtype) + weight_q._enable_pre_quant_scale = True + weight_q.pre_quant_scale = scale_reshaped.to(B.dtype) + + c_search = torch.empty_like(C) + original_invoke_kernel(A, B, c_search, *args, **kwargs) + + loss = (c_search - c_actual).float().pow(2).mean() + helper.loss[alpha] += loss + + # Disable temporary scales + weight_q._enable_pre_quant_scale = False + if helper.projection_name == "w2": + input_q._enable_pre_quant_scale = False + + helper.num_search_steps += 1 + C[:] = c_actual + + # Apply the patch globally + vllm_fused_moe_package.invoke_fused_moe_kernel = patched_invoke_fused_moe_kernel # type: ignore[attr-defined] + except ImportError: + warnings.warn("vLLM not installed, skipping FusedMoE AWQ-lite calibration") + # Collect activation scale values AWQLiteHelper.cache_mode = True + AWQLiteFusedMoEHelper.cache_mode = True print_rank_0("awq_lite: Caching activation statistics...") # Lets enable stats collection @@ -631,7 +960,37 @@ def sync_act_scale_across_dp(module, data_parallel_group): module.parallel_state.data_parallel_group, ) + # Sync FusedMoE activation scales + for name, module in model.named_modules(): + if hasattr(module, "awq_lite_w13") and hasattr(module, "awq_lite_w2"): + module._if_calib = True + + for helper in [module.awq_lite_w13, module.awq_lite_w2]: + if helper.num_cache_steps > 0: + helper.act_scale = helper.act_scale / helper.num_cache_steps + + has_nan_local = torch.any(torch.isnan(helper.act_scale)) or torch.any( + torch.isnan(helper.weight_scale) + ) + has_nan = DistributedProcessGroup.get_dist_syncd_obj( + has_nan_local, + module.parallel_state.data_parallel_group, + lambda objs: any(objs), + ) + + if has_nan: + helper.is_enabled = False + elif module.parallel_state.data_parallel_group.is_initialized(): + dist.all_reduce( + helper.act_scale, + op=dist.ReduceOp.AVG, + group=module.parallel_state.data_parallel_group.group, + ) + else: + helper.is_enabled = False + AWQLiteHelper.cache_mode = False + AWQLiteFusedMoEHelper.cache_mode = False print_rank_0("awq_lite: Searching parameters...") with torch.no_grad(): forward_loop(model) @@ -675,6 +1034,29 @@ def postprocess(module, name): if not debug: delattr(module, "awq_lite") + # Restore original vLLM kernel and post-process FusedMoE modules + if fused_moe_modules and original_invoke_kernel is not None: + try: + import importlib + + vllm_fused_moe_package = importlib.import_module( + "vllm.model_executor.layers.fused_moe.fused_moe" + ) + vllm_fused_moe_package.invoke_fused_moe_kernel = original_invoke_kernel # type: ignore[attr-defined] + except ImportError: + pass + + for name, module in model.named_modules(): + if hasattr(module, "awq_lite_w13") and hasattr(module, "awq_lite_w2"): + for proj_name, helper_attr in [("w13", "awq_lite_w13"), ("w2", "awq_lite_w2")]: + helper = getattr(module, helper_attr) + + with enable_weight_access_and_writeback(module, model): + postprocess_fusedmoe(module, name, proj_name, helper, get_scale) + + if not debug: + delattr(module, helper_attr) + @torch.no_grad() def awq_clip( diff --git a/modelopt/torch/quantization/plugins/vllm.py b/modelopt/torch/quantization/plugins/vllm.py index 11954e69f..d3dc2f20f 100644 --- a/modelopt/torch/quantization/plugins/vllm.py +++ b/modelopt/torch/quantization/plugins/vllm.py @@ -27,6 +27,13 @@ vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe") +def _assign_weight(target, weight): + if isinstance(target, torch.nn.Parameter): + target = torch.nn.Parameter(weight, requires_grad=target.requires_grad) + else: + target = weight + + class FakeQuantMethod: """A class that implements fake quantization methods for vLLM models. @@ -150,44 +157,39 @@ def invoke_fused_moe_quantized( **kwargs, ): if B is self.w13_weight: - # First layer of expert - A = self.w13_input_quantizer(A) # noqa: N806 if self.w13_weight_quantizer.is_enabled: original_weight = self.w13_weight - self.w13_weight = self.w13_weight_quantizer(self.w13_weight) - vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) - self.w13_weight = original_weight + _assign_weight(self.w13_weight, self.w13_weight_quantizer(self.w13_weight)) + self._original_invoke_kernel(A, B, C, *args, **kwargs) + _assign_weight(self.w13_weight, original_weight) else: - vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) + self._original_invoke_kernel(A, B, C, *args, **kwargs) if self.w13_output_quantizer.is_enabled: C[:] = self.w13_output_quantizer(C) elif B is self.w2_weight: A = self.w2_input_quantizer(A) # noqa: N806 if self.w2_weight_quantizer.is_enabled: original_weight = self.w2_weight - self.w2_weight = self.w2_weight_quantizer(self.w2_weight) - vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) - self.w2_weight = original_weight + _assign_weight(self.w2_weight, self.w2_weight_quantizer(self.w2_weight)) + self._original_invoke_kernel(A, B, C, *args, **kwargs) + _assign_weight(self.w2_weight, original_weight) else: - vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) + self._original_invoke_kernel(A, B, C, *args, **kwargs) if self.w2_output_quantizer.is_enabled: C[:] = self.w2_output_quantizer(C) else: raise ValueError("Cannot determine first or second layer of expert") def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - # This is again due to the bad coding of vLLM - # fused_moe submodule is overwritten by the fused_moe function - # so we need to import the fused_moe module explicitly - assert vllm_fused_moe_package.invoke_fused_moe_kernel is not None - # This context manager will conflict with torch.compile - # with replace_function( - # vllm_fused_moe_package, - # "invoke_fused_moe_kernel", - # self.invoke_fused_moe_quantized, - # ): - self._invoke_fused_moe_quantized = self.invoke_fused_moe_quantized - self.invoke_fused_moe_quantized = self.invoke_fused_moe_quantized - output = super().forward(hidden_states, router_logits) - self.invoke_fused_moe_quantized = self._invoke_fused_moe_quantized + hidden_states = self.w13_input_quantizer(hidden_states) + + # Temporarily patch kernel to apply quantization + self._original_invoke_kernel = vllm_fused_moe_package.invoke_fused_moe_kernel + vllm_fused_moe_package.invoke_fused_moe_kernel = self.invoke_fused_moe_quantized # type: ignore[attr-defined] + + try: + output = super().forward(hidden_states, router_logits) + finally: + vllm_fused_moe_package.invoke_fused_moe_kernel = self._original_invoke_kernel # type: ignore[attr-defined] + return output diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 6167daf23..39c95e27f 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -266,6 +266,21 @@ def is_quantized_row_parallel_linear(module): return is_quantized_linear(module) and getattr(module, "_is_row_parallel", False) +def is_vllm_fused_moe(module): + """Check if a module is a vLLM FusedMoE module.""" + from .nn import QuantModule, TensorQuantizer + + return ( + isinstance(module, QuantModule) + and hasattr(module, "w13_weight") + and hasattr(module, "w2_weight") + and isinstance(getattr(module, "w13_input_quantizer", None), TensorQuantizer) + and isinstance(getattr(module, "w2_input_quantizer", None), TensorQuantizer) + and hasattr(module, "w13_weight_quantizer") + and hasattr(module, "w2_weight_quantizer") + ) + + def is_quantized_parallel_linear(module): """Check if a module is a quantized parallel linear module.""" return is_quantized_column_parallel_linear(module) or is_quantized_row_parallel_linear(module)