diff --git a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py index 540988266d1..20f47cf36f7 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py +++ b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py @@ -29,6 +29,7 @@ weight_quantize_xpu, xpu_moe_layer, ) +from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs class XPUMoEMethod(MoEMethodBase): @@ -61,78 +62,146 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): """ create weight process. """ - self.up_gate_proj_weight_shape = [ - layer.num_local_experts, - layer.moe_intermediate_size * 2, - layer.hidden_size, - ] - self.down_proj_weight_shape = [ - layer.num_local_experts, - layer.hidden_size, - layer.moe_intermediate_size, - ] - if self.moe_quant_type in ["weight_only_int4", "w4a8"]: - self.up_gate_proj_weight_shape[-1] //= 2 - self.down_proj_weight_shape[-1] //= 2 + if layer.fd_config.load_config.load_choices == "default_v1" and self.moe_quant_type in ["w16a16"]: + self.up_gate_proj_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + layer.hidden_size, + ] + self.down_proj_weight_shape = [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size] + extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}} - setattr( - layer, - self.added_weight_attrs[0], - layer.create_parameter( + layer.up_gate_proj_weight = layer.create_parameter( shape=self.up_gate_proj_weight_shape, - dtype=self.weight_dtype, + dtype=layer.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - setattr( - layer, - self.added_weight_attrs[1], - layer.create_parameter( + ) + + layer.down_proj_weight = layer.create_parameter( shape=self.down_proj_weight_shape, - dtype=self.weight_dtype, + dtype=layer.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), - ), - ) + ) + + set_weight_attrs( + layer.up_gate_proj_weight, + { + "weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)), + "weight_need_transpose": extra_weight_attrs.get("model_format") == "torch", + }, + ) + set_weight_attrs( + layer.down_proj_weight, + { + "weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)), + "weight_need_transpose": extra_weight_attrs.get("model_format") == "torch", + }, + ) - if self.moe_quant_type in ["weight_only_int8", "w8a8", "weight_only_int4", "w4a8"]: - self.up_gate_proj_scale_shape = [ + if layer.with_bias: + layer.up_gate_proj_bias = layer.create_parameter( + shape=[layer.num_experts, layer.moe_intermediate_size * 2], + dtype=layer.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + layer.down_proj_bias = layer.create_parameter( + shape=[layer.num_experts, layer.hidden_size], + dtype=layer.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + set_weight_attrs( + layer.up_gate_proj_bias, + { + "weight_loader": extra_weight_attrs.get( + "weight_loader", default_weight_loader(layer.fd_config) + ), + "model_format": extra_weight_attrs.get("model_format", ""), + }, + ) + set_weight_attrs( + layer.down_proj_bias, + { + "weight_loader": extra_weight_attrs.get( + "weight_loader", default_weight_loader(layer.fd_config) + ), + "model_format": extra_weight_attrs.get("model_format", ""), + }, + ) + + else: + self.up_gate_proj_weight_shape = [ layer.num_local_experts, layer.moe_intermediate_size * 2, + layer.hidden_size, ] - self.down_proj_scale_shape = [ + self.down_proj_weight_shape = [ layer.num_local_experts, layer.hidden_size, + layer.moe_intermediate_size, ] + if self.moe_quant_type in ["weight_only_int4", "w4a8"]: + self.up_gate_proj_weight_shape[-1] //= 2 + self.down_proj_weight_shape[-1] //= 2 + setattr( layer, - self.added_scale_attrs[0], + self.added_weight_attrs[0], layer.create_parameter( - shape=self.up_gate_proj_scale_shape, - dtype=self.scale_dtype, + shape=self.up_gate_proj_weight_shape, + dtype=self.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), ), ) setattr( layer, - self.added_scale_attrs[1], + self.added_weight_attrs[1], layer.create_parameter( - shape=self.down_proj_scale_shape, - dtype=self.scale_dtype, + shape=self.down_proj_weight_shape, + dtype=self.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), ), ) - if self.moe_quant_type in ["w8a8", "w4a8"]: - for in_scale_name in self.added_in_scale_attrs: + if self.moe_quant_type in ["weight_only_int8", "w8a8", "weight_only_int4", "w4a8"]: + self.up_gate_proj_scale_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + ] + self.down_proj_scale_shape = [ + layer.num_local_experts, + layer.hidden_size, + ] setattr( layer, - in_scale_name, + self.added_scale_attrs[0], layer.create_parameter( - shape=[layer.num_local_experts], + shape=self.up_gate_proj_scale_shape, dtype=self.scale_dtype, default_initializer=paddle.nn.initializer.Constant(0), ), ) + setattr( + layer, + self.added_scale_attrs[1], + layer.create_parameter( + shape=self.down_proj_scale_shape, + dtype=self.scale_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + if self.moe_quant_type in ["w8a8", "w4a8"]: + for in_scale_name in self.added_in_scale_attrs: + setattr( + layer, + in_scale_name, + layer.create_parameter( + shape=[layer.num_local_experts], + dtype=self.scale_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) def process_loaded_weights(self, layer: nn.Layer, state_dict): up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 4ef3f6e451a..28278d5654f 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -261,8 +261,8 @@ def v1_loader_support(fd_config): def _err_msg(msg: str) -> str: logger.info(msg + "; fallback to the v0 loader for model loading.") - if not current_platform.is_cuda(): - _err_msg("v1loader currently does not support backends other than CUDA") + if not (current_platform.is_cuda() or current_platform.is_xpu()): + _err_msg("v1loader currently only support backends gpu and xpu") return False if is_pre_sliced_weight(fd_config.model_config.model): diff --git a/scripts/run_ci_xpu.sh b/scripts/run_ci_xpu.sh index 38cda9120d3..f2be141950a 100644 --- a/scripts/run_ci_xpu.sh +++ b/scripts/run_ci_xpu.sh @@ -26,7 +26,6 @@ echo "build whl" bash custom_ops/xpu_ops/download_dependencies.sh develop export CLANG_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xtdk export XVLLM_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xvllm - bash build.sh || exit 1 echo "pip others" @@ -54,7 +53,8 @@ python -m fastdeploy.entrypoints.openai.api_server \ --num-gpu-blocks-override 16384 \ --max-model-len 32768 \ --max-num-seqs 128 \ - --quantization wint4 > server.log 2>&1 & + --quantization wint4 \ + --load-choices default > server.log 2>&1 & sleep 60 # 探活 @@ -121,7 +121,8 @@ python -m fastdeploy.entrypoints.openai.api_server \ --num-gpu-blocks-override 16384 \ --max-model-len 32768 \ --max-num-seqs 64 \ - --quantization "W4A8" > server.log 2>&1 & + --quantization "W4A8" \ + --load-choices default > server.log 2>&1 & sleep 60 # 探活 @@ -191,7 +192,8 @@ python -m fastdeploy.entrypoints.openai.api_server \ --enable-mm \ --mm-processor-kwargs '{"video_max_frames": 30}' \ --limit-mm-per-prompt '{"image": 10, "video": 3}' \ - --reasoning-parser ernie-45-vl > server.log 2>&1 & + --reasoning-parser ernie-45-vl \ + --load-choices default > server.log 2>&1 & sleep 60 # 探活 diff --git a/tests/ci_use/XPU_45T/run_ep.py b/tests/ci_use/XPU_45T/run_ep.py index c82242aa394..e411396d69a 100644 --- a/tests/ci_use/XPU_45T/run_ep.py +++ b/tests/ci_use/XPU_45T/run_ep.py @@ -44,6 +44,7 @@ def test_fd_ep(): quantization="wint4", engine_worker_queue_port=engine_worker_queue_port, max_num_seqs=8, + load_choices="default", ) try: