Skip to content
4 changes: 4 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@
"FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR": lambda: os.getenv("FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR", "</think>"),
# Timeout for cache_transfer_manager process exit
"FD_CACHE_PROC_EXIT_TIMEOUT": lambda: int(os.getenv("FD_CACHE_PROC_EXIT_TIMEOUT", "600")),
# FP4 dense GEMM backend, could be flashinfer-cutlass, flashinfer-trtllm, flashinfer-cudnn or None (default is None)
"FD_NVFP4_GEMM_BACKEND": lambda: os.getenv("FD_NVFP4_MOE_BACKEND", None),
# Flahinfer MOE backend, could be flashinfer-cutlass, flashinfer-trtllm or None (default is None)
"FD_FLASHINFER_MOE_BACKEND": lambda: os.getenv("FD_FLASHINFER_MOE_BACKEND", None),
# Count for cache_transfer_manager process error
"FD_CACHE_PROC_ERROR_COUNT": lambda: int(os.getenv("FD_CACHE_PROC_ERROR_COUNT", "10")),
# API_KEY required for service authentication
Expand Down
40 changes: 40 additions & 0 deletions fastdeploy/flashinfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import functools
import importlib
import importlib.util
import os
import shutil

from paddleformers.utils.log import logger


@functools.cache
def has_flashinfer() -> bool:
"""Return `True` if FlashInfer is available."""
# Use find_spec to check if the module exists without importing it
# This avoids potential CUDA initialization side effects
if os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() not in ["1", "on", "true"]:
# currently must support by Paddle compatible API
logger.warning("FlashInfer is not supported by Paddle compatible API.")
return False
if importlib.util.find_spec("flashinfer") is None:
return False
# Also check if nvcc is available since it's required to JIT compile flashinfer
if shutil.which("nvcc") is None:
return False
return True
1 change: 0 additions & 1 deletion fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def process_loaded_weights(self, layer, weights) -> None:
layer.weight.set_value(weights)

def apply(self, layer: nn.Layer, x: paddle.Tensor) -> paddle.Tensor:

linear_out = paddle.matmul(x, layer.weight)
if layer.with_bias:
linear_out = paddle.add(linear_out, layer.bias)
Expand Down
41 changes: 37 additions & 4 deletions fastdeploy/model_executor/layers/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,10 @@ def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_
expert_param = param[expert_id - self.expert_id_offset]
dim = -1 if shard_dim else 0
param_shard_size = expert_param.shape[dim] // 2
if shard_id == "gate":
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
if (shard_id == "gate" and not switch_w13) or (shard_id == "up" and switch_w13):
param_shard_offset = 0
else:
# shard_id == "up":
param_shard_offset = param_shard_size
expert_param = slice_fn(
expert_param, shard_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size
Expand All @@ -302,8 +302,12 @@ def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_
)

# To ensure compatibility across backends, apply an extra transpose for GCU and XPU

if expert_param.shape != loaded_weight.shape:
loaded_weight = loaded_weight.transpose([1, 0])
if len(expert_param.shape) != len(loaded_weight.shape):
loaded_weight = loaded_weight.reshape(expert_param.shape)
else:
loaded_weight = loaded_weight.transpose([1, 0])
assert expert_param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
)
Expand Down Expand Up @@ -360,6 +364,32 @@ def _load_fused_experts_weight(self, param, loaded_weight):
for i in range(self.num_local_experts):
param.tensor_track.mark(start=0, batch_id=i)

def _load_per_tensor_weight_scale(
self,
param,
expert_id,
loaded_weight,
shard_id,
):
loaded_weight = get_tensor(loaded_weight)
expert_param = param[expert_id - self.expert_id_offset]
if shard_id in ["gate", "up"]:
idx = 0 if shard_id == "gate" else 1
if expert_param[idx].shape != loaded_weight.shape:
if len(expert_param[idx].shape) != len(loaded_weight.shape):
loaded_weight = loaded_weight.reshape(expert_param[idx].shape)
else:
loaded_weight = loaded_weight.transpose([1, 0])

expert_param[idx].set_value(loaded_weight)
elif shard_id == "down":
if expert_param.shape != loaded_weight.shape:
if len(expert_param.shape) != len(loaded_weight.shape):
loaded_weight = loaded_weight.reshape(expert_param.shape)
else:
loaded_weight = loaded_weight.transpose([1, 0])
expert_param.set_value(loaded_weight)

def _load_expert_weight(
self,
param,
Expand All @@ -368,7 +398,10 @@ def _load_expert_weight(
shard_id,
shard_dim=None,
):
if shard_id == "down":
weight_type = getattr(param, "weight_type", None)
if weight_type in ["weight_scale_2", "input_scale"]:
self._load_per_tensor_weight_scale(param, expert_id, loaded_weight, shard_id)
elif shard_id == "down":
self._load_down_weight(param, expert_id, loaded_weight, shard_id, shard_dim)
elif shard_id in ["gate", "up"]:
self._load_gate_up_weight(param, expert_id, loaded_weight, shard_id, shard_dim)
Expand Down
8 changes: 8 additions & 0 deletions fastdeploy/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"mix_quant",
"tensor_wise_fp8",
"kvcache",
"modelopt_fp4",
]


Expand Down Expand Up @@ -99,6 +100,11 @@ def _get_offline_quant_config_name(quantization_config, is_torch_weight, is_v1_l
has_block_size = "weight_block_size" in quantization_config
if quant_method == "fp8" and has_block_size:
quant_config_name = "block_wise_fp8"
elif quant_method == "modelopt":
if quantization_config.get("quant_algo", "") == "NVFP4":
quant_config_name = "modelopt_fp4"
else:
raise ValueError("modelopt only supports NVFP4 quantization.")
else:
raise ValueError("Torch weight offline quantization only supports block-wise FP8.")
else:
Expand All @@ -116,6 +122,7 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
from .block_wise_fp8 import BlockWiseFP8Config
from .kv_cache import KvCacheQuantConfig
from .mix_quant import MixQuantConfig
from .nvfp4 import ModelOptNvFp4Config
from .tensor_wise_fp8 import TensorWiseFP8Config
from .w4a8 import W4A8Config
from .w4afp8 import W4AFP8Config
Expand All @@ -137,6 +144,7 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
"tensor_wise_fp8": TensorWiseFP8Config,
"kvcache": KvCacheQuantConfig,
"mix_quant": MixQuantConfig,
"modelopt_fp4": ModelOptNvFp4Config,
}

return method_to_config[quantization]
Loading
Loading