Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .base_weight import BaseWeight
from .mm_weight import (
MMWeightPack,
MMWeightTpl,
MultiMMWeightTpl,
ROWMMWeight,
COLMMWeight,
MultiROWMMWeight,
ROWBMMWeight,
AWQMultiMMWeightTpl,
)
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
from .fused_moe_weight_tp import FusedMoeWeightTP
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from multiprocessing import parent_process
import torch
from abc import ABC, abstractmethod
from typing import Dict
Expand All @@ -14,7 +15,7 @@ def load_hf_weights(self, weights):

@abstractmethod
def verify_load(self):
pass
parent_process


class BaseWeightTpl(BaseWeight):
Expand All @@ -24,30 +25,8 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, data_type: to
self.device_id_ = get_current_device_id()
self.data_type_ = data_type

def _slice_weight(self, weight: torch.Tensor):
# slice weight
return weight.to(self.data_type_)

def _slice_bias(self, bias: torch.Tensor):
# slice bias
return bias.to(self.data_type_)

def _slice_weight_scale(self, weight_scale: torch.Tensor):
# slice weight scale and zero point
return weight_scale

def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None:
# load weight
pass

def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
# load quantization scale
pass

def load_hf_weights(self, weights):
self._load_weights(weights)
self._load_scales(weights)
return
raise NotImplementedError("load_hf_weights must implement this method")

def verify_load(self):
pass
raise NotImplementedError("verify_load must implement this method")

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from .mm_weight import (
MMWeightPack,
MMWeightTpl,
MultiMMWeightTpl,
AWQMultiMMWeightTpl,
)
from .rowmm_weight import (
from .mm_factory import (
MMWeight,
ROWMMWeight,
ROWBMMWeight,
MultiROWMMWeight,
W8A8B128ROWMMWeight,
W8A8B128ROWBMMWeight,
W8A8B128MultiROWMMWeight,
)
from .colmm_weight import (
ROWBMMWeight,
COLMMWeight,
W8A8B128COLMMWeight,
)
Original file line number Diff line number Diff line change
@@ -1,25 +1,38 @@
import torch
from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import (
MMWeight,
MMWeightTpl,
generate_scale_name,
SingleMMWeightTpl,
DeepGemmFP8W8A8B128MMWeight,
AWQMMWeightTpl,
)
from lightllm.common.quantization import Quantcfg
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.common.quantization.quantize_method import QuantizationMethod
from typing import Dict, List, Optional
from .mm_slicer import ColSliceMixin, QuantizedRowSliceMixin, QuantizedColSliceMixin


class COLMMWeight(MMWeight):
@classmethod
def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool):
if quant_method is None or not quantized_weight:
return UnquantizedCOLMMWeight
else:
return W8A8B128COLMMWeight
class UnquantizedCOLMMWeight(SingleMMWeightTpl):
def __init__(
self,
weight_name: str,
data_type: torch.dtype,
bias_name: Optional[str] = None,
quant_method: QuantizationMethod = None,
tp_rank: int = None,
tp_world_size: int = None,
) -> None:
super().__init__(
weight_name=weight_name,
data_type=data_type,
bias_name=bias_name,
quant_method=quant_method,
tp_rank=tp_rank,
tp_world_size=tp_world_size,
)
self.param_slicer = ColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size)


class UnquantizedCOLMMWeight(MMWeightTpl):
class DeepGemmFP8W8A8B128COLMMWeight(DeepGemmFP8W8A8B128MMWeight):
def __init__(
self,
weight_name: str,
Expand All @@ -29,24 +42,40 @@ def __init__(
tp_rank: int = None,
tp_world_size: int = None,
) -> None:
super().__init__(data_type, quant_method, tp_rank, tp_world_size)
self.weight_name = weight_name
self.bias_name = bias_name
self.has_bias = bias_name is not None
super().__init__(
weight_name=weight_name,
data_type=data_type,
bias_name=bias_name,
quant_method=quant_method,
tp_rank=tp_rank,
tp_world_size=tp_world_size,
)
self.param_slicer = QuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size)

def _slice_weight(self, tensor):
assert tensor.shape[1] % self.tp_world_size_ == 0, f"tp slice error {tensor.shape[1]} % {self.tp_world_size_}"
tp_size = tensor.shape[1] // self.tp_world_size_
return tensor[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)].to(self.data_type_)

def _slice_bias(self, bias):
"""
因为 Colmm 列 tp 切分的计算,最后会有一个 reduce 操作,直接将 bias / tp_world_size 可以节省一步计算。
"""
return (bias / self.tp_world_size_).to(self.data_type_)
class AWQCOLMMWeight(AWQMMWeightTpl):
def __init__(
self,
weight_name: str,
data_type: torch.dtype,
bias_name: Optional[str] = None,
quant_method: QuantizationMethod = None,
tp_rank: int = None,
tp_world_size: int = None,
) -> None:
super().__init__(
weight_name=weight_name,
data_type=data_type,
bias_name=bias_name,
quant_method=quant_method,
tp_rank=tp_rank,
tp_world_size=tp_world_size,
)
# 注意这里不是错误,因为awq的weight是按inxout存的
self.param_slicer = QuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size)


