diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index 9db819dfd..396f1fc11 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -1,11 +1,13 @@ from .base_weight import BaseWeight from .mm_weight import ( + MMWeightPack, MMWeightTpl, MultiMMWeightTpl, ROWMMWeight, COLMMWeight, MultiROWMMWeight, ROWBMMWeight, + AWQMultiMMWeightTpl, ) from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight from .fused_moe_weight_tp import FusedMoeWeightTP diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py index b67fc1b43..544dcb2fa 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py @@ -1,3 +1,4 @@ +from multiprocessing import parent_process import torch from abc import ABC, abstractmethod from typing import Dict @@ -14,7 +15,7 @@ def load_hf_weights(self, weights): @abstractmethod def verify_load(self): - pass + parent_process class BaseWeightTpl(BaseWeight): @@ -24,30 +25,8 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, data_type: to self.device_id_ = get_current_device_id() self.data_type_ = data_type - def _slice_weight(self, weight: torch.Tensor): - # slice weight - return weight.to(self.data_type_) - - def _slice_bias(self, bias: torch.Tensor): - # slice bias - return bias.to(self.data_type_) - - def _slice_weight_scale(self, weight_scale: torch.Tensor): - # slice weight scale and zero point - return weight_scale - - def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None: - # load weight - pass - - def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: - # load quantization scale - pass - def load_hf_weights(self, weights): - self._load_weights(weights) - self._load_scales(weights) - return + raise NotImplementedError("load_hf_weights must implement this method") def verify_load(self): - pass + raise NotImplementedError("verify_load must implement this method") diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py index 3e61178f3..ece9de8b8 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py @@ -7,7 +7,56 @@ from lightllm.common.quantization import Quantcfg -class FusedMoeWeightTP(BaseWeight): +class FusedMoeWeightTP: + def __new__( + cls, + gate_proj_name: str, + down_proj_name: str, + up_proj_name: str, + e_score_correction_bias_name: str, + weight_prefix: str, + n_routed_experts: int, + num_fused_shared_experts: int, + split_inter_size: int, + data_type: torch.dtype, + network_config: Dict[str, Any], + layer_num: int, + quant_cfg: Quantcfg = None, + ): + quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") + if quant_method is not None and quant_method.method_name == "awq_marlin": + return FusedAWQMARLINMoeWeightTP( + gate_proj_name=gate_proj_name, + down_proj_name=down_proj_name, + up_proj_name=up_proj_name, + e_score_correction_bias_name=e_score_correction_bias_name, + weight_prefix=weight_prefix, + n_routed_experts=n_routed_experts, + num_fused_shared_experts=num_fused_shared_experts, + split_inter_size=split_inter_size, + data_type=data_type, + network_config=network_config, + layer_num=layer_num, + quant_cfg=quant_cfg, + ) + else: + return FusedBaseMoeWeightTP( + gate_proj_name=gate_proj_name, + down_proj_name=down_proj_name, + up_proj_name=up_proj_name, + e_score_correction_bias_name=e_score_correction_bias_name, + weight_prefix=weight_prefix, + n_routed_experts=n_routed_experts, + num_fused_shared_experts=num_fused_shared_experts, + split_inter_size=split_inter_size, + data_type=data_type, + network_config=network_config, + layer_num=layer_num, + quant_cfg=quant_cfg, + ) + + +class FusedBaseMoeWeightTP(BaseWeight): def __init__( self, gate_proj_name: str, @@ -245,3 +294,374 @@ def _cuda(self, cpu_tensor): def verify_load(self): return self.w1 is not None and self.w2 is not None + + +class FusedAWQMARLINMoeWeightTP(BaseWeight): + def __init__( + self, + gate_proj_name: str, + down_proj_name: str, + up_proj_name: str, + e_score_correction_bias_name: str, + weight_prefix: str, + n_routed_experts: int, + num_fused_shared_experts: int, + split_inter_size: int, + data_type: torch.dtype, + network_config: Dict[str, Any], + layer_num: int, + quant_cfg: Quantcfg = None, + ) -> None: + super().__init__() + self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") + self.quantized_weight = quant_cfg.quantized_weight + if self.quant_method is not None: + self.weight_scale_suffix = self.quant_method.weight_scale_suffix + self.weight_zero_point_suffix = self.quant_method.weight_zero_point_suffix + self.quant_method.is_moe = True + hf_quantization_config = network_config.get("quantization_config", None) + self.num_bits = hf_quantization_config.get("bits", 4) + self.group_size = hf_quantization_config.get("group_size", 128) + self.pack_factor = 32 // self.num_bits + self.has_processed_weight = False + assert self.quant_method.method_name == "awq_marlin" + + self.w1_weight_name = gate_proj_name + self.w2_weight_name = down_proj_name + self.w3_weight_name = up_proj_name + + self.e_score_correction_bias_name = e_score_correction_bias_name + self.weight_prefix = weight_prefix + assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." + self.n_routed_experts = n_routed_experts + num_fused_shared_experts + self.num_fused_shared_experts = num_fused_shared_experts + self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) + self.split_inter_size = split_inter_size + self.data_type_ = data_type + self.tp_rank_ = get_current_rank_in_dp() + self.experts_up_projs = [None] * self.n_routed_experts + self.experts_gate_projs = [None] * self.n_routed_experts + self.experts_up_proj_scales = [None] * self.n_routed_experts + self.experts_up_proj_zero_points = [None] * self.n_routed_experts + self.experts_gate_proj_scales = [None] * self.n_routed_experts + self.experts_gate_proj_zero_points = [None] * self.n_routed_experts + self.e_score_correction_bias = None + self.w2_list = [None] * self.n_routed_experts + self.w2_scale_list = [None] * self.n_routed_experts + self.w2_zero_point_list = [None] * self.n_routed_experts + self.scoring_func = network_config.get("scoring_func", "softmax") + self.w1 = [None, None, None] # weight, weight_scale, zero_point + self.w2 = [None, None, None] # weight, weight_scale, zero_point + self.lock = threading.Lock() + + def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + from lightllm.common.fused_moe.topk_select import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=input_tensor, + router_logits=router_logits, + correction_bias=self.e_score_correction_bias, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=self.scoring_func, + ) + topk_weights.mul_(self.routed_scaling_factor) + if self.num_fused_shared_experts > 0: + pad_topk_ids = ( + torch.arange( + start=self.n_routed_experts - self.num_fused_shared_experts, + end=self.n_routed_experts, + step=1, + dtype=topk_ids.dtype, + device="cuda", + ) + .view(1, self.num_fused_shared_experts) + .repeat(topk_ids.shape[0], 1) + ) + pad_topk_weights = torch.full( + (topk_weights.shape[0], self.num_fused_shared_experts), + fill_value=1.0, + device="cuda", + dtype=topk_weights.dtype, + ) + + topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) + topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) + + w1, w1_scale, w1_zero_point = self.w1 + w2, w2_scale, w2_zero_point = self.w2 + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe + + fused_marlin_moe( + input_tensor, + w1, + w2, + None, + None, + w1_scale, + w2_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=self.quant_method.vllm_quant_type.id, + apply_router_weight_on_input=False, + global_num_experts=-1, + expert_map=None, + w1_zeros=w1_zero_point, + w2_zeros=w2_zero_point, + workspace=self.workspace, + inplace=True, + ) + + return + + def _fuse(self): + self._fuse_weight() + self._fuse_weight_scale() + self._fuse_weight_zero_point() + + def _fuse_weight(self): + with self.lock: + if ( + hasattr(self, "experts_up_projs") + and None not in self.experts_up_projs + and None not in self.experts_gate_projs + and None not in self.w2_list + ): + gate_in_dim, gate_out_dim = self.experts_gate_projs[0].shape + up_in_dim, up_out_dim = self.experts_up_projs[0].shape + assert gate_in_dim == up_in_dim + total_expert_num = self.n_routed_experts + + w1 = torch.empty( + (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=torch.int32, device="cpu" + ) + + for i_experts in range(self.n_routed_experts): + w1[i_experts, :, 0:gate_out_dim] = self.experts_gate_projs[i_experts] + w1[i_experts, :, gate_out_dim:] = self.experts_up_projs[i_experts] + + inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] + w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) + self.w1[0] = self._cuda(w1) + self.w2[0] = self._cuda(w2) + delattr(self, "w2_list") + delattr(self, "experts_up_projs") + delattr(self, "experts_gate_projs") + + def _fuse_weight_scale(self): + with self.lock: + if ( + hasattr(self, "experts_up_proj_scales") + and None not in self.experts_up_proj_scales + and None not in self.experts_gate_proj_scales + and None not in self.w2_scale_list + ): + gate_in_dim, gate_out_dim = self.experts_gate_proj_scales[0].shape + up_in_dim, up_out_dim = self.experts_up_proj_scales[0].shape + dtype = self.experts_gate_proj_scales[0].dtype + assert gate_in_dim == up_in_dim + total_expert_num = self.n_routed_experts + w1_scale = torch.empty( + (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=dtype, device="cpu" + ) + for i_experts in range(self.n_routed_experts): + w1_scale[i_experts, :, 0:gate_out_dim] = self.experts_gate_proj_scales[i_experts] + w1_scale[i_experts, :, gate_out_dim:] = self.experts_up_proj_scales[i_experts] + inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] + w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( + len(self.w2_scale_list), inter_shape, hidden_size + ) + self.w1[1] = self._cuda(w1_scale).to(self.data_type_) + self.w2[1] = self._cuda(w2_scale).to(self.data_type_) + delattr(self, "w2_scale_list") + delattr(self, "experts_up_proj_scales") + delattr(self, "experts_gate_proj_scales") + + def _fuse_weight_zero_point(self): + with self.lock: + if ( + hasattr(self, "experts_up_proj_zero_points") + and None not in self.experts_up_proj_zero_points + and None not in self.experts_gate_proj_zero_points + and None not in self.w2_zero_point_list + ): + gate_in_dim, gate_out_dim = self.experts_gate_proj_zero_points[0].shape + up_in_dim, up_out_dim = self.experts_up_proj_zero_points[0].shape + assert gate_in_dim == up_in_dim + total_expert_num = self.n_routed_experts + w1_zero_point = torch.empty( + (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=torch.int32, device="cpu" + ) + for i_experts in range(self.n_routed_experts): + w1_zero_point[i_experts, :, 0:gate_out_dim] = self.experts_gate_proj_zero_points[i_experts] + w1_zero_point[i_experts, :, gate_out_dim:] = self.experts_up_proj_zero_points[i_experts] + inter_shape, hidden_size = self.w2_zero_point_list[0].shape[0], self.w2_zero_point_list[0].shape[1] + w2_zero_point = torch._utils._flatten_dense_tensors(self.w2_zero_point_list).view( + len(self.w2_zero_point_list), inter_shape, hidden_size + ) + self.w1[2] = self._cuda(w1_zero_point) + self.w2[2] = self._cuda(w2_zero_point) + delattr(self, "w2_zero_point_list") + delattr(self, "experts_up_proj_zero_points") + delattr(self, "experts_gate_proj_zero_points") + + def load_hf_weights(self, weights): + self._load_weight(weights) + self._load_weight_scale(weights) + self._load_weight_zero_point(weights) + self._fuse() + self._process_weight_after_loading() + + def _load_weight(self, weights: Dict[str, torch.Tensor]) -> None: + # awq quantization weight shape: in x out + if self.e_score_correction_bias_name in weights: + self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name]) + for i_experts in range(self.n_routed_experts): + w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.qweight" + w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.qweight" + w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.qweight" + + if w1_weight in weights: + self.experts_gate_projs[i_experts] = weights[w1_weight][ + :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) + ] + if w3_weight in weights: + self.experts_up_projs[i_experts] = weights[w3_weight][ + :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) + ] + + if w2_weight in weights: + self.w2_list[i_experts] = weights[w2_weight][ + self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : + ] + + def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: + for i_experts in range(self.n_routed_experts): + w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}" + w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}" + w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}" + split_inter_size = self.split_inter_size * self.pack_factor + if w1_scale in weights: + self.experts_gate_proj_scales[i_experts] = weights[w1_scale][ + :, + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), + ] + if w3_scale in weights: + self.experts_up_proj_scales[i_experts] = weights[w3_scale][ + :, + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), + ] + + if w2_scale in weights: + self.w2_scale_list[i_experts] = weights[w2_scale][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), + :, + ] + + def _load_weight_zero_point(self, weights: Dict[str, torch.Tensor]) -> None: + for i_experts in range(self.n_routed_experts): + w1_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_zero_point_suffix}" + w2_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_zero_point_suffix}" + w3_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_zero_point_suffix}" + if w1_zero_point in weights: + self.experts_gate_proj_zero_points[i_experts] = weights[w1_zero_point][ + :, + self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), + ] + if w3_zero_point in weights: + self.experts_up_proj_zero_points[i_experts] = weights[w3_zero_point][ + :, + self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), + ] + if w2_zero_point in weights: + self.w2_zero_point_list[i_experts] = weights[w2_zero_point][ + self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), + :, + ] + + def _process_weight_after_loading(self): + with self.lock: + if None in self.w1 or None in self.w2 or self.has_processed_weight: + return + self.has_processed_weight = True + from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops + + assert HAS_VLLM, "moe awq marlin quantization requires kernels of vllm" + + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_moe_permute_scales, + moe_awq_to_marlin_zero_points, + marlin_make_workspace_new, + ) + + num_experts = self.n_routed_experts + device = self.w1[0].device + + self.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + self.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + self.w1[0] = vllm_ops.awq_marlin_moe_repack( + self.w1[0], + self.w13_g_idx_sort_indices, + size_k=self.w1[0].shape[1], + size_n=self.w1[0].shape[2] * self.pack_factor, + num_bits=self.num_bits, + ) + + self.w2[0] = vllm_ops.awq_marlin_moe_repack( + self.w2[0], + self.w2_g_idx_sort_indices, + size_k=self.w2[0].shape[1], + size_n=self.w2[0].shape[2] * self.pack_factor, + num_bits=self.num_bits, + ) + + # Why does this take the intermediate size for size_k? + self.w1[1] = marlin_moe_permute_scales( + s=self.w1[1], + size_k=self.split_inter_size * self.pack_factor, + size_n=self.w1[1].shape[2], + group_size=self.group_size, + ) + + self.w2[1] = marlin_moe_permute_scales( + s=self.w2[1], + size_k=self.split_inter_size * self.pack_factor, + size_n=self.w2[1].shape[2], + group_size=self.group_size, + ) + + self.w1[2] = moe_awq_to_marlin_zero_points( + self.w1[2], + size_k=self.w1[2].shape[1], + size_n=self.w1[2].shape[2] * self.pack_factor, + num_bits=self.num_bits, + ) + + self.w2[2] = moe_awq_to_marlin_zero_points( + self.w2[2], + size_k=self.w2[2].shape[1], + size_n=self.w2[2].shape[2] * self.pack_factor, + num_bits=self.num_bits, + ) + + self.workspace = marlin_make_workspace_new(device, 4) + + def _cuda(self, cpu_tensor): + device_id = get_current_device_id() + if self.quantized_weight: + return cpu_tensor.cuda(device_id) + return cpu_tensor.cuda(device_id) + + def verify_load(self): + return self.w1 is not None and self.w2 is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py index 263112435..ea343b41d 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py @@ -1,16 +1,13 @@ from .mm_weight import ( + MMWeightPack, MMWeightTpl, MultiMMWeightTpl, + AWQMultiMMWeightTpl, ) -from .rowmm_weight import ( +from .mm_factory import ( + MMWeight, ROWMMWeight, - ROWBMMWeight, MultiROWMMWeight, - W8A8B128ROWMMWeight, - W8A8B128ROWBMMWeight, - W8A8B128MultiROWMMWeight, -) -from .colmm_weight import ( + ROWBMMWeight, COLMMWeight, - W8A8B128COLMMWeight, ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py index d6d064cf4..1b4e3e815 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py @@ -1,25 +1,38 @@ import torch from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( - MMWeight, - MMWeightTpl, - generate_scale_name, + SingleMMWeightTpl, + DeepGemmFP8W8A8B128MMWeight, + AWQMMWeightTpl, ) from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod from typing import Dict, List, Optional +from .mm_slicer import ColSliceMixin, QuantizedRowSliceMixin, QuantizedColSliceMixin -class COLMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): - if quant_method is None or not quantized_weight: - return UnquantizedCOLMMWeight - else: - return W8A8B128COLMMWeight +class UnquantizedCOLMMWeight(SingleMMWeightTpl): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + bias_name: Optional[str] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__( + weight_name=weight_name, + data_type=data_type, + bias_name=bias_name, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) + self.param_slicer = ColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) -class UnquantizedCOLMMWeight(MMWeightTpl): +class DeepGemmFP8W8A8B128COLMMWeight(DeepGemmFP8W8A8B128MMWeight): def __init__( self, weight_name: str, @@ -29,24 +42,40 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: - super().__init__(data_type, quant_method, tp_rank, tp_world_size) - self.weight_name = weight_name - self.bias_name = bias_name - self.has_bias = bias_name is not None + super().__init__( + weight_name=weight_name, + data_type=data_type, + bias_name=bias_name, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) + self.param_slicer = QuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - def _slice_weight(self, tensor): - assert tensor.shape[1] % self.tp_world_size_ == 0, f"tp slice error {tensor.shape[1]} % {self.tp_world_size_}" - tp_size = tensor.shape[1] // self.tp_world_size_ - return tensor[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)].to(self.data_type_) - def _slice_bias(self, bias): - """ - 因为 Colmm 列 tp 切分的计算,最后会有一个 reduce 操作,直接将 bias / tp_world_size 可以节省一步计算。 - """ - return (bias / self.tp_world_size_).to(self.data_type_) +class AWQCOLMMWeight(AWQMMWeightTpl): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + bias_name: Optional[str] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__( + weight_name=weight_name, + data_type=data_type, + bias_name=bias_name, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) + # 注意这里不是错误,因为awq的weight是按inxout存的 + self.param_slicer = QuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) -class W8A8B128COLMMWeight(MMWeightTpl): +class AWQMARLINCOLMMWeight(AWQCOLMMWeight): def __init__( self, weight_name: str, @@ -56,44 +85,37 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: - super().__init__(data_type, quant_method, tp_rank, tp_world_size) - self.weight_name = weight_name - self.bias_name = bias_name - self.has_bias = bias_name is not None - - self.weight_scale_name, self.act_scale_name = generate_scale_name( - weight_name, quant_method.weight_scale_suffix, quant_method.act_scale_suffix + super().__init__( + weight_name=weight_name, + data_type=data_type, + bias_name=bias_name, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, ) - self.weight_scale: Optional[torch.Tensor] = None - self.block_size = self.quant_method.block_size - self.quantized_weight = True - def _slice_weight(self, tensor): - assert tensor.shape[1] % self.tp_world_size_ == 0, f"tp slice error {tensor.shape[1]} % {self.tp_world_size_}" - tp_size = tensor.shape[1] // self.tp_world_size_ - return tensor[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] - - def _slice_weight_scale(self, weight_scale: torch.Tensor): - assert ( - weight_scale.shape[1] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[1]} % {self.tp_world_size_}" - tp_size = weight_scale.shape[1] // self.tp_world_size_ - scale_start = tp_size * self.tp_rank_ - scale_end = tp_size * (self.tp_rank_ + 1) - return weight_scale[:, scale_start:scale_end].to(torch.float) + def _process_weight(self, weight: torch.Tensor) -> torch.Tensor: + new_weight = self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id())) + self.mm_param.weight = new_weight + return - def _process_weight_scale(self, weight_scale) -> None: - self.weight_scale = weight_scale.cuda(get_current_device_id()).transpose(0, 1) + def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: + new_weight_scale = self.quant_method._process_weight_scale_after_loading( + weight_scale.cuda(get_current_device_id()).to(self.data_type_) + ) + self.mm_param.weight_scale = new_weight_scale + return - def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: - if self.weight_scale_name in weights: - weight_scale = self._slice_weight_scale(weights[self.weight_scale_name]) - self._process_weight_scale(weight_scale) - if self.weight_scale is not None and isinstance(self.weight, torch.Tensor): - # weight 中保存的 None 是为 激活静态量化 scale 预留的扩展位置。 - self.weight = [ - self.weight, - self.weight_scale, - None, - ] + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + new_weight_zero_point = self.quant_method._process_weight_zero_point_after_loading( + weight_zero_point.cuda(get_current_device_id()) + ) + self.mm_param.weight_zero_point = new_weight_zero_point return + + +COLMM_WEIGHT_CLS_MAP = { + "deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128COLMMWeight, + "awq": AWQCOLMMWeight, + "awq_marlin": AWQMARLINCOLMMWeight, +} diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py new file mode 100644 index 000000000..a6486bfa8 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py @@ -0,0 +1,83 @@ +from lightllm.common.quantization import Quantcfg +from lightllm.common.quantization.quantize_method import QuantizationMethod +from typing import Type, Union, Dict +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( + MMWeightTpl, + MultiMMWeightTpl, + BMMWeightTpl, +) +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import ( + UnquantizedROWMMWeight, + UnquantizedROWBMMWeight, + UnquantizedMultiROWMMWeight, + ROWMM_WEIGHT_CLS_MAP, + MULTI_ROWMM_WEIGHT_CLS_MAP, +) +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.colmm_weight import ( + UnquantizedCOLMMWeight, + COLMM_WEIGHT_CLS_MAP, +) + + +class MMWeight: + def __new__(cls, **kwargs): + quant_cfg = kwargs.pop("quant_cfg", None) + layer_num_ = kwargs.pop("layer_num", None) + name = kwargs.pop("name", None) + quant_method, quantized_weight = cls._get_quant_method(quant_cfg, layer_num_, name) + kwargs["quant_method"] = quant_method + mmcls = cls._get_mmcls(quant_method, quantized_weight) + return mmcls(**kwargs) + + @classmethod + def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> QuantizationMethod: + if quant_cfg is None: + return None, False + quant_method = quant_cfg.get_quant_method(layer_num_, name) + if quant_method is None: + return None, False + quant_method.hf_quantization_config = quant_cfg.hf_quantization_config + quantized_weight = quant_cfg.quantized_weight + return quant_method, quantized_weight + + @classmethod + def _get_mmcls( + cls, quant_method: QuantizationMethod, quantized_weight: bool + ) -> Type[Union[MMWeightTpl, MultiMMWeightTpl, BMMWeightTpl]]: + raise NotImplementedError("Subclasses must implement _get_mmcls method") + + +class ROWMMWeight(MMWeight): + @classmethod + def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): + if quant_method is None or not quantized_weight: + return UnquantizedROWMMWeight + + return ROWMM_WEIGHT_CLS_MAP[quant_method.method_name] + + +class MultiROWMMWeight(MMWeight): + @classmethod + def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): + if quant_method is None or not quantized_weight: + return UnquantizedMultiROWMMWeight + + return MULTI_ROWMM_WEIGHT_CLS_MAP[quant_method.method_name] + + +class ROWBMMWeight(MMWeight): + @classmethod + def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): + if quant_method is None or not quantized_weight: + return UnquantizedROWBMMWeight + else: + # TODO: Implement more quantization weight + raise NotImplementedError("ROWBMMWeight is not implemented") + + +class COLMMWeight(MMWeight): + @classmethod + def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): + if quant_method is None or not quantized_weight: + return UnquantizedCOLMMWeight + return COLMM_WEIGHT_CLS_MAP[quant_method.method_name] diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py new file mode 100644 index 000000000..6c90deaa7 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -0,0 +1,117 @@ +import torch +from typing import Optional +from abc import ABC, abstractmethod +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size + + +class SliceMixinBase(ABC): + """切片操作的Mixin基类""" + + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + + @abstractmethod + def _slice_weight(self, weight: torch.Tensor): + pass + + @abstractmethod + def _slice_bias(self, bias): + pass + + +class SliceMixinTpl(SliceMixinBase): + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) + + def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("slice_weight must implement this method") + + def _slice_bias(self, bias) -> torch.Tensor: + raise NotImplementedError("slice_bias must implement this method") + + def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("slice_weight_scale must implement this method") + + def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("slice_weight_zero_point must implement this method") + + +# 默认weight 的shape是 outxin,这也是目前最通用的约定。 +# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。 +class RowSliceMixin(SliceMixinTpl): + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) + + def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: + assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}" + tp_size = weight.shape[0] // self.tp_world_size_ + return weight[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + + def _slice_bias(self, bias) -> torch.Tensor: + assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" + tp_size = bias.shape[0] // self.tp_world_size_ + return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + + +# 量化切片默认实现方式是group-wise的量化,所以weight_scale 和weight_zero_point ndims跟weight一样。 +# 后续按需要,扩展per-tensor、per-channel的量化方式。 +class QuantizedRowSliceMixin(RowSliceMixin): + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) + + def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: + assert ( + weight_scale.shape[0] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_scale.shape[0]} % {self.tp_world_size_}" + tp_size = weight_scale.shape[0] // self.tp_world_size_ + scale_start = tp_size * self.tp_rank_ + scale_end = tp_size * (self.tp_rank_ + 1) + return weight_scale[scale_start:scale_end] + + def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + assert ( + weight_zero_point.shape[0] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[0]} % {self.tp_world_size_}" + tp_size = weight_zero_point.shape[0] // self.tp_world_size_ + zero_point_start = tp_size * self.tp_rank_ + zero_point_end = tp_size * (self.tp_rank_ + 1) + return weight_zero_point[zero_point_start:zero_point_end] + + +class ColSliceMixin(SliceMixinTpl): + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) + + def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: + assert weight.shape[1] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[1]} % {self.tp_world_size_}" + tp_size = weight.shape[1] // self.tp_world_size_ + return weight[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + + def _slice_bias(self, bias) -> torch.Tensor: + assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" + tp_size = bias.shape[0] // self.tp_world_size_ + return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + + +class QuantizedColSliceMixin(ColSliceMixin): + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) + + def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: + assert ( + weight_scale.shape[1] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_scale.shape[1]} % {self.tp_world_size_}" + tp_size = weight_scale.shape[1] // self.tp_world_size_ + scale_start = tp_size * self.tp_rank_ + scale_end = tp_size * (self.tp_rank_ + 1) + return weight_scale[:, scale_start:scale_end] + + def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + assert ( + weight_zero_point.shape[1] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[1]} % {self.tp_world_size_}" + tp_size = weight_zero_point.shape[1] // self.tp_world_size_ + zero_point_start = tp_size * self.tp_rank_ + zero_point_end = tp_size * (self.tp_rank_ + 1) + return weight_zero_point[:, zero_point_start:zero_point_end] diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 79787bb18..a98c47f52 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -1,6 +1,7 @@ import os import torch from abc import abstractmethod +from dataclasses import dataclass from typing import Optional, Tuple, List, Dict, Union, Type from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.common.quantization.quantize_method import QuantizationMethod @@ -8,18 +9,29 @@ from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.log_utils import init_logger +from .mm_slicer import SliceMixinTpl logger = init_logger(__name__) -def generate_scale_name(name, weight_scale_suffix, act_scale_suffix): - weight_scale_name = None - act_scale_name = None - if weight_scale_suffix is not None: - weight_scale_name = ".".join(name.split(".")[:-1] + [weight_scale_suffix]) - if act_scale_suffix is not None: - act_scale_name = ".".join(name.split(".")[:-1] + [act_scale_suffix]) - return weight_scale_name, act_scale_name +@dataclass +class MMWeightPack: + weight: Optional[torch.Tensor] = None + bias: Optional[torch.Tensor] = None + weight_scale: Optional[torch.Tensor] = None + weight_zero_point: Optional[torch.Tensor] = None + + has_bias: bool = False + has_weight_scale: bool = False + has_weight_zero_point: bool = False + + def is_ready(self) -> bool: + return ( + self.weight is not None + and (not self.has_bias or (self.has_bias and self.bias is not None)) + and (not self.has_weight_scale or (self.has_weight_scale and self.weight_scale is not None)) + and (not self.has_weight_zero_point or (self.has_weight_zero_point and self.weight_zero_point is not None)) + ) class MMWeightTpl(BaseWeightTpl): @@ -29,59 +41,125 @@ def __init__( quant_method: QuantizationMethod = None, tp_rank: int = None, tp_world_size: int = None, + has_bias: bool = False, + has_weight_scale: bool = False, + has_weight_zero_point: bool = False, ) -> None: super().__init__(tp_rank, tp_world_size, data_type) self.quant_method = quant_method - self.weight: Optional[torch.Tensor] = None - self.bias: Optional[torch.Tensor] = None - # quantized_weight 用于标记加载的权重是已经量化的权重格式 - # 不需要做在线量化 - self.quantized_weight: bool = False - # 标记是否存在 bias, 由子类初始化 - self.has_bias: bool = None + self.mm_param: MMWeightPack = MMWeightPack( + has_bias=has_bias, + has_weight_scale=has_weight_scale, + has_weight_zero_point=has_weight_zero_point, + ) + self.param_slicer: SliceMixinTpl = None def mm( self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True ) -> torch.Tensor: if self.quant_method is not None: return self.quant_method.apply( - input_tensor, self.weight, self.bias, out, use_custom_tensor_mananger=use_custom_tensor_mananger + input_tensor, self.mm_param, out, use_custom_tensor_mananger=use_custom_tensor_mananger ) if out is None: - shape = (input_tensor.shape[0], self.weight.shape[1]) + shape = (input_tensor.shape[0], self.mm_param.weight.shape[1]) dtype = input_tensor.dtype device = input_tensor.device if use_custom_tensor_mananger: out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) else: out = torch.empty(shape, dtype=dtype, device=device) - if self.bias is None: - return torch.mm(input_tensor, self.weight, out=out) - return torch.addmm(self.bias, input_tensor, self.weight, out=out) + if self.mm_param.bias is None: + return torch.mm(input_tensor, self.mm_param.weight, out=out) + return torch.addmm(self.mm_param.bias, input_tensor, self.mm_param.weight, out=out) + + def load_hf_weights(self, weights): + raise NotImplementedError("load_hf_weights must implement this method") def verify_load(self) -> bool: - load_ok = True - # Verify weight. The weight must be not None. - load_ok = load_ok and self.weight is not None - # Verify bias. If bias_name is set, it must be not None. - if self.has_bias: - load_ok = load_ok and self.bias is not None - return load_ok + return self.mm_param.is_ready() - def _process_weight(self, weight) -> None: - if self.quant_method is not None and not self.quantized_weight: - self.weight = self.quant_method.quantize(weight.to(self.data_type_).cuda(get_current_device_id())) + def _process_weight(self, weight: torch.Tensor) -> None: + # 由于所有的量化算法,都会产生一个scale,所以只要没有scale,就说明需要在线对weight进行量化 + if self.quant_method is not None and not self.mm_param.has_weight_scale: + quantized_weight, weight_scale, weight_zero_point = self.quant_method.quantize( + weight.to(self.data_type_).cuda(get_current_device_id()) + ) + self.mm_param.weight = quantized_weight + self.mm_param.weight_scale = weight_scale + self.mm_param.weight_zero_point = weight_zero_point return # 让 k dim 更连续,大多数split k 算法的算子可能能更快 - self.weight = weight.cuda(get_current_device_id()).transpose(0, 1) + self.mm_param.weight = weight.to(self.data_type_).cuda(get_current_device_id()).transpose(0, 1) + return + + def _process_bias(self, bias: torch.Tensor) -> None: + self.mm_param.bias = bias.to(self.data_type_).cuda(get_current_device_id()) + return + + def _process_weight_scale(self, weight_scale: torch.Tensor) -> None: + raise NotImplementedError("process_weight_scale must implement this method") - def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None: + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> None: + raise NotImplementedError("process_weight_zero_point must implement this method") + + def _load_weight(self, weights: Dict[str, torch.Tensor]) -> None: + raise NotImplementedError("load_weight_scale must implement this method") + + def _load_bias(self, weights: Dict[str, torch.Tensor]) -> None: + raise NotImplementedError("load_bias must implement this method") + + def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: + raise NotImplementedError("load_weight_scale must implement this method") + + def _load_weight_zero_point(self, weights: Dict[str, torch.Tensor]) -> None: + raise NotImplementedError("load_weight_zero_point must implement this method") + + def _fuse_weights(self, dim: int = 0) -> None: + raise NotImplementedError("fuse_weights must implement this method") + + +class SingleMMWeightTpl(MMWeightTpl): + def __init__( + self, + weight_name: str, + bias_name: Optional[str] = None, + data_type: torch.dtype = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + has_weight_scale: bool = False, + has_weight_zero_point: bool = False, + ) -> None: + super().__init__( + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + has_bias=bias_name is not None, + has_weight_scale=has_weight_scale, + has_weight_zero_point=has_weight_zero_point, + ) + self.weight_name = weight_name + self.bias_name = bias_name + return + + def _load_weight(self, weights: Dict[str, torch.Tensor]) -> None: if self.weight_name in weights: - weight = self._slice_weight(weights[self.weight_name]) + weight = weights[self.weight_name] + weight = self.param_slicer._slice_weight(weight) self._process_weight(weight) + return + def _load_bias(self, weights: Dict[str, torch.Tensor]) -> None: if self.bias_name in weights: - self.bias = self._slice_bias(weights[self.bias_name]).cuda(get_current_device_id()) + bias = self.param_slicer._slice_bias(weights[self.bias_name]) + self._process_bias(bias) + return + + def load_hf_weights(self, weights): + self._load_weight(weights) + self._load_bias(weights) return @@ -89,54 +167,77 @@ class MultiMMWeightTpl(MMWeightTpl): def __init__( self, weight_names: List[str], - data_type: torch.dtype, bias_names: Optional[List[str]] = None, + data_type: torch.dtype = None, quant_method: QuantizationMethod = None, tp_rank: int = None, tp_world_size: int = None, + has_weight_scale: bool = False, + has_weight_zero_point: bool = False, ) -> None: - super().__init__(data_type, quant_method, tp_rank, tp_world_size) - + has_bias = bias_names is not None and any(b is not None for b in bias_names) + super().__init__( + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + has_bias=has_bias, + has_weight_scale=has_weight_scale, + has_weight_zero_point=has_weight_zero_point, + ) self.weight_names = weight_names - self.bias_names = bias_names - self.weights = [None] * len(self.weight_names) - if self.bias_names is not None: - self.biases = [None] * len(self.bias_names) - self.has_bias = all(b is not None for b in self.bias_names) and len(bias_names) > 0 - else: - self.biases = None - self.has_bias = False - - def _pre_porcess_weights(self, weights: Dict[str, torch.Tensor]) -> None: + self.bias_names = bias_names if bias_names is not None else [] + self.mm_params: List[MMWeightPack] = [ + MMWeightPack( + weight=None, + bias=None, + weight_scale=None, + weight_zero_point=None, + has_bias=has_bias, + has_weight_scale=has_weight_scale, + has_weight_zero_point=has_weight_zero_point, + ) + for _ in range(len(weight_names)) + ] + + def _load_weight(self, weights: Dict[str, torch.Tensor]) -> None: for i in range(len(self.weight_names)): if self.weight_names[i] in weights: - weight = weights[self.weight_names[i]] - self.weights[i] = self._slice_weight(weight) - if self.has_bias and self.bias_names[i] in weights: - bias = weights[self.bias_names[i]] - self.biases[i] = self._slice_bias(bias) - - def _fuse_weights(self) -> None: - if self.weight is None and (None not in self.weights): - weight = torch.cat(self.weights, dim=0) - self._process_weight(weight) - delattr(self, "weights") + weight_i = weights[self.weight_names[i]] + weight_i = self.param_slicer._slice_weight(weight_i) + self.mm_params[i].weight = weight_i + return - if self.has_bias and self.bias is None and (None not in self.biases): - self.bias = torch.cat(self.biases, dim=0).cuda(get_current_device_id()) - delattr(self, "biases") - return self + def _load_bias(self, weights: Dict[str, torch.Tensor]) -> None: + for i in range(len(self.bias_names)): + if self.bias_names[i] in weights: + bias_i = weights[self.bias_names[i]] + bias_i = self.param_slicer._slice_bias(bias_i) + self.mm_params[i].bias = bias_i.to(self.data_type_) + return - def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None: - self._pre_porcess_weights(weights) + def _fuse_weights(self, dim: int = 0) -> None: + if self.mm_param.weight is None and all(p.weight is not None for p in self.mm_params): + weight = torch.cat([p.weight for p in self.mm_params], dim=dim) + self._process_weight(weight) + for p in self.mm_params: + p.weight = None + + if self.mm_param.has_bias and self.mm_param.bias is None and all(p.bias is not None for p in self.mm_params): + bias = torch.cat([p.bias for p in self.mm_params], dim=dim) + self._process_bias(bias) + for p in self.mm_params: + p.bias = None + return def load_hf_weights(self, weights): - super().load_hf_weights(weights) - self._fuse_weights() + self._load_weight(weights) + self._load_bias(weights) + self._fuse_weights(dim=0) return -class BMMWeightTpl(MMWeightTpl): +class BMMWeightTpl(SingleMMWeightTpl): def mm( self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True ) -> torch.Tensor: @@ -146,7 +247,7 @@ def bmm( self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True ) -> torch.Tensor: # 目前 bmm 不支持量化运算操作 - fpweight = self.weight + fpweight = self.mm_param.weight if out is None: shape = (input_tensor.shape[0], input_tensor.shape[1], fpweight.shape[2]) dtype = input_tensor.dtype @@ -155,34 +256,304 @@ def bmm( out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) else: out = torch.empty(shape, dtype=dtype, device=device) - if self.bias is None: + if self.mm_param.bias is None: return torch.bmm(input_tensor, fpweight, out=out) - return torch.addbmm(self.bias, input_tensor, fpweight, out=out) + return torch.addbmm(self.mm_param.bias, input_tensor, fpweight, out=out) + + def _process_weight(self, weight) -> None: + self.mm_param.weight = weight.cuda(get_current_device_id()) + + +class SingleQuantizedMMWeightTpl(SingleMMWeightTpl): + def __init__( + self, + weight_name: str, + bias_name: Optional[str] = None, + data_type: torch.dtype = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + has_weight_scale: bool = True, + has_weight_zero_point: bool = False, # 目前较多的是对称量化,所以默认没有zero_point + ) -> None: + super().__init__( + weight_name=weight_name, + bias_name=bias_name, + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + has_weight_scale=has_weight_scale, + has_weight_zero_point=has_weight_zero_point, + ) + assert quant_method is not None, "quant_method is not set" + assert quant_method.weight_scale_suffix is not None, "weight_scale_suffix is not set" + self.weight_scale_name = weight_name.replace("weight", quant_method.weight_scale_suffix) + if has_weight_zero_point: + assert quant_method.weight_zero_point_suffix is not None, "weight_zero_point_suffix is not set" + self.weight_zero_point_name = weight_name.replace("weight", quant_method.weight_zero_point_suffix) + if quant_method.weight_suffix is not None: + self.weight_name = weight_name.replace("weight", quant_method.weight_suffix) + return + + def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: + if self.weight_scale_name is not None and self.weight_scale_name in weights: + weight_scale = weights[self.weight_scale_name] + weight_scale = self.param_slicer._slice_weight_scale(weight_scale) + self._process_weight_scale(weight_scale) + + def _load_weight_zero_point(self, weights: Dict[str, torch.Tensor]) -> None: + if self.mm_param.has_weight_zero_point and self.weight_zero_point_name in weights: + weight_zero_point = weights[self.weight_zero_point_name] + weight_zero_point = self.param_slicer._slice_weight_zero_point(weight_zero_point) + self._process_weight_zero_point(weight_zero_point) + + def load_hf_weights(self, weights): + self._load_weight(weights) + self._load_bias(weights) + self._load_weight_scale(weights) + self._load_weight_zero_point(weights) + return + + # 不同的量化算法,往往需要不同的处理方式,所以强制要求实现这些方法 + def _process_weight(self, weight: torch.Tensor) -> None: + raise NotImplementedError("Quantized weight process_weight must implement this method") + + def _process_weight_scale(self, weight_scale: torch.Tensor) -> None: + raise NotImplementedError("Quantized weight process_weight_scale must implement this method") + + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> None: + raise NotImplementedError("Quantized weight process_weight_zero_point must implement this method") + + +class MultiQuantizedMMWeightTpl(MultiMMWeightTpl): + def __init__( + self, + weight_names: List[str], + bias_names: Optional[List[str]] = None, + data_type: torch.dtype = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + has_weight_scale: bool = True, + has_weight_zero_point: bool = False, + ) -> None: + super().__init__( + weight_names=weight_names, + bias_names=bias_names, + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + has_weight_scale=has_weight_scale, + has_weight_zero_point=has_weight_zero_point, + ) + assert quant_method is not None, "quant_method is not set" + assert quant_method.weight_scale_suffix is not None, "weight_scale_suffix is not set" + self.weight_scale_names = [ + weight_name.replace("weight", quant_method.weight_scale_suffix) for weight_name in weight_names + ] + if has_weight_zero_point: + assert quant_method.weight_zero_point_suffix is not None, "weight_zero_point_suffix is not set" + self.weight_zero_point_names = [ + weight_name.replace("weight", quant_method.weight_zero_point_suffix) for weight_name in weight_names + ] + if quant_method.weight_suffix is not None: + self.weight_names = [ + weight_name.replace("weight", quant_method.weight_suffix) for weight_name in weight_names + ] + return + + def _load_weight(self, weights: Dict[str, torch.Tensor]) -> None: + for i in range(len(self.weight_names)): + if self.weight_names[i] in weights: + weight = weights[self.weight_names[i]] + weight = self.param_slicer._slice_weight(weight) + self.mm_params[i].weight = weight + + def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: + for i in range(len(self.weight_names)): + if self.weight_scale_names[i] is not None and self.weight_scale_names[i] in weights: + weight_scale = weights[self.weight_scale_names[i]] + weight_scale = self.param_slicer._slice_weight_scale(weight_scale) + self.mm_params[i].weight_scale = weight_scale.to(self.data_type_) + + def _load_weight_zero_point(self, weights: Dict[str, torch.Tensor]) -> None: + for i in range(len(self.weight_names)): + if self.mm_params[i].has_weight_zero_point and self.weight_zero_point_names[i] in weights: + weight_zero_point = weights[self.weight_zero_point_names[i]] + weight_zero_point = self.param_slicer._slice_weight_zero_point(weight_zero_point) + self.mm_params[i].weight_zero_point = weight_zero_point + return + + def _fuse_weights(self, dim: int = 0) -> None: + super()._fuse_weights(dim=dim) + if self.mm_param.weight_scale is None and (None not in [p.weight_scale for p in self.mm_params]): + # awq 保存的量化参数,weight shape 是 in x out。所以这里的cat dim 是 1 + weight_scale = torch.cat([p.weight_scale for p in self.mm_params], dim=dim).cuda(get_current_device_id()) + self._process_weight_scale(weight_scale) + for p in self.mm_params: + p.weight_scale = None + + if self.mm_param.weight_zero_point is None and (None not in [p.weight_zero_point for p in self.mm_params]): + weight_zero_point = torch.cat([p.weight_zero_point for p in self.mm_params], dim=dim) + self._process_weight_zero_point(weight_zero_point) + for p in self.mm_params: + p.weight_zero_point = None + torch.cuda.empty_cache() + return + + # 不同的量化算法,往往需要不同的处理方式,所以强制要求实现这些方法 + def _process_weight(self, weight: torch.Tensor) -> None: + raise NotImplementedError("Quantized weight process_weight must implement this method") + + def _process_weight_scale(self, weight_scale: torch.Tensor) -> None: + raise NotImplementedError("Quantized weight process_weight_scale must implement this method") + + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> None: + raise NotImplementedError("Quantized weight process_weight_zero_point must implement this method") + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: + self._load_weight(weights) + self._load_bias(weights) + self._load_weight_scale(weights) + self._load_weight_zero_point(weights) + self._fuse_weights(dim=0) + return + + +class DeepGemmFP8W8A8B128MMWeight(SingleQuantizedMMWeightTpl): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + bias_name: Optional[str] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__( + weight_name=weight_name, + bias_name=bias_name, + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + has_weight_scale=True, + has_weight_zero_point=False, + ) + + def _process_weight_scale(self, weight_scale) -> None: + self.mm_param.weight_scale = weight_scale.to(torch.float).cuda(get_current_device_id()).transpose(0, 1) + return def _process_weight(self, weight) -> None: - self.weight = weight.cuda(get_current_device_id()) - - -class MMWeight: - def __new__(cls, **kwargs): - quant_cfg = kwargs.pop("quant_cfg", None) - layer_num_ = kwargs.pop("layer_num", None) - name = kwargs.pop("name", None) - quant_method, quantized_weight = cls._get_quant_method(quant_cfg, layer_num_, name) - kwargs["quant_method"] = quant_method - mmcls = cls._get_mmcls(quant_method, quantized_weight) - return mmcls(**kwargs) - - @classmethod - def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> QuantizationMethod: - if quant_cfg is None: - return None, False - quant_method = quant_cfg.get_quant_method(layer_num_, name) - quantized_weight = quant_cfg.quantized_weight - return quant_method, quantized_weight - - @classmethod - def _get_mmcls( - cls, quant_method: QuantizationMethod, quantized_weight: bool - ) -> Type[Union[MMWeightTpl, MultiMMWeightTpl, BMMWeightTpl]]: - raise NotImplementedError("Subclasses must implement _get_mmcls method") + self.mm_param.weight = weight.cuda(get_current_device_id()).transpose(0, 1) + return + + +class DeepGemmFP8W8A8B128MultiMMWeight(MultiQuantizedMMWeightTpl): + def __init__( + self, + weight_names: str, + data_type: torch.dtype, + bias_names: Optional[str] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__( + weight_names=weight_names, + bias_names=bias_names, + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + has_weight_scale=True, + has_weight_zero_point=False, + ) + + def _process_weight_scale(self, weight_scale) -> None: + self.mm_param.weight_scale = weight_scale.to(torch.float).cuda(get_current_device_id()).transpose(0, 1) + return + + def _process_weight(self, weight) -> None: + self.mm_param.weight = weight.cuda(get_current_device_id()).transpose(0, 1) + return + + +class AWQMMWeightTpl(SingleQuantizedMMWeightTpl): + def __init__( + self, + weight_name: str, + bias_name: Optional[str] = None, + data_type: torch.dtype = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__( + weight_name=weight_name, + bias_name=bias_name, + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + has_weight_scale=True, + has_weight_zero_point=True, + ) + + def _process_weight(self, weight: torch.Tensor) -> None: + self.mm_param.weight = weight.cuda(get_current_device_id()) + return + + def _process_weight_scale(self, weight_scale: torch.Tensor) -> None: + self.mm_param.weight_scale = weight_scale.to(self.data_type_).cuda(get_current_device_id()) + return + + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> None: + self.mm_param.weight_zero_point = weight_zero_point.cuda(get_current_device_id()) + return + + +class AWQMultiMMWeightTpl(MultiQuantizedMMWeightTpl): + def __init__( + self, + weight_names: List[str], + bias_names: Optional[List[str]] = None, + data_type: torch.dtype = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__( + weight_names=weight_names, + bias_names=bias_names, + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + has_weight_scale=True, + has_weight_zero_point=True, + ) + + def _process_weight(self, weight: torch.Tensor) -> None: + self.mm_param.weight = weight.cuda(get_current_device_id()) + return + + def _process_weight_scale(self, weight_scale: torch.Tensor) -> None: + self.mm_param.weight_scale = weight_scale.to(self.data_type_).cuda(get_current_device_id()) + return + + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> None: + self.mm_param.weight_zero_point = weight_zero_point.cuda(get_current_device_id()) + return + + def load_hf_weights(self, weights): + self._load_weight(weights) + self._load_bias(weights) + self._load_weight_scale(weights) + self._load_weight_zero_point(weights) + # 由于awq的储存格式是inxout,所以拼接dim是 1 + self._fuse_weights(dim=1) + return diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index c90d7c1a3..599162c64 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -1,51 +1,21 @@ import torch from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( - MMWeight, - MMWeightTpl, - BMMWeightTpl, + SingleMMWeightTpl, MultiMMWeightTpl, - generate_scale_name, + DeepGemmFP8W8A8B128MMWeight, + DeepGemmFP8W8A8B128MultiMMWeight, + AWQMMWeightTpl, + AWQMultiMMWeightTpl, + BMMWeightTpl, ) from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod from typing import Dict, List, Optional +from .mm_slicer import RowSliceMixin, QuantizedRowSliceMixin, QuantizedColSliceMixin -class ROWMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): - if quant_method is None or not quantized_weight: - return UnquantizedROWMMWeight - else: - return W8A8B128ROWMMWeight - # TODO: Implement more quantization weight - return None - - -class MultiROWMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): - if quant_method is None or not quantized_weight: - return UnquantizedMultiROWMMWeight - else: - return W8A8B128MultiROWMMWeight - # TODO: Implement more quantization weight - return None - - -class ROWBMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): - if quant_method is None or not quantized_weight: - return UnquantizedROWBMMWeight - else: - return W8A8B128ROWBMMWeight - # TODO: Implement more quantization weight - return None - - -class UnquantizedROWMMWeight(MMWeightTpl): +class UnquantizedROWMMWeight(SingleMMWeightTpl): def __init__( self, weight_name: str, @@ -55,23 +25,39 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: - self.weight_name = weight_name - self.bias_name = bias_name - self.has_bias = bias_name is not None - super().__init__(data_type, quant_method, tp_rank, tp_world_size) + super().__init__( + weight_name=weight_name, + bias_name=bias_name, + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) + self.param_slicer = RowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - def _slice_weight(self, weight: torch.Tensor): - assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}" - tp_size = weight.shape[0] // self.tp_world_size_ - return weight[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)].to(self.data_type_) - def _slice_bias(self, bias): - assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" - tp_size = bias.shape[0] // self.tp_world_size_ - return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)].to(self.data_type_) +class UnquantizedMultiROWMMWeight(MultiMMWeightTpl): + def __init__( + self, + weight_names: str, + data_type: torch.dtype, + bias_names: Optional[str] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__( + weight_names=weight_names, + data_type=data_type, + bias_names=bias_names, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) + self.param_slicer = RowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) -class W8A8B128ROWMMWeight(UnquantizedROWMMWeight): +class DeepGemmFP8W8A8B128ROWMMWeight(DeepGemmFP8W8A8B128MMWeight): def __init__( self, weight_name: str, @@ -81,58 +67,19 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: - super().__init__(weight_name, data_type, bias_name, quant_method, tp_rank, tp_world_size) - - self.weight_scale_name, _ = generate_scale_name( - weight_name, quant_method.weight_scale_suffix, quant_method.act_scale_suffix + super().__init__( + weight_name=weight_name, + data_type=data_type, + bias_name=bias_name, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, ) - self.weight_scale: Optional[torch.Tensor] = None - self.quantized_weight = True - - def _slice_weight(self, weight: torch.Tensor): - assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}" - tp_size = weight.shape[0] // self.tp_world_size_ - return weight[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] - - def _slice_bias(self, bias): - assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" - tp_size = bias.shape[0] // self.tp_world_size_ - return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] - - def _slice_weight_scale(self, weight_scale: torch.Tensor): - assert ( - weight_scale.shape[0] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[0]} % {self.tp_world_size_}" - tp_size = weight_scale.shape[0] // self.tp_world_size_ - scale_start = tp_size * self.tp_rank_ - scale_end = tp_size * (self.tp_rank_ + 1) - return weight_scale.to(torch.float)[scale_start:scale_end] - - def _process_weight_scale(self, weight_scale) -> None: - self.weight_scale = weight_scale.cuda(get_current_device_id()).transpose(0, 1) - - def _process_weight(self, weight) -> None: - self.weight = weight.cuda(get_current_device_id()).transpose(0, 1) - - def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: - if self.weight_scale_name in weights: - weight_scale = weights[self.weight_scale_name] - weight_scale = self._slice_weight_scale(weight_scale) - self._process_weight_scale(weight_scale) - - if self.weight_scale is not None and isinstance(self.weight, torch.Tensor): - self.weight = [ - self.weight, - self.weight_scale, - None, # placeholder for input scale - ] + self.param_slicer = QuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) return -class UnquantizedMultiROWMMWeight(MultiMMWeightTpl): - _slice_weight = UnquantizedROWMMWeight._slice_weight - _slice_bias = UnquantizedROWMMWeight._slice_bias - +class DeepGemmFP8W8A8B128MultiROWMMWeight(DeepGemmFP8W8A8B128MultiMMWeight): def __init__( self, weight_names: str, @@ -142,66 +89,83 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: - super().__init__(weight_names, data_type, bias_names, quant_method, tp_rank, tp_world_size) - + super().__init__( + weight_names=weight_names, + data_type=data_type, + bias_names=bias_names, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) + self.param_slicer = QuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) -class W8A8B128MultiROWMMWeight(UnquantizedMultiROWMMWeight): - _slice_weight = W8A8B128ROWMMWeight._slice_weight - _slice_bias = W8A8B128ROWMMWeight._slice_bias - _slice_weight_scale = W8A8B128ROWMMWeight._slice_weight_scale +class UnquantizedROWBMMWeight(BMMWeightTpl): def __init__( self, - weight_names: str, + weight_name: str, data_type: torch.dtype, - bias_names: Optional[str] = None, + bias_name: Optional[str] = None, quant_method: QuantizationMethod = None, tp_rank: int = None, tp_world_size: int = None, ) -> None: - super().__init__(weight_names, data_type, bias_names, quant_method, tp_rank, tp_world_size) - self.weight_scale_names = [] - self.weight_scale: Optional[torch.Tensor] = None - self.weight_scales = [None] * len(self.weight_names) - for weight_name in weight_names: - weight_scale_name, act_scale_name = generate_scale_name( - weight_name, quant_method.weight_scale_suffix, quant_method.act_scale_suffix - ) - self.weight_scale_names.append(weight_scale_name) - self.quantized_weight = True - - def _load_scales(self, weights): - for i in range(len(self.weight_names)): - if self.weight_scale_names[i] in weights: - weight_scale = weights[self.weight_scale_names[i]] - weight_scale = self._slice_weight_scale(weight_scale) - self.weight_scales[i] = weight_scale - - def _process_weight_scale(self, weight_scale) -> None: - self.weight_scale = weight_scale.cuda(get_current_device_id()).transpose(0, 1) + super().__init__( + weight_name=weight_name, + data_type=data_type, + bias_name=bias_name, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) + self.param_slicer = RowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - def _process_weight(self, weight) -> None: - self.weight = weight.cuda(get_current_device_id()).transpose(0, 1) - def _fuse_weights(self) -> None: - super()._fuse_weights() - if self.weight_scale is None and (None not in self.weight_scales): - weight_scale = torch.cat(self.weight_scales, dim=0).cuda(get_current_device_id()) - self._process_weight_scale(weight_scale) - delattr(self, "weight_scales") +class AWQROWMMWeight(AWQMMWeightTpl): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + bias_name: Optional[str] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__( + weight_name=weight_name, + data_type=data_type, + bias_name=bias_name, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) + # 注意这里不是错误,因为awq的weight是按inxout存的 + self.param_slicer = QuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - if self.weight_scale is not None and isinstance(self.weight, torch.Tensor): - self.weight = [ - self.weight, - self.weight_scale, - None, - ] +class AWQMultiROWMMWeight(AWQMultiMMWeightTpl): + def __init__( + self, + weight_names: List[str], + data_type: torch.dtype, + bias_names: Optional[List[str]] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__( + weight_names=weight_names, + data_type=data_type, + bias_names=bias_names, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) + # 注意这里不是错误,因为awq的weight是按inxout存的 + self.param_slicer = QuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) -class UnquantizedROWBMMWeight(BMMWeightTpl): - _slice_weight = UnquantizedROWMMWeight._slice_weight - _slice_bias = UnquantizedROWMMWeight._slice_bias +class AWQMARLINROWMMWeight(AWQROWMMWeight): def __init__( self, weight_name: str, @@ -211,48 +175,68 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: - self.weight_name = weight_name - self.bias_name = bias_name - self.has_bias = bias_name is not None - super().__init__(data_type, quant_method, tp_rank, tp_world_size) + super().__init__(weight_name, data_type, bias_name, quant_method, tp_rank, tp_world_size) + + def _process_weight(self, weight: torch.Tensor) -> torch.Tensor: + new_weight = self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id())) + self.mm_param.weight = new_weight + return + + def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: + new_weight_scale = self.quant_method._process_weight_scale_after_loading( + weight_scale.cuda(get_current_device_id()).to(self.data_type_) + ) + self.mm_param.weight_scale = new_weight_scale + return + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + new_weight_zero_point = self.quant_method._process_weight_zero_point_after_loading( + weight_zero_point.cuda(get_current_device_id()) + ) + self.mm_param.weight_zero_point = new_weight_zero_point + return -class W8A8B128ROWBMMWeight(UnquantizedROWBMMWeight): - _slice_weight = W8A8B128ROWMMWeight._slice_weight - _slice_bias = W8A8B128ROWMMWeight._slice_bias +class AWQMARLINMultiROWMMWeight(AWQMultiROWMMWeight): def __init__( self, - weight_name: str, + weight_names: List[str], data_type: torch.dtype, - bias_name: Optional[str] = None, + bias_names: Optional[List[str]] = None, quant_method: QuantizationMethod = None, tp_rank: int = None, tp_world_size: int = None, - weight_scale_suffix: Optional[str] = None, - act_scale_suffix: Optional[str] = None, ) -> None: - super().__init__(weight_name, data_type, bias_name, quant_method, tp_rank, tp_world_size) - self.weight_scale_name, self.act_scale_name = generate_scale_name( - weight_name, weight_scale_suffix, act_scale_suffix + super().__init__(weight_names, data_type, bias_names, quant_method, tp_rank, tp_world_size) + + def _process_weight(self, weight: torch.Tensor) -> torch.Tensor: + new_weight = self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id())) + self.mm_param.weight = new_weight + return + + def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: + new_weight_scale = self.quant_method._process_weight_scale_after_loading( + weight_scale.cuda(get_current_device_id()).to(self.data_type_) ) - self.weight_scale: Optional[torch.Tensor] = None - self.quantized_weight = True + self.mm_param.weight_scale = new_weight_scale + return + + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + new_weight_zero_point = self.quant_method._process_weight_zero_point_after_loading( + weight_zero_point.cuda(get_current_device_id()) + ) + self.mm_param.weight_zero_point = new_weight_zero_point + return - def _slice_weight_scale(self, weight_scale: torch.Tensor): - tp_size = weight_scale.shape[0] // self.tp_world_size_ - scale_start = tp_size * self.tp_rank_ - scale_end = tp_size * (self.tp_rank_ + 1) - return weight_scale[scale_start:scale_end].to(torch.float) - def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: - if self.weight_scale_name is not None and self.weight_scale_name in weights: - weight_scale = weights[self.weight_scale_name] - weight_scale = self._slice_weight_scale(weight_scale) +ROWMM_WEIGHT_CLS_MAP = { + "deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128ROWMMWeight, + "awq": AWQROWMMWeight, + "awq_marlin": AWQMARLINROWMMWeight, +} - if self.weight_scale is not None and isinstance(self.weight, torch.Tensor): - self.weight = [ - self.weight, - self.weight_scale, - None, - ] +MULTI_ROWMM_WEIGHT_CLS_MAP = { + "deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128MultiROWMMWeight, + "awq": AWQMultiROWMMWeight, + "awq_marlin": AWQMARLINMultiROWMMWeight, +} diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index 611a3407a..26f59258c 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -5,6 +5,7 @@ from .w8a8_quant import * from .triton_quant.triton_quant import * from .deepgemm_quant import * +from .awq_quant import * from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -43,10 +44,14 @@ def _mapping_quant_method(self): else: self.quant_type = "vllm-fp8w8a8-b128" logger.info(f"select fp8w8a8-b128 quant way: {self.quant_type}") - - else: - # TODO: more quant method - pass + elif self.hf_quantization_method == "awq": + self.quant_type = "awq" + if is_awq_marlin_compatible(self.hf_quantization_config): + self.quant_type = "awq_marlin" + logger.info(f"select awq quant way: {self.quant_type}") + else: + # TODO: more quant method + pass def _parse_custom_cfg(self, custom_cfg_path): self.quant_cfg = collections.defaultdict(dict) @@ -58,7 +63,6 @@ def _parse_custom_cfg(self, custom_cfg_path): self.quant_type = data["quant_type"] for layer_quant_cfg in data.get("mix_bits", []): - print(layer_quant_cfg) name = layer_quant_cfg["name"] layer_nums = layer_quant_cfg.get("layer_nums", range(self.layer_num)) layer_quant_type = layer_quant_cfg["quant_type"] diff --git a/lightllm/common/quantization/awq_quant.py b/lightllm/common/quantization/awq_quant.py new file mode 100644 index 000000000..6a00bcf80 --- /dev/null +++ b/lightllm/common/quantization/awq_quant.py @@ -0,0 +1,219 @@ +import os +import torch +from .quantize_method import QuantizationMethod +from .registry import QUANTMETHODS +import torch.nn.functional as F +from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm +from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops +from typing import Any +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack + +if HAS_VLLM: + awq_dequantize = vllm_ops.awq_dequantize + awq_gemm = vllm_ops.awq_gemm + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, + marlin_permute_scales, + awq_to_marlin_zero_points, + should_use_atomic_add_reduce, + marlin_make_empty_g_idx, + marlin_make_workspace_new, + ) + from vllm.scalar_type import scalar_types + + TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + + +class AWQBaseQuantizationMethod(QuantizationMethod): + def __init__(self): + super().__init__() + assert HAS_VLLM, "vllm are not installed, you can't use quant api of them." + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + self.cache_manager = g_cache_manager + + def quantize(self, weight: torch.Tensor): + raise NotImplementedError("AWQ online quantization is not supported yet.") + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + raise NotImplementedError("AWQ online quantization is not supported yet.") + + @property + def method_name(self): + return "awq-base" + + +@QUANTMETHODS.register("awq") +class AWQW4A16QuantizationMethod(AWQBaseQuantizationMethod): + def __init__(self): + super().__init__() + self.pack_factor = 8 + self.weight_scale_suffix = "scales" + self.weight_zero_point_suffix = "qzeros" + self.weight_suffix = "qweight" + + @property + def method_name(self): + return "awq" + + def quantize(self, weight: torch.Tensor): + raise NotImplementedError("AWQ online quantization is not supported yet.") + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + qzeros = weight_pack.weight_zero_point + bias = weight_pack.bias + + NEED_DEQUANT_WEIGHT = input_tensor.shape[:-1].numel() >= 256 + if NEED_DEQUANT_WEIGHT: + fpweight = awq_dequantize(qweight, weight_scale, qzeros, 0, 0, 0) + out = torch.matmul(input_tensor, fpweight) + else: + out = awq_gemm(input_tensor, qweight, weight_scale, qzeros, self.pack_factor) + + if bias is not None: + out.add_(bias) + return out + + +@QUANTMETHODS.register("awq_marlin") +class AWQMARLINW4A16QuantizationMethod(AWQBaseQuantizationMethod): + def __init__(self): + super().__init__() + self.pack_factor = 8 + self.nbits = 4 + self.weight_scale_suffix = "scales" + self.weight_zero_point_suffix = "qzeros" + self.weight_suffix = "qweight" + self.g_idx = marlin_make_empty_g_idx(torch.device("cuda")) + self.g_idx_sort_indices = marlin_make_empty_g_idx(torch.device("cuda")) + self.workspace = marlin_make_workspace_new(torch.device("cuda")) + self.vllm_quant_type = TYPE_MAP[self.nbits] + + @property + def method_name(self): + return "awq_marlin" + + def quantize(self, weight: torch.Tensor): + raise NotImplementedError("AWQ online quantization is not supported yet.") + + def _process_weight_after_loading(self, weight: torch.Tensor) -> torch.Tensor: + assert self.hf_quantization_config is not None, "hf_quantization_config is not set" + self.k = weight.shape[0] + self.n = weight.shape[1] * self.pack_factor + return vllm_ops.awq_marlin_repack( + weight, + size_k=weight.shape[0], + size_n=weight.shape[1] * self.pack_factor, + num_bits=self.hf_quantization_config["bits"], + ) + + def _process_weight_scale_after_loading(self, weight_scale: torch.Tensor) -> torch.Tensor: + assert self.hf_quantization_config is not None, "hf_quantization_config is not set" + group_size = self.hf_quantization_config["group_size"] + return marlin_permute_scales( + weight_scale, + size_k=weight_scale.shape[0] * group_size, + size_n=weight_scale.shape[1], + group_size=self.hf_quantization_config["group_size"], + ) + + def _process_weight_zero_point_after_loading(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + return awq_to_marlin_zero_points( + weight_zero_point, + size_k=weight_zero_point.shape[0], + size_n=weight_zero_point.shape[1] * self.pack_factor, + num_bits=self.hf_quantization_config["bits"], + ) + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + qzeros = weight_pack.weight_zero_point + bias = weight_pack.bias + reshaped_x = input_tensor.reshape(-1, input_tensor.shape[-1]) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=self.n, + k=self.k, + device=input_tensor.device, + dtype=input_tensor.dtype, + ) + + out = vllm_ops.gptq_marlin_gemm( + reshaped_x, + None, + qweight, + bias, + weight_scale, + None, + qzeros, + self.g_idx, + self.g_idx_sort_indices, + self.workspace, + self.vllm_quant_type, + size_m=reshaped_x.shape[0], + size_n=self.n, + size_k=self.k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=True, + is_zp_float=False, + ) + + if bias is not None: + out.add_(bias) + return out + + +# adapted from +# https://github.com/vllm-project/vllm/blob/aef368aa08572505b820db01da82e2fbb3d43a72/vllm/model_executor/layers/quantization/awq_marlin.py#L211-L212 +def is_awq_marlin_compatible(quantization_config: dict[str, Any]): + # Extract data from quant config. + quant_method = quantization_config.get("quant_method", "").lower() + num_bits = quantization_config.get("bits") + group_size = quantization_config.get("group_size") + zero_point = quantization_config.get("zero_point") + + if not torch.cuda.is_available(): + return False + + if quant_method != "awq": + return False + + # If we cannot find the info needed in the config, cannot convert. + if num_bits is None or group_size is None or zero_point is None: + return False + + if num_bits not in TYPE_MAP: + return False + + return check_marlin_supported(quant_type=TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point) diff --git a/lightllm/common/quantization/deepgemm_quant.py b/lightllm/common/quantization/deepgemm_quant.py index 964db67cf..7d3fc5358 100644 --- a/lightllm/common/quantization/deepgemm_quant.py +++ b/lightllm/common/quantization/deepgemm_quant.py @@ -7,7 +7,10 @@ per_token_group_quant_fp8, tma_align_input_scale, ) +from typing import TYPE_CHECKING, Optional +if TYPE_CHECKING: + from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack try: HAS_DEEPGEMM = True import deep_gemm @@ -27,9 +30,19 @@ def quantize(self, weight: torch.Tensor): """ """ pass - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): - """ """ - pass + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + raise NotImplementedError("Not implemented") + + @property + def method_name(self): + return "deepgemm-base" @QUANTMETHODS.register(["deepgemm-fp8w8a8-b128"]) @@ -37,23 +50,33 @@ class DeepGEMMFP8w8a8B128QuantizationMethod(DeepGEMMBaseQuantizationMethod): def __init__(self): super().__init__() self.block_size = 128 + self.weight_suffix = None + self.weight_zero_point_suffix = None self.weight_scale_suffix = "weight_scale_inv" - self.act_scale_suffix = None # no support for static input tensor scale for ds model. + + @property + def method_name(self): + return "deepgemm-fp8w8a8-b128" def quantize(self, weight: torch.Tensor): from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant return weight_quant(weight, self.block_size) - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): - if len(weights) == 3: - qweight, weight_scale, input_scale = weights - else: - qweight, weight_scale = weights - input_scale = None + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + input_scale = None alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty m, k = input_tensor.shape - n = weights[0].shape[1] + n = qweight.shape[1] if input_scale is None: qinput_tensor, input_scale = per_token_group_quant_fp8( input_tensor, diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index b7b4c3705..5a7db15fc 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -1,13 +1,19 @@ import torch from abc import ABC, abstractmethod from lightllm.utils.dist_utils import get_current_device_id +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack class QuantizationMethod(ABC): def __init__(self): super().__init__() self.device_id_ = get_current_device_id() + self.weight_suffix = None self.weight_scale_suffix = None + self.weight_zero_point_suffix = None self.act_scale_suffix = None @abstractmethod @@ -15,5 +21,18 @@ def quantize(self, weights: torch.Tensor): pass @abstractmethod - def apply(self, input_tensor, weight, bias=None, out=None, use_custom_tensor_mananger=True): + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + bias: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + pass + + @property + @abstractmethod + def method_name(self): pass diff --git a/lightllm/common/quantization/registry.py b/lightllm/common/quantization/registry.py index 350f7fd1c..674a22b60 100644 --- a/lightllm/common/quantization/registry.py +++ b/lightllm/common/quantization/registry.py @@ -1,3 +1,7 @@ +from .quantize_method import QuantizationMethod +from typing import Type + + class QuantMethodFactory: def __init__(self): self._quant_methods = {} @@ -13,7 +17,7 @@ def decorator(cls): return decorator - def get(self, key, *args, **kwargs): + def get(self, key, *args, **kwargs) -> Type[QuantizationMethod]: if key == "none": return None quant_method_class = self._quant_methods.get(key) diff --git a/lightllm/common/quantization/torchao_quant.py b/lightllm/common/quantization/torchao_quant.py index 67677b50c..df8d1319d 100644 --- a/lightllm/common/quantization/torchao_quant.py +++ b/lightllm/common/quantization/torchao_quant.py @@ -3,12 +3,13 @@ from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS import torch.nn.functional as F +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack try: HAS_TORCH_AO = True - from torchao.dtypes import to_affine_quantized_intx, AffineQuantizedTensor - from torchao.dtypes import TensorCoreTiledLayoutType - from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.quantization import ( int4_weight_only, int8_weight_only, @@ -38,11 +39,24 @@ def quantize(self, weight: torch.Tensor): dummy_linear = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) dummy_linear.weight = torch.nn.Parameter(weight.cuda(self.device_id_)) quantize_(dummy_linear, self.quant_func) - return dummy_linear.weight - - def apply(self, input_tensor, weights, bias=None, out=None, use_custom_tensor_mananger=True): + return dummy_linear.weight, None, None + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + weights = weight_pack.weight + bias = weight_pack.bias return F.linear(input_tensor, weights, bias) + @property + def method_name(self): + return "ao-base" + @QUANTMETHODS.register(["ao-w4a16-256"]) class AOW4A16QuantizationMethodGroup256(AOBaseQuantizationMethod): @@ -51,6 +65,10 @@ def __init__(self): self.group_size = 256 self.quant_func = int4_weight_only(group_size=self.group_size) + @property + def method_name(self): + return "ao-w4a16-256" + @QUANTMETHODS.register(["ao-w4a16-128"]) class AOW4A16QuantizationMethodGroup128(AOBaseQuantizationMethod): @@ -59,6 +77,10 @@ def __init__(self): self.group_size = 128 self.quant_func = int4_weight_only(group_size=self.group_size) + @property + def method_name(self): + return "ao-w4a16-128" + @QUANTMETHODS.register(["ao-w4a16-64"]) class AOW4A16QuantizationMethodGroup64(AOBaseQuantizationMethod): @@ -67,6 +89,10 @@ def __init__(self): self.group_size = 64 self.quant_func = int4_weight_only(group_size=self.group_size) + @property + def method_name(self): + return "ao-w4a16-64" + @QUANTMETHODS.register(["ao-w4a16-32"]) class AOW4A16QuantizationMethodGroup32(AOBaseQuantizationMethod): @@ -75,6 +101,10 @@ def __init__(self): self.group_size = 32 self.quant_func = int4_weight_only(group_size=self.group_size) + @property + def method_name(self): + return "ao-w4a16-32" + @QUANTMETHODS.register("ao-w8a8") class AOW8A8QuantizationMethod(AOBaseQuantizationMethod): @@ -82,6 +112,10 @@ def __init__(self): super().__init__() self.quant_func = int8_dynamic_activation_int8_weight() + @property + def method_name(self): + return "ao-w8a8" + @QUANTMETHODS.register("ao-w8a16") class AOW8A16QuantizationMethod(AOBaseQuantizationMethod): @@ -89,6 +123,10 @@ def __init__(self): super().__init__() self.quant_func = int8_weight_only() + @property + def method_name(self): + return "ao-w8a16" + @QUANTMETHODS.register("ao-fp8w8a16") class AOFP8W8A16QuantizationMethod(AOBaseQuantizationMethod): @@ -98,6 +136,10 @@ def __init__(self): assert is_cuda_8_9, "FP8 requires GPU with compute capability >= 8.9" self.quant_func = float8_weight_only() + @property + def method_name(self): + return "ao-fp8w8a16" + @QUANTMETHODS.register("ao-fp6w6a16") class AOFP6W6A16QuantizationMethod(AOBaseQuantizationMethod): @@ -105,3 +147,7 @@ def __init__(self): super().__init__() assert TORCH_VERSION_AT_LEAST_2_5, "torchao fp6 requires torch >=2.5" self.quant_func = fpx_weight_only(3, 2) + + @property + def method_name(self): + return "ao-fp6w6a16" diff --git a/lightllm/common/quantization/triton_quant/triton_quant.py b/lightllm/common/quantization/triton_quant/triton_quant.py index a8d6a0055..a79e3f65a 100644 --- a/lightllm/common/quantization/triton_quant/triton_quant.py +++ b/lightllm/common/quantization/triton_quant/triton_quant.py @@ -5,6 +5,10 @@ from lightllm.common.quantization.registry import QUANTMETHODS from .fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul from .fp8.fp8act_quant_kernel import per_token_group_quant_fp8 +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack class TritonBaseQuantizationMethod(QuantizationMethod): @@ -15,12 +19,17 @@ def __init__(self): self.cache_manager = g_cache_manager def quantize(self, weight: torch.Tensor): - """ """ pass - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): - """ """ - pass + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + raise NotImplementedError("Not implemented") @QUANTMETHODS.register(["triton-fp8w8a8-block128"]) @@ -29,13 +38,25 @@ def __init__(self): super().__init__() self.is_moe = False self.block_size = 128 + self.weight_suffix = None + self.weight_zero_point_suffix = None + self.weight_scale_suffix = "weight_scale_inv" def quantize(self, weight: torch.Tensor): # TODO block-wise quant kernel pass - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): - qweight, weight_scale, input_scale = weights + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + input_scale = None m, k = input_tensor.shape n = qweight.shape[1] alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty diff --git a/lightllm/common/quantization/w8a8_quant.py b/lightllm/common/quantization/w8a8_quant.py index 1c38b625f..ea5b66bce 100644 --- a/lightllm/common/quantization/w8a8_quant.py +++ b/lightllm/common/quantization/w8a8_quant.py @@ -3,11 +3,15 @@ from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS import torch.nn.functional as F +from typing import Optional, TYPE_CHECKING from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8 from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops +if TYPE_CHECKING: + from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack + if HAS_LIGHTLLM_KERNEL: def scaled_fp8_quant(tensor, *args, **kwargs): @@ -27,12 +31,21 @@ def __init__(self): self.cache_manager = g_cache_manager def quantize(self, weight: torch.Tensor): - """ """ pass - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): - """ """ - pass + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + raise NotImplementedError("Not implemented") + + @property + def method_name(self): + return "w8a8-base" @QUANTMETHODS.register(["vllm-w8a8", "w8a8"]) @@ -47,17 +60,21 @@ def quantize(self, weight: torch.Tensor): scale = weight.abs().max(dim=-1)[0] / 127 weight = weight.transpose(0, 1) / scale.reshape(1, -1) weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8) - return weight.cuda(self.device_id_), scale.cuda(self.device_id_) - - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): + return weight.cuda(self.device_id_), scale.cuda(self.device_id_), None + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: input_scale = None - if len(weights) == 3: - qweight, weight_scale, input_scale = weights - elif len(weights) == 2: - qweight, weight_scale = weights - else: - raise ValueError("vllm-quant Weights must be a tuple of length 2 or 3.") - + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + bias = weight_pack.bias + input_scale = None # dynamic quantization for input tensor x_q, x_scale, x_zp = vllm_ops.scaled_int8_quant(input_tensor, scale=input_scale, azp=None, symmetric=True) m = input_tensor.shape[0] n = qweight.shape[1] @@ -71,6 +88,10 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) return out + @property + def method_name(self): + return "vllm-w8a8" + @QUANTMETHODS.register(["vllm-fp8w8a8", "fp8w8a8"]) class FP8w8a8QuantizationMethod(BaseQuantizationMethod): @@ -84,9 +105,9 @@ def quantize(self, weight: torch.Tensor): qweight, weight_scale = scaled_fp8_quant( weight.contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True ) - return qweight.transpose(0, 1), weight_scale + return qweight.transpose(0, 1), weight_scale, None - def quantize_moe(self, weight): + def quantize_moe(self, weight: torch.Tensor): num_experts = weight.shape[0] qweights = [] weight_scales = [] @@ -100,10 +121,20 @@ def quantize_moe(self, weight): weight_scale = torch.stack(weight_scales, dim=0).contiguous() return qweights, weight_scale - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + bias = weight_pack.bias x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) m = input_tensor.shape[0] - n = weights[0].shape[1] + n = qweight.shape[1] if out is None: if use_custom_tensor_mananger: out = self.cache_manager.alloc_tensor( @@ -111,9 +142,13 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ ) else: out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias) + cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) return out + @property + def method_name(self): + return "vllm-fp8w8a8" + @QUANTMETHODS.register(["vllm-fp8w8a8-b128", "fp8w8a8-b128"]) class FP8w8a8B128QuantizationMethod(BaseQuantizationMethod): @@ -121,16 +156,25 @@ def __init__(self): super().__init__() self.block_size = 128 self.weight_scale_suffix = "weight_scale_inv" - self.act_scale_suffix = None # no support for static input tensor scale for ds model. def quantize(self, weight: torch.Tensor): raise Exception("Not implemented") - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): - qweight, weight_scale, input_scale = weights + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + bias = weight_pack.bias + input_scale = None # dynamic quantization for input tensor m, k = input_tensor.shape - n = weights[0].shape[1] + n = qweight.shape[1] alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty if input_scale is None: qinput_tensor, input_scale = per_token_group_quant_fp8( @@ -152,3 +196,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ input_scale = input_scale.t().contiguous().t() cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias) return out + + @property + def method_name(self): + return "vllm-fp8w8a8-b128" diff --git a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py index 0d2a3084d..b6de41140 100644 --- a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py @@ -1,6 +1,6 @@ import torch import numpy as np -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import ROWMMWeight +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight from lightllm.common.basemodel.layer_weights.meta_weights.norm_weight import NormWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index a67cd5ac2..661c450f0 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -3,7 +3,7 @@ import numpy as np from lightllm.common.basemodel.layer_weights.meta_weights.gpt_oss_fused_moe_weight_tp import GPTOSSFusedMoeWeightTP -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import ROWMMWeight +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight from lightllm.common.basemodel.layer_weights.meta_weights.norm_weight import NormWeight, TpNormWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight from lightllm.utils.log_utils import init_logger