diff --git a/docs/quantization/nvfp4.md b/docs/quantization/nvfp4.md new file mode 100644 index 00000000000..c8edd091c79 --- /dev/null +++ b/docs/quantization/nvfp4.md @@ -0,0 +1,74 @@ + +# NVFP4 Quantization +NVFP4 is an innovative 4-bit floating-point format introduced by NVIDIA. For detailed information, please refer to [Introducing NVFP4 for Efficient and Accurate Low-Precision Inference](https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/). + +Based on [FlashInfer](https://github.com/flashinfer-ai/flashinfer), Fastdeploy supports NVFP4 quantized model inference in the format produced by [Modelopt](https://github.com/NVIDIA/TensorRT-Model-Optimizer). + +- Note: Currently, this feature only supports FP4 quantized models of Ernie/Qwen series. + +## How to Use +### Environment Setup +- **Supported Hardware**: GPU sm >= 100 +- **PaddlePaddle Version**: 3.3.0 or higher +- **Fastdeploy Version**: 2.4.0 or higher + +#### 1. Fastdeploy Installation +First, install the Fastdeploy base environment according to the [Fastdeploy NVIDIA GPU Environment Installation Guide](../../get_started/installation/nvidia_gpu.md). + +#### 2. Flashinfer Installation +```bash +git clone -b support-paddlepaddle-with-compatible-api-and-tvmffi https://github.com/PFCCLab/flashinfer/ --recursive + +cd flashinfer +python -m pip install -v . +``` + +### Running Inference Service +- Note: Need to set environment variable `export PADDLE_COMPATIBLE_API=true` and install the corresponding Flashinfer correctly +```bash +export PADDLE_COMPATIBLE_API=true +python -m fastdeploy.entrypoints.openai.api_server \ + --model nv-community/Qwen3-30B-A3B-FP4 \ + --port 8180 \ + --metrics-port 8181 \ + --engine-worker-queue-port 8182 \ + --cache-queue-port 8183 \ + --tensor-parallel-size 1 \ + --max-model-len 32768 \ + --max-num-seqs 128 +``` + +### API Access +Make service requests using the following command + +```shell +curl -X POST "http://0.0.0.0:8180/v1/chat/completions" \ +-H "Content-Type: application/json" \ +-d '{ + "messages": [ + {"role": "user", "content": "把李白的静夜思改写为现代诗"} + ] +}' +``` + +FastDeploy service interface is compatible with OpenAI protocol. You can make service requests using the following Python code. + +```python +import openai +host = "0.0.0.0" +port = "8180" +client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null") + +response = client.chat.completions.create( + model="null", + messages=[ + {"role": "system", "content": "I'm a helpful AI assistant."}, + {"role": "user", "content": "把李白的静夜思改写为现代诗"}, + ], + stream=True, +) +for chunk in response: + if chunk.choices[0].delta: + print(chunk.choices[0].delta.content, end='') +print('\n') +```. diff --git a/docs/zh/quantization/nvfp4.md b/docs/zh/quantization/nvfp4.md new file mode 100644 index 00000000000..62e6e36aa57 --- /dev/null +++ b/docs/zh/quantization/nvfp4.md @@ -0,0 +1,75 @@ +[English](../../quantization/nvfp4.md) + +# NVFP4量化 +NVFP4 是 NVIDIA 引入的创新 4 位浮点格式,详细介绍请参考[Introducing NVFP4 for Efficient and Accurate Low-Precision Inference](https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/)。 + +基于[FlashInfer](https://github.com/flashinfer-ai/flashinfer), Fastdeploy 支持[Modelopt](https://github.com/NVIDIA/TensorRT-Model-Optimizer) 产出格式的NVFP4量化模型推理。 + +- 注:目前该功能仅支持Ernie / Qwen系列的FP4量化模型。 + +## 如何使用 +### 环境安装 +- **支持硬件**:GPU sm >= 100 +- **PaddlePaddle 版本**:3.3.0 或更高版本 +- **Fastdeploy 版本**:2.4.0 或更高版本 + +#### 1. Fastdeploy 安装 +首先请根据[Fastdeploy NVIDIA GPU 环境安装指南](../../get_started/installation/nvidia_gpu.md),安装Fastdeploy基础环境。 + +#### 2. Flashinfer 安装 +```bash +git clone -b support-paddlepaddle-with-compatible-api-and-tvmffi https://github.com/PFCCLab/flashinfer/ --recursive + +cd flashinfer +python -m pip install -v . +``` + +### 运行推理服务 +- 注意:需要指定环境变量`export PADDLE_COMPATIBLE_API=true`并正确安装对应Flashinfer +```bash +export PADDLE_COMPATIBLE_API=true +python -m fastdeploy.entrypoints.openai.api_server \ + --model nv-community/Qwen3-30B-A3B-FP4 \ + --port 8180 \ + --metrics-port 8181 \ + --engine-worker-queue-port 8182 \ + --cache-queue-port 8183 \ + --tensor-parallel-size 1 \ + --max-model-len 32768 \ + --max-num-seqs 128 +``` + +### 接口访问 +通过如下命令发起服务请求 + +```shell +curl -X POST "http://0.0.0.0:8180/v1/chat/completions" \ +-H "Content-Type: application/json" \ +-d '{ + "messages": [ + {"role": "user", "content": "把李白的静夜思改写为现代诗"} + ] +}' +``` + +FastDeploy服务接口兼容OpenAI协议,可以通过如下Python代码发起服务请求。 + +```python +import openai +host = "0.0.0.0" +port = "8180" +client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null") + +response = client.chat.completions.create( + model="null", + messages=[ + {"role": "system", "content": "I'm a helpful AI assistant."}, + {"role": "user", "content": "把李白的静夜思改写为现代诗"}, + ], + stream=True, +) +for chunk in response: + if chunk.choices[0].delta: + print(chunk.choices[0].delta.content, end='') +print('\n') +``` diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 23030d6a80b..558f7b3dc09 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -128,6 +128,10 @@ "FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR": lambda: os.getenv("FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR", ""), # Timeout for cache_transfer_manager process exit "FD_CACHE_PROC_EXIT_TIMEOUT": lambda: int(os.getenv("FD_CACHE_PROC_EXIT_TIMEOUT", "600")), + # FP4 dense GEMM backend, could be flashinfer-cutlass, flashinfer-trtllm, flashinfer-cudnn or None (default is None) + "FD_NVFP4_GEMM_BACKEND": lambda: os.getenv("FD_NVFP4_MOE_BACKEND", None), + # Flahinfer MOE backend, could be flashinfer-cutlass, flashinfer-trtllm or None (default is None) + "FD_FLASHINFER_MOE_BACKEND": lambda: os.getenv("FD_FLASHINFER_MOE_BACKEND", None), # Count for cache_transfer_manager process error "FD_CACHE_PROC_ERROR_COUNT": lambda: int(os.getenv("FD_CACHE_PROC_ERROR_COUNT", "10")), # API_KEY required for service authentication diff --git a/fastdeploy/flashinfer.py b/fastdeploy/flashinfer.py new file mode 100644 index 00000000000..23634faed5f --- /dev/null +++ b/fastdeploy/flashinfer.py @@ -0,0 +1,40 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import functools +import importlib +import importlib.util +import os +import shutil + +from paddleformers.utils.log import logger + + +@functools.cache +def has_flashinfer() -> bool: + """Return `True` if FlashInfer is available.""" + # Use find_spec to check if the module exists without importing it + # This avoids potential CUDA initialization side effects + if os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() not in ["1", "on", "true"]: + # currently must support by Paddle compatible API + logger.warning("FlashInfer is not supported by Paddle compatible API.") + return False + if importlib.util.find_spec("flashinfer") is None: + return False + # Also check if nvcc is available since it's required to JIT compile flashinfer + if shutil.which("nvcc") is None: + return False + return True diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index e7725be6d23..3227edac765 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -82,7 +82,6 @@ def process_loaded_weights(self, layer, weights) -> None: layer.weight.set_value(weights) def apply(self, layer: nn.Layer, x: paddle.Tensor) -> paddle.Tensor: - linear_out = paddle.matmul(x, layer.weight) if layer.with_bias: linear_out = paddle.add(linear_out, layer.bias) diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index ede87972185..4959029d917 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -285,10 +285,10 @@ def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_ expert_param = param[expert_id - self.expert_id_offset] dim = -1 if shard_dim else 0 param_shard_size = expert_param.shape[dim] // 2 - if shard_id == "gate": + switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False) + if (shard_id == "gate" and not switch_w13) or (shard_id == "up" and switch_w13): param_shard_offset = 0 else: - # shard_id == "up": param_shard_offset = param_shard_size expert_param = slice_fn( expert_param, shard_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size @@ -302,8 +302,12 @@ def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_ ) # To ensure compatibility across backends, apply an extra transpose for GCU and XPU + if expert_param.shape != loaded_weight.shape: - loaded_weight = loaded_weight.transpose([1, 0]) + if len(expert_param.shape) != len(loaded_weight.shape): + loaded_weight = loaded_weight.reshape(expert_param.shape) + else: + loaded_weight = loaded_weight.transpose([1, 0]) assert expert_param.shape == loaded_weight.shape, ( f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})" ) @@ -360,6 +364,32 @@ def _load_fused_experts_weight(self, param, loaded_weight): for i in range(self.num_local_experts): param.tensor_track.mark(start=0, batch_id=i) + def _load_per_tensor_weight_scale( + self, + param, + expert_id, + loaded_weight, + shard_id, + ): + loaded_weight = get_tensor(loaded_weight) + expert_param = param[expert_id - self.expert_id_offset] + if shard_id in ["gate", "up"]: + idx = 0 if shard_id == "gate" else 1 + if expert_param[idx].shape != loaded_weight.shape: + if len(expert_param[idx].shape) != len(loaded_weight.shape): + loaded_weight = loaded_weight.reshape(expert_param[idx].shape) + else: + loaded_weight = loaded_weight.transpose([1, 0]) + + expert_param[idx].set_value(loaded_weight) + elif shard_id == "down": + if expert_param.shape != loaded_weight.shape: + if len(expert_param.shape) != len(loaded_weight.shape): + loaded_weight = loaded_weight.reshape(expert_param.shape) + else: + loaded_weight = loaded_weight.transpose([1, 0]) + expert_param.set_value(loaded_weight) + def _load_expert_weight( self, param, @@ -368,7 +398,10 @@ def _load_expert_weight( shard_id, shard_dim=None, ): - if shard_id == "down": + weight_type = getattr(param, "weight_type", None) + if weight_type in ["weight_scale_2", "input_scale"]: + self._load_per_tensor_weight_scale(param, expert_id, loaded_weight, shard_id) + elif shard_id == "down": self._load_down_weight(param, expert_id, loaded_weight, shard_id, shard_dim) elif shard_id in ["gate", "up"]: self._load_gate_up_weight(param, expert_id, loaded_weight, shard_id, shard_dim) diff --git a/fastdeploy/model_executor/layers/quantization/__init__.py b/fastdeploy/model_executor/layers/quantization/__init__.py index f8716369852..a6bffde03db 100644 --- a/fastdeploy/model_executor/layers/quantization/__init__.py +++ b/fastdeploy/model_executor/layers/quantization/__init__.py @@ -33,6 +33,7 @@ "mix_quant", "tensor_wise_fp8", "kvcache", + "modelopt_fp4", ] @@ -99,6 +100,11 @@ def _get_offline_quant_config_name(quantization_config, is_torch_weight, is_v1_l has_block_size = "weight_block_size" in quantization_config if quant_method == "fp8" and has_block_size: quant_config_name = "block_wise_fp8" + elif quant_method == "modelopt": + if quantization_config.get("quant_algo", "") == "NVFP4": + quant_config_name = "modelopt_fp4" + else: + raise ValueError("modelopt only supports NVFP4 quantization.") else: raise ValueError("Torch weight offline quantization only supports block-wise FP8.") else: @@ -116,6 +122,7 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]: from .block_wise_fp8 import BlockWiseFP8Config from .kv_cache import KvCacheQuantConfig from .mix_quant import MixQuantConfig + from .nvfp4 import ModelOptNvFp4Config from .tensor_wise_fp8 import TensorWiseFP8Config from .w4a8 import W4A8Config from .w4afp8 import W4AFP8Config @@ -137,6 +144,7 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]: "tensor_wise_fp8": TensorWiseFP8Config, "kvcache": KvCacheQuantConfig, "mix_quant": MixQuantConfig, + "modelopt_fp4": ModelOptNvFp4Config, } return method_to_config[quantization] diff --git a/fastdeploy/model_executor/layers/quantization/nvfp4.py b/fastdeploy/model_executor/layers/quantization/nvfp4.py new file mode 100644 index 00000000000..49c88be7882 --- /dev/null +++ b/fastdeploy/model_executor/layers/quantization/nvfp4.py @@ -0,0 +1,628 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional + +import paddle +from paddleformers.utils.log import logger + +import fastdeploy +from fastdeploy import envs +from fastdeploy.flashinfer import has_flashinfer +from fastdeploy.model_executor.layers.moe import FusedMoE +from fastdeploy.model_executor.utils import ( + create_parameter_and_copy, + free_tensor, + set_weight_attrs, +) + +from .quant_base import QuantConfigBase, QuantMethodBase + +if has_flashinfer(): + paddle.compat.enable_torch_proxy() + from flashinfer import fp4_quantize + from flashinfer import mm_fp4 as fp4_gemm + from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe +else: + logger.warning("FlashInfer is not installed. For nvFp4 inference, please install Flashinfer.") + + +def next_power_of_2(n: int): + return 1 << (n - 1).bit_length() if n > 0 else 1 + + +class ModelOptNvFp4Config(QuantConfigBase): + """ + quantization config for ModelOpt Nvfp4 datatype + """ + + def __init__( + self, + is_checkpoint_nvfp4_serialized: bool, + kv_cache_quant_algo: str | None, + exclude_modules: list[str], + group_size: int = 16, + is_checkpoint_bf16: bool = False, + ) -> None: + self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized + if is_checkpoint_nvfp4_serialized: + logger.warning( + "Detected ModelOpt NVFP4 checkpoint. Please note that" + " the format is experimental and could change in future." + ) + + self.group_size = group_size + self.kv_cache_quant_algo = kv_cache_quant_algo + self.exclude_modules = exclude_modules + + self.quant_max_bound = 6 + self.quant_min_bound = -6 + self.quant_round_type = 1 + self.is_checkpoint_bf16 = is_checkpoint_bf16 + + def name(self) -> str: + return "modelopt_fp4" + + @classmethod + def from_config(cls, config: dict) -> "ModelOptNvFp4Config": + quant_config = config + quant_method = quant_config.get("quant_algo", "") + if not quant_method: + raise ValueError("Missing 'quant_algo' in quantization config") + + # Handle kv_cache_quant_algo with proper type validation + kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo") + if kv_cache_quant_algo_raw is None: + # No KV cache quantization by default + kv_cache_quant_algo = None + elif isinstance(kv_cache_quant_algo_raw, str): + kv_cache_quant_algo = kv_cache_quant_algo_raw + else: + raise ValueError(f"kv_cache_quant_algo must be a string, got " f"{type(kv_cache_quant_algo_raw)}") + + # Handle group_size with proper type validation + group_size_raw = quant_config.get("group_size") + if group_size_raw is None: + group_size = 16 # Default value + elif isinstance(group_size_raw, int): + group_size = group_size_raw + else: + try: + group_size = int(group_size_raw) + except (ValueError, TypeError): + raise ValueError(f"group_size must be an integer, got {type(group_size_raw)}") from None + + # "exclude_modules" is the key in the legacy hf_quant_config.json + exclude_modules = quant_config.get("exclude_modules", []) + if not isinstance(exclude_modules, list): + raise ValueError(f"exclude_modules must be a list, got {type(exclude_modules)}") + + is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method + + # For FP4, these fields are required + if is_checkpoint_nvfp4_serialized and "quantization" in config: + # Check if required fields are present in the quantization config + quant_config = config["quantization"] + required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"] + missing_fields = [field for field in required_fields if field not in quant_config] + if missing_fields: + raise ValueError( + f"NVFP4 quantization requires the following fields in " f"hf_quant_config.json: {missing_fields}" + ) + + return cls( + is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized, + kv_cache_quant_algo=kv_cache_quant_algo, + exclude_modules=exclude_modules, + group_size=group_size, + ) + + def get_quant_method(self, layer) -> Optional[QuantMethodBase]: + """ + Get quantization method. + """ + if isinstance(layer, FusedMoE): + return ModelOptNvFp4FusedMoE(self) + else: + return ModelOptNvFp4LinearMethod(self) + + return None + + +class ModelOptNvFp4LinearMethod(QuantMethodBase): + """Linear method for Model Optimizer NVFP4. + Supports loading NVFP4 checkpoints with the following structure: + + input_scale: paddle.float32, scalar , + weight: NVFP4(represented as byte) Shape: [1, X, y/2] + weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale, + weight_scale_2: paddle.float32, scalar, + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: ModelOptNvFp4Config) -> None: + self.quant_config = quant_config + + self.backend = "none" + if envs.FD_NVFP4_GEMM_BACKEND is None: + if has_flashinfer(): + self.backend = "flashinfer-cutlass" + elif envs.FD_NVFP4_GEMM_BACKEND.startswith("flashinfer-"): + self.backend = envs.FD_NVFP4_GEMM_BACKEND + assert has_flashinfer(), f"FlashInfer is required for {self.backend}" + + if self.backend == "none": + raise ValueError( + "No valid NVFP4 GEMM backend found. Please check your platform capability and installtion of Flashinfer." + ) + + logger.info(f"Using {self.backend} for NVFP4 GEMM") + + def create_weights( + self, + layer, + **extra_weight_attrs, + ): + extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"] + weight_shape = layer.weight_shape[::-1] + weight_shape[1] = weight_shape[1] // 2 + layer.weight_dtype = "uint8" + input_scale_shape = [1] + weight_scale_shape = [layer.weight_shape[::-1][0], layer.weight_shape[::-1][1] // self.quant_config.group_size] + weight_scale_2_shape = [1] + + self._create_main_weight(layer, weight_shape, extra_weight_attrs) + self._create_input_scale(layer, input_scale_shape) + self._create_weight_scales(layer, weight_scale_shape, weight_scale_2_shape, extra_weight_attrs) + + def _create_main_weight(self, layer, weight_shape, extra_weight_attrs): + """创建主权重参数 + + 参数: + layer: 当前层对象 + weight_shape: 权重形状 + extra_weight_attrs: 额外权重属性 + """ + layer.weight = layer.create_parameter( + shape=weight_shape, + dtype=layer.weight_dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + set_weight_attrs( + layer.weight, + extra_weight_attrs, + ) + + def _create_input_scale(self, layer, input_scale_shape): + """创建输入缩放参数 + + 参数: + layer: 当前层对象 + input_scale_shape: 输入缩放形状 + """ + layer.input_scale = layer.create_parameter( + shape=input_scale_shape, + dtype=paddle.float32, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + def _create_weight_scales(self, layer, weight_scale_shape, weight_scale_2_shape, extra_weight_attrs): + """创建权重缩放参数 + + 参数: + layer: 当前层对象 + weight_scale_shape: 权重缩放形状 + weight_scale_2_shape: 权重缩放2形状 + extra_weight_attrs: 额外权重属性 + """ + layer.weight_scale_2 = layer.create_parameter( + shape=weight_scale_2_shape, + dtype=paddle.float32, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.weight_scale = layer.create_parameter( + shape=weight_scale_shape, + dtype=paddle.float8_e4m3fn, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + set_weight_attrs( + layer.weight_scale, + extra_weight_attrs, + ) + + def process_weights_after_loading(self, layer) -> None: + def _process_scale_interleaved(scales): + scale_dim = len(scales.shape) + if scale_dim == 2: + scales = scales.unsqueeze(0) + assert len(scales.shape) == 3 + B, M, K = scales.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scales = paddle.empty([B, M_padded, K_padded], dtype=scales.dtype) + padded_scales[:B, :M, :K].copy_(scales) + batches, rows, cols = padded_scales.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scales = padded_scales.reshape(batches, rows // 128, 4, 32, cols // 4, 4) + padded_scales = padded_scales.transpose([0, 1, 4, 3, 2, 5]) + padded_scales = padded_scales.contiguous().to(paddle.device.get_device()) + padded_scales = ( + padded_scales.reshape(M_padded, K_padded) + if scale_dim == 2 + else padded_scales.reshape(B, M_padded, K_padded) + ) + return padded_scales + + input_scale_2 = layer.input_scale.max().to(paddle.float32) + weight_scale_2 = layer.weight_scale_2.max().to(paddle.float32) + alpha = input_scale_2 * weight_scale_2 + input_scale_inv = (1 / input_scale_2).to(paddle.float32) + weight_scale_interleaved = _process_scale_interleaved(layer.weight_scale) + free_tensor(layer.input_scale) + free_tensor(layer.weight_scale_2) + + layer.weight_scale_2 = layer.create_parameter( + shape=weight_scale_2.shape, + dtype=weight_scale_2.dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.input_scale = layer.create_parameter( + shape=input_scale_2.shape, + dtype=input_scale_2.dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.alpha = layer.create_parameter( + shape=alpha.shape, + dtype=alpha.dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.input_scale_inv = layer.create_parameter( + shape=input_scale_inv.shape, + dtype=input_scale_inv.dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.weight_scale_interleaved = layer.create_parameter( + shape=weight_scale_interleaved.shape, + dtype=weight_scale_interleaved.dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.weight_scale_2.copy_(weight_scale_2, False) + layer.input_scale.copy_(input_scale_2, False) + layer.alpha.copy_(alpha, False) + layer.input_scale_inv.copy_(input_scale_inv, False) + layer.weight_scale_interleaved.copy_(weight_scale_interleaved, False) + + def apply( + self, + layer, + x, + ): + x_m, _ = x.shape + w_n, _ = layer.weight.shape + output_shape = [x_m, w_n] + output_dtype = x.dtype + + # Quantize BF16 or FP16 to (FP4 and interleaved block scale) + x_fp4, x_scale_interleaved = fp4_quantize(x, layer.input_scale_inv) + + assert x_fp4.dtype == paddle.uint8 + assert layer.weight.dtype == paddle.uint8 + assert layer.weight_scale_interleaved.dtype == paddle.float8_e4m3fn + assert layer.alpha.dtype == paddle.float32 + + if self.backend.startswith("flashinfer-"): + backend = self.backend[len("flashinfer-") :] + else: + raise ValueError(f"Unsupported backend: {self.backend}.") + + w = layer.weight.T + w_scale_interleaved = layer.weight_scale_interleaved.T + + if backend == "cutlass": + x_scale_interleaved = x_scale_interleaved.view(paddle.uint8) + w_scale_interleaved = w_scale_interleaved.view(paddle.uint8) + out = fp4_gemm(x_fp4, w, x_scale_interleaved, w_scale_interleaved, layer.alpha, output_dtype, backend=backend) + if layer.with_bias: + out = paddle.add(out, layer.bias) + return out.view(*output_shape) + + +class ModelOptNvFp4FusedMoE(QuantMethodBase): + """Fused MoE method for Model Optimizer NVFP4. + Supports loading NVFP4 checkpoints with the following structure: + + input_scale: paddle.float32, scalar , + weight: NVFP4(represented as byte) Shape: [1, X, y/2] + weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale, + weight_scale_2: paddle.float32, scalar, + Args: + quant_config: The ModelOpt quantization config. + moe_config: The MoE configuration. + layer: The linear layer. + """ + + def __init__(self, quant_config: ModelOptNvFp4Config): + self.quant_config = quant_config + self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] + self.added_scale_attrs = [ + "up_gate_proj_weight_scale", + "down_proj_weight_scale", + ] + self.quant_config = quant_config + self.backend = "none" + + if envs.FD_FLASHINFER_MOE_BACKEND is None: + # currently support flashinfer-cutlass, flashinfer-trtllm will support in the future + if has_flashinfer(): + self.backend = "flashinfer-cutlass" + elif envs.FD_FLASHINFER_MOE_BACKEND.startswith("flashinfer-"): + self.backend = envs.FD_FLASHINFER_MOE_BACKEND + assert has_flashinfer(), f"FlashInfer is required for MoE backend {self.backend}" + + if self.backend == "none": + raise ValueError( + "No valid NVFP4 flashinfer MoE backend found. Please check your platform capability and installtion of FlashInfer." + ) + + logger.info(f"Using {self.backend} for NVFP4 FusedMoE") + + def create_weights(self, layer, **extra_weight_attrs): + """ + Triton MoE create weight process. + """ + self.up_gate_proj_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + layer.hidden_size // 2, + ] + self.down_proj_weight_shape = [ + layer.num_local_experts, + layer.hidden_size, + layer.moe_intermediate_size // 2, + ] + self.up_gate_proj_scale_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + layer.hidden_size // self.quant_config.group_size, + ] + self.down_proj_scale_shape = [ + layer.num_local_experts, + layer.hidden_size, + layer.moe_intermediate_size // self.quant_config.group_size, + ] + + self.weight_scale_dtype = paddle.float8_e4m3fn + self.weight_dtype = paddle.uint8 + self.added_scale_attrs = ["up_gate_proj_weight_scale", "down_proj_weight_scale"] + up_gate_proj_weight_name = self.added_weight_attrs[0] + down_proj_weight_name = self.added_weight_attrs[1] + up_gate_proj_scale_name = self.added_scale_attrs[0] + down_proj_scale_name = self.added_scale_attrs[1] + setattr( + layer, + up_gate_proj_weight_name, + layer.create_parameter( + shape=self.up_gate_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_weight_name, + layer.create_parameter( + shape=self.down_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # weight_scale + setattr( + layer, + up_gate_proj_scale_name, + layer.create_parameter( + shape=self.up_gate_proj_scale_shape, + dtype=self.weight_scale_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_scale_name, + layer.create_parameter( + shape=self.down_proj_scale_shape, + dtype=self.weight_scale_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # weight_scale_2 + layer.up_gate_proj_weight_scale_2 = layer.create_parameter( + shape=[layer.num_local_experts, 2], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.down_proj_weight_scale_2 = layer.create_parameter( + shape=[layer.num_local_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ) + # input_scale + layer.up_gate_proj_input_scale = layer.create_parameter( + shape=[layer.num_local_experts, 2], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.down_proj_input_scale = layer.create_parameter( + shape=[layer.num_local_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ) + + set_weight_attrs( + getattr(layer, up_gate_proj_weight_name), + {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}, + ) + set_weight_attrs( + getattr(layer, up_gate_proj_scale_name), + {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}, + ) + + set_weight_attrs( + getattr(layer, down_proj_weight_name), + {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}, + ) + set_weight_attrs( + getattr(layer, down_proj_scale_name), + {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}, + ) + + set_weight_attrs( + layer.up_gate_proj_weight_scale_2, + {**extra_weight_attrs, "weight_type": "weight_scale_2"}, + ) + set_weight_attrs(layer.down_proj_weight_scale_2, {**extra_weight_attrs, "weight_type": "weight_scale_2"}) + set_weight_attrs(layer.up_gate_proj_input_scale, {**extra_weight_attrs, "weight_type": "input_scale"}) + set_weight_attrs(layer.down_proj_input_scale, {**extra_weight_attrs, "weight_type": "input_scale"}) + + def swizzle_blockscale(self, scale): + assert scale.dtype == paddle.float8_e4m3fn + # Pad and blockwise interleave weight_scale + scale_dim = len(scale.shape) + if len(scale.shape) == 2: + scale = scale.unsqueeze(0) + assert len(scale.shape) == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = paddle.empty([B, M_padded, K_padded], dtype=scale.dtype) + padded_scale[:B, :M, :K].copy_(scale) + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().to(paddle.device.get_device()) + return ( + swizzled_scale.reshape(M_padded, K_padded) + if scale_dim == 2 + else swizzled_scale.reshape(B, M_padded, K_padded) + ) + + @property + def load_up_proj_weight_first(self) -> bool: + # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 + # 目前默认给True + return True + + def process_weights_after_loading(self, layer): + """ """ + up_gate_proj_weight_scale_2 = layer.up_gate_proj_weight_scale_2[:, 0] + free_tensor(layer.up_gate_proj_weight_scale_2) + create_parameter_and_copy(layer, name="up_gate_proj_weight_scale_2", weight=up_gate_proj_weight_scale_2) + up_gate_proj_input_scale = paddle.max(layer.up_gate_proj_input_scale).cast("float32") + down_proj_input_scale = paddle.max(layer.down_proj_input_scale).cast("float32") + + # Create shared parameters + create_parameter_and_copy( + layer, "g1_alphas", (up_gate_proj_input_scale * up_gate_proj_weight_scale_2).cast("float32") + ) + create_parameter_and_copy( + layer, "g2_alphas", (down_proj_input_scale * layer.down_proj_weight_scale_2).cast("float32") + ) + create_parameter_and_copy( + layer, "up_gate_proj_input_scale_quant", (1 / up_gate_proj_input_scale).cast("float32") + ) + create_parameter_and_copy(layer, "down_proj_input_scale_quant", (1 / down_proj_input_scale).cast("float32")) + + for name, weight_scale in [ + ("up_gate", layer.up_gate_proj_weight_scale), + ("down", layer.down_proj_weight_scale), + ]: + assert weight_scale.shape[2] % 16 == 0, f"Expected {name}_weight_scale.dim(2) to be divisible by 16" + assert ( + weight_scale.dtype == paddle.float8_e4m3fn + ), f"{name} Weight Blockscale must be represented as FP8-E4M3" + + up_gate_proj_blockscale_swizzled = self.swizzle_blockscale(layer.up_gate_proj_weight_scale) + free_tensor(layer.up_gate_proj_weight_scale) + layer.up_gate_proj_weight_scale = None + create_parameter_and_copy( + layer, name="up_gate_proj_blockscale_swizzled", weight=up_gate_proj_blockscale_swizzled + ) + down_proj_blockscale_swizzled = self.swizzle_blockscale(layer.down_proj_weight_scale) + free_tensor(layer.down_proj_weight_scale) + layer.down_proj_weight_scale = None + create_parameter_and_copy(layer, name="down_proj_blockscale_swizzled", weight=down_proj_blockscale_swizzled) + + def apply(self, layer, x, gate): + """ + flashinfer nvfp4 fusedmoe for Model Optimizer + """ + gate_out = gate(x.cast("float32")) + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, # apply_norm_weight, + False, + ) + + output_dtype = x.dtype + x_sf = None + output = paddle.empty_like(x) + + if self.backend == "flashinfer-cutlass": + # flashinfer cutlass + _ = flashinfer_cutlass_fused_moe( + input=x, + token_selected_experts=topk_ids.to(paddle.int), + token_final_scales=topk_weights, + fc1_expert_weights=getattr(layer, self.added_weight_attrs[0]).view(paddle.long), + fc2_expert_weights=getattr(layer, self.added_weight_attrs[1]).view(paddle.long), + output_dtype=output_dtype, + input_sf=x_sf, + quant_scales=[ + layer.up_gate_proj_input_scale_quant, + layer.up_gate_proj_blockscale_swizzled.view(paddle.int32), + layer.g1_alphas, + layer.down_proj_input_scale_quant, + layer.down_proj_blockscale_swizzled.view(paddle.int32), + layer.g2_alphas, + ], + ep_size=layer.ep_size, + ep_rank=layer.ep_rank, + tp_size=layer.tp_size, + tp_rank=layer.tp_rank, + tune_max_num_tokens=next_power_of_2(x.shape[0]), + output=output, + ) + + return output + + # flashinfer-trtllm + return output diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 3b42e0294e6..29ea12a728b 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -267,6 +267,19 @@ def free_tensor(tensor): del tensor +def create_parameter_and_copy(layer, name, weight): + setattr( + layer, + name, + layer.create_parameter( + shape=weight.shape, + dtype=weight.dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + getattr(layer, name).copy_(weight, False) + + def fd_cast(weight, param): if weight.dtype != param.dtype: if weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn: @@ -489,6 +502,8 @@ def fn(loaded_weight_name, is_moe): # Can be extended to other offline quantization suffixes if needed. if (is_moe and moe_quant_type == "block_wise_fp8") or (not is_moe and dense_quant_type == "block_wise_fp8"): fd_suffix_map = fp8_suffix_map + else: + fd_suffix_map = {} for ckpt_suffix, fd_suffix in fd_suffix_map.items(): if re.search(rf"{ckpt_suffix}$", loaded_weight_name): loaded_weight_name = loaded_weight_name.replace(ckpt_suffix, fd_suffix) diff --git a/tests/quantization/test_modelopt_nvfp4.py b/tests/quantization/test_modelopt_nvfp4.py new file mode 100644 index 00000000000..6015a0dff03 --- /dev/null +++ b/tests/quantization/test_modelopt_nvfp4.py @@ -0,0 +1,94 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest +from unittest import mock + +import paddle + +from fastdeploy.flashinfer import has_flashinfer +from fastdeploy.model_executor.layers.linear import QKVParallelLinear +from fastdeploy.model_executor.layers.moe import FusedMoE +from fastdeploy.model_executor.layers.quantization.nvfp4 import ( + ModelOptNvFp4Config, + ModelOptNvFp4FusedMoE, + ModelOptNvFp4LinearMethod, +) + + +def get_sm_version(): + prop = paddle.device.cuda.get_device_properties() + cc = prop.major * 10 + prop.minor + return cc + + +@unittest.skipIf( + not paddle.is_compiled_with_cuda() or get_sm_version() < 100, + "Nvfp4 do not support sm < 100.", +) +class TestModelOptNvFp4Config(unittest.TestCase): + def setUp(self): + prop = paddle.device.cuda.get_device_properties() + self.sm_version = prop.major * 10 + prop.minor + + self.raw_config = { + "config_groups": { + "group_0": { + "input_activations": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": 16}, + "weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": 16}, + "targets": ["Linear"], + } + }, + "quant_algo": "NVFP4", + "producer": {"name": "modelopt", "version": "0.34.1.dev85+g7a72957d"}, + "quant_method": "modelopt", + } + + self.config = ModelOptNvFp4Config.from_config(self.raw_config) + + def test_name(self): + """Test name() method""" + self.assertEqual(self.config.name(), "modelopt_fp4") + + def test_from_config(self): + """Test from_config with full dict""" + cfg = ModelOptNvFp4Config.from_config(self.raw_config) + self.assertFalse(cfg.is_checkpoint_bf16) + self.assertTrue(cfg.is_checkpoint_nvfp4_serialized) + self.assertEqual(cfg.group_size, 16) + self.assertEqual(cfg.exclude_modules, []) + self.assertEqual(cfg.kv_cache_quant_algo, None) + self.assertEqual(cfg.quant_max_bound, 6) + self.assertEqual(cfg.quant_min_bound, -6) + self.assertEqual(cfg.quant_round_type, 1) + + @unittest.skipIf(not has_flashinfer(), "Skip if no FlashInfer available") + def test_get_quant_method_linear(self): + """Test get_quant_method with a linear layer""" + layer = mock.Mock(spec=QKVParallelLinear) + method = self.config.get_quant_method(layer) + assert isinstance(method, ModelOptNvFp4LinearMethod) + + @unittest.skipIf(not has_flashinfer(), "Skip if no FlashInfer available") + def test_get_quant_method_fused_moe(self): + """Test get_quant_method with a moe layer""" + layer = mock.Mock(spec=FusedMoE) + method = self.config.get_quant_method(layer) + assert isinstance(method, ModelOptNvFp4FusedMoE) + + +if __name__ == "__main__": + unittest.main()