class W8A8B128COLMMWeight(MMWeightTpl):
class AWQMARLINCOLMMWeight(AWQCOLMMWeight):
def __init__(
self,
weight_name: str,
Expand All @@ -56,44 +85,37 @@ def __init__(
tp_rank: int = None,
tp_world_size: int = None,
) -> None:
super().__init__(data_type, quant_method, tp_rank, tp_world_size)
self.weight_name = weight_name
self.bias_name = bias_name
self.has_bias = bias_name is not None

self.weight_scale_name, self.act_scale_name = generate_scale_name(
weight_name, quant_method.weight_scale_suffix, quant_method.act_scale_suffix
super().__init__(
weight_name=weight_name,
data_type=data_type,
bias_name=bias_name,
quant_method=quant_method,
tp_rank=tp_rank,
tp_world_size=tp_world_size,
)
self.weight_scale: Optional[torch.Tensor] = None
self.block_size = self.quant_method.block_size
self.quantized_weight = True

def _slice_weight(self, tensor):
assert tensor.shape[1] % self.tp_world_size_ == 0, f"tp slice error {tensor.shape[1]} % {self.tp_world_size_}"
tp_size = tensor.shape[1] // self.tp_world_size_
return tensor[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)]

def _slice_weight_scale(self, weight_scale: torch.Tensor):
assert (
weight_scale.shape[1] % self.tp_world_size_ == 0
), f"tp slice error {weight_scale.shape[1]} % {self.tp_world_size_}"
tp_size = weight_scale.shape[1] // self.tp_world_size_
scale_start = tp_size * self.tp_rank_
scale_end = tp_size * (self.tp_rank_ + 1)
return weight_scale[:, scale_start:scale_end].to(torch.float)
def _process_weight(self, weight: torch.Tensor) -> torch.Tensor:
new_weight = self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id()))
self.mm_param.weight = new_weight
return

def _process_weight_scale(self, weight_scale) -> None:
self.weight_scale = weight_scale.cuda(get_current_device_id()).transpose(0, 1)
def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
new_weight_scale = self.quant_method._process_weight_scale_after_loading(
weight_scale.cuda(get_current_device_id()).to(self.data_type_)
)
self.mm_param.weight_scale = new_weight_scale
return

def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
if self.weight_scale_name in weights:
weight_scale = self._slice_weight_scale(weights[self.weight_scale_name])
self._process_weight_scale(weight_scale)
if self.weight_scale is not None and isinstance(self.weight, torch.Tensor):
# weight 中保存的 None 是为 激活静态量化 scale 预留的扩展位置。
self.weight = [
self.weight,
self.weight_scale,
None,
]
def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
new_weight_zero_point = self.quant_method._process_weight_zero_point_after_loading(
weight_zero_point.cuda(get_current_device_id())
)
self.mm_param.weight_zero_point = new_weight_zero_point
return


COLMM_WEIGHT_CLS_MAP = {
"deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128COLMMWeight,
"awq": AWQCOLMMWeight,
"awq_marlin": AWQMARLINCOLMMWeight,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from lightllm.common.quantization import Quantcfg
from lightllm.common.quantization.quantize_method import QuantizationMethod
from typing import Type, Union, Dict
from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import (
MMWeightTpl,
MultiMMWeightTpl,
BMMWeightTpl,
)
from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import (
UnquantizedROWMMWeight,
UnquantizedROWBMMWeight,
UnquantizedMultiROWMMWeight,
ROWMM_WEIGHT_CLS_MAP,
MULTI_ROWMM_WEIGHT_CLS_MAP,
)
from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.colmm_weight import (
UnquantizedCOLMMWeight,
COLMM_WEIGHT_CLS_MAP,
)


class MMWeight:
def __new__(cls, **kwargs):
quant_cfg = kwargs.pop("quant_cfg", None)
layer_num_ = kwargs.pop("layer_num", None)
name = kwargs.pop("name", None)
quant_method, quantized_weight = cls._get_quant_method(quant_cfg, layer_num_, name)
kwargs["quant_method"] = quant_method
mmcls = cls._get_mmcls(quant_method, quantized_weight)
return mmcls(**kwargs)

@classmethod
def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> QuantizationMethod:
if quant_cfg is None:
return None, False
quant_method = quant_cfg.get_quant_method(layer_num_, name)
if quant_method is None:
return None, False
quant_method.hf_quantization_config = quant_cfg.hf_quantization_config
quantized_weight = quant_cfg.quantized_weight
return quant_method, quantized_weight

@classmethod
def _get_mmcls(
cls, quant_method: QuantizationMethod, quantized_weight: bool
) -> Type[Union[MMWeightTpl, MultiMMWeightTpl, BMMWeightTpl]]:
raise NotImplementedError("Subclasses must implement _get_mmcls method")


class ROWMMWeight(MMWeight):
@classmethod
def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool):
if quant_method is None or not quantized_weight:
return UnquantizedROWMMWeight

return ROWMM_WEIGHT_CLS_MAP[quant_method.method_name]


class MultiROWMMWeight(MMWeight):
@classmethod
def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool):
if quant_method is None or not quantized_weight:
return UnquantizedMultiROWMMWeight

return MULTI_ROWMM_WEIGHT_CLS_MAP[quant_method.method_name]


class ROWBMMWeight(MMWeight):
@classmethod
def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool):
if quant_method is None or not quantized_weight:
return UnquantizedROWBMMWeight
else:
# TODO: Implement more quantization weight
raise NotImplementedError("ROWBMMWeight is not implemented")


class COLMMWeight(MMWeight):
@classmethod
def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool):
if quant_method is None or not quantized_weight:
return UnquantizedCOLMMWeight
return COLMM_WEIGHT_CLS_MAP[quant_method.method_name]
Loading