Skip to content

Commit bda4536

Browse files
HIT-cwhHAOCHENYE
authored andcommitted
Cwh/speed up (#1076)
* fix bugs without fsdp * add fp8 docs * save profile files to exp_dir * support liger ce loss * add rms_norm kernel
1 parent ad6b6e0 commit bda4536

File tree

13 files changed

+1366
-12
lines changed

13 files changed

+1366
-12
lines changed
53.7 KB
Loading
12.1 KB
Loading
62.6 KB
Loading
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# FP8 训练
2+
3+
Hopper 架构的 GPU 引入了新的数据类型 FP8(8-bit floating point),可显著提升矩阵乘法的计算效率。下面将介绍如何在 XTuner 中使用 FP8 进行训练。
4+
5+
## 为什么选择 FP8
6+
7+
1. 降低通信量、提升通信速度:XTuner V1 基于 PyTorch FSDP 开发。相较 BF16,使用 FP8 通信可显著缓解 FSDP 通信量大的固有瓶颈。
8+
2. 提升矩阵乘计算效率。
9+
3. 节约显存:与 BF16 训练相比,FP8 训练中 Linear 和 Grouped Linear 层 PyTorch 计算图中保存的是 FP8 Tensor 而非 BF16 Tensor。可大幅降低计算图的显存开销。
10+
4. 精度具有保证:为了避免陷入“你别管我对不对,就问你快不快”的窘境,XTuner 采用了细粒度的 FP8 量化模式,在保证训练精度的前提下优化了训练速度。
11+
12+
## BenchMark
13+
14+
并行配置 | 训练配置 | SeqLen | GlobalBatchSize | GPUNum | TimePerIter (s) | Tokens/GPU/Second
15+
-- | -- | -- | -- | -- | -- | --
16+
tp1, ep1, pp1 | BF16 | 65536 | 256 | 256 | 32.77 | 2000
17+
tp1, ep1, pp1 | FP8 | 65536 | 256 | 256 | 26.75 | 2450
18+
19+
[profile data](https://drive.google.com/file/d/1TW-DbsUCckKJS36-5YHJo73L1Nvlpv6h/view?usp=sharing)
20+
21+
## 如何使用 XTuner FP8 训练
22+
23+
### 环境准备
24+
25+
首先检查 GPU 是否为 Hopper 及以上架构:
26+
27+
```python
28+
import torch
29+
30+
print(torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9))
31+
```
32+
33+
安装 `AdaptiveGEMM` 库:
34+
35+
```{code-block} shell
36+
:caption: 安装 AdaptiveGEMM
37+
38+
pip install git+https://github.com/InternLM/AdaptiveGEMM.git@main
39+
```
40+
41+
### 使用 XTuner 的 Linear 和 Grouped Linear 模块
42+
43+
```python
44+
import torch
45+
from xtuner.v1.float8 import TileWiseFloat8Linear, TileWiseFloat8GroupedLinear
46+
47+
# (bs, seq, dim)
48+
x = torch.randn(1, 32768, 1024, device='cuda', dtype=torch.bfloat16, requires_grad=True)
49+
linear = TileWiseFloat8Linear(in_features=1024, out_features=2048, bias=False, device='cuda', dtype=torch.bfloat16)
50+
out = linear(x)
51+
out.mean().backward()
52+
53+
x = torch.randn(1, 32768, 1024, device='cuda', dtype=torch.bfloat16)
54+
grouped_linear = TileWiseFloat8GroupedLinear(in_features=1024, out_features=2048, num_routed_experts=4, moe_bias=False).to(dtype=torch.bfloat16, device='cuda')
55+
tokens_per_expert = torch.tensor([1000, 4000, 6000, 32768 - 11000], device='cuda')
56+
out = grouped_linear(x, tokens_per_expert)
57+
out.mean().backward()
58+
```
59+
60+
```{tip}
61+
:class: margin
62+
63+
1. 单测 `TileWiseFloat8Linear` 与 `TileWiseFloat8GroupedLinear` 难以体现端到端的理想速度,因为对权重的量化较耗时。需结合 FSDP 才能达到最佳训练效率(可用 FP8 通信,且每个 rank 仅量化自身切片参数,使权重量化开销可忽略)。用法见下一小节。
64+
65+
2. 首次执行 fwd + bwd 速度较慢是正常现象,再次执行速度就会恢复正常。
66+
```
67+
68+
### 使用 XTuner FP8 训练
69+
70+
第一步,参考 [选择模型](model-cfg) 一节构建 model_cfg 实例,并配置 float8_cfg:
71+
72+
```{code-block} python
73+
:caption: 构建模型配置
74+
75+
from xtuner.v1.model import Qwen3Dense8BConfig
76+
from xtuner.v1.float8.config import Float8Config, ScalingGranularity
77+
78+
float8_cfg = Float8Config(
79+
scaling_granularity_gemm=ScalingGranularity.TILEWISE,
80+
scaling_granularity_grouped_gemm=ScalingGranularity.TILEWISE,
81+
)
82+
83+
model_cfg = Qwen3Dense8BConfig(float8_cfg=float8_cfg)
84+
```
85+
86+
第二步,参考 [使用 Trainer 进行大模型微调](trainer-sft) 一节后续内容构建 `trainer`
87+
88+
第三步,启动训练,完整代码如下:
89+
90+
````{toggle}
91+
```diff
92+
from xtuner.v1.model import Qwen3Dense8BConfig
93+
from xtuner.v1.config import LRConfig, AdamWConfig
94+
from xtuner.v1.train import Trainer
95+
+ from xtuner.v1.float8.config import Float8Config, ScalingGranularity
96+
97+
+ float8_cfg = Float8Config(
98+
+ scaling_granularity_gemm=ScalingGranularity.TILEWISE,
99+
+ scaling_granularity_grouped_gemm=ScalingGranularity.TILEWISE,
100+
+ )
101+
102+
- model_cfg = Qwen3Dense8BConfig()
103+
+ model_cfg = Qwen3Dense8BConfig(float8_cfg=float8_cfg)
104+
dataset_cfg = []
105+
optim_cfg = AdamWConfig(lr=6e-05)
106+
lr_cfg = LRConfig(lr_type="cosine", lr_min=1e-6)
107+
108+
load_from = "<模型路径>" # 如果是微调模式,必须指定,否则会重头训练
109+
tokenizer = "<tokenizer 路径,通常和模型路径一致>"
110+
111+
trainer = Trainer(
112+
model_cfg=model_cfg,
113+
tokenizer_path=tokenizer,
114+
load_from=load_from,
115+
optim_cfg=optim_cfg,
116+
dataset_cfg=dataset_cfg,
117+
lr_cfg=lr_cfg,
118+
)
119+
trainer.fit()
120+
```
121+
````
122+
123+
写完上述 python 脚本后,命名为 `toy_train.py`,我们就能通过 `torchrun` 启动分布式训练了:
124+
125+
```{code-block} bash
126+
:caption: 启动训练
127+
128+
torchrun --nproc_per_node=8 toy_train.py
129+
```
130+
131+
恭喜你,已经自己实现了一个 XTuner 的 FP8 训练入口!你可以在这个脚本里尽情地发挥,定制化自己的训练参数。
132+
133+
## XTuner FP8 训练策略
134+
135+
### FP8 量化
136+
137+
XTuner 采用对称量化:
138+
139+
```python
140+
s = absmax(x) / q_max
141+
q = clip(x / s, q_min, q_max)
142+
```
143+
144+
XTuner 支持以下三种量化粒度:Tensor-Wise, Block-Wise 和 Tile-Wise,如下图所示。相同颜色的元素共享同一个量化参数。在实际使用中,block_size 和 tile_size 一般会设置为 128。
145+
146+
![fp8_granularity](../../../assets/images/float8/fp8_granularity.png)
147+
148+
XTuner 采用了 "just-in-time scaling" 的量化方法,该策略根据输入 Tensor 实时计算出对应的缩放因子 (scales) 。
149+
150+
### FP8 算子
151+
152+
我们基于 [DeepGemm](https://github.com/deepseek-ai/DeepGEMM/tree/3b3783d06cd4d06ac4ba048633e604151d1ee535) 扩展了以下两项与 Grouped GEMM 相关的能力(感谢 DeepSeek 团队对开源社区的贡献):
153+
154+
1. 支持 Group Size M 非 128x 的情况以满足实际训练需求,细节见我们的论文 [TMA-Adaptive FP8 Grouped GEMM](https://arxiv.org/abs/2508.16584)
155+
2. 支持 Grouped Linear 的 Backward 算子 Group K GEMM。
156+
157+
需要额外说的是,为确保性能符合预期,Group K GEMM 算子要求 Group Size K 为 128 的倍数,这对我们的 AutoGrad 涉及提出了更高的要求,详情请见下一小节。
158+
159+
### FP8 混合精度训练
160+
161+
XTuner FP8 参考了 DeepSeek V3 中的 FP8 训练策略,如下图所示。对于主要的计算密集型算子(例如 GEMM 和 Grouped GEMM),我们采用了 FP8 来加速计算。算子接受 FP8 的输入并得到 BF16 的输出。下图中三个 Linear Module 涉及到的 GEMM 计算均使用 FP8 计算,我们将其命名为 Fprop (Forward Pass), Dgrad (Activation Backward
162+
Pass) 和 Wgrad (Weight Backward Pass)。与 BF16 相比,FP8 让 GEMM 的理论耗时减半。同时, PyTorch 计算图中只需保存 FP8 Tensor 即可完成 Backward 计算,进而节约了计算图的显存开销。
163+
164+
![fp8_overall](../../../assets/images/float8/fp8_overall.png)
165+
166+
进一步地,XTuner 细化了 FP8 Linear 和 Grouped Linear 的 AutoGrad 计算逻辑。这里我们以较为复杂的 Grouped Linear 为例展开介绍。如下图所示,在 Forward 和 Backward dx 计算中,我们对激活值采用了 Tile-Wise 的量化策略,对模型权重采用了 Block-Wise 的量化策略。而在 Backward dw 计算中,为了追求性能优势,我们对 Grad Output 采用了 Tile-Wise 的量化策略,而对 Forward 的输入 X 采用了 Block-Wise 的量化策略。
167+
168+
图中有一个需要特殊说明的地方是,在 Backward dw 的计算中,我们对 Forward 时的输入 X 进行了 Transpose + Block-Wise FP8 Quantize + Pad to 128x + Transpose 的操作,这是因为,为了达到理想的计算效率,FP8 GEMM 算子和 Grouped GEMM 算子要求 lhs 矩阵的 layout 是 Row-Major 的,而 rhs 矩阵则是 Column-Major。同时,如上一小节所述,Group K GEMM 算子要求 Group Size K 可以被 128 整除,我们把 Transpose + Block-Wise FP8 Quantize + Pad to 128x 融合成了一个算子以提高计算效率。
169+
170+
![fp8_overall](../../../assets/images/float8/fp8_autograd.png)
171+

docs/zh_cn/pretrain_sft/advanced_tutorial/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
model.md
99
dataset.md
1010
loss.md
11+
float8.md
1112
profile.md

docs/zh_cn/pretrain_sft/tutorial/llm_trainer.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
(trainer-sft)=
12
# 使用 Trainer 进行大模型微调
23

34
在之前的[教程](../../get_started/sft.md)中我们通过命令行,用最简单的方式启动了一次微调训练,而在这快速启动的背后,则是 XTuner 的核心组件 `Trainer` 在发挥作用。这一节我们将初识 Trainer,用更加细力度的方式控制训练的各个环节。
45

56

7+
(model-cfg)=
68
## 选择模型:
79

810
Trainer 通过配置文件的方式来构建模型,我们以 XTuner 内置支持的 `Qwen3 8B` 为例,来快速获取一个模型配置实例

xtuner/v1/float8/float8_gmm_tile_wise.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,12 @@ def forward(self, input: torch.Tensor, tokens_per_expert, decoding: bool = False
291291
weight_fp8 = view_weight.apply(weight_fp8, self.ori_local_shape)
292292
else:
293293
weight = weight.view(*self.ori_local_shape)
294-
weight_fp8 = weight_to_per_block_float8_dynamic.apply(weight, torch.float8_e4m3fn, group_size=128)
294+
weight_fp8 = weight_to_per_block_float8_dynamic.apply(weight, torch.float8_e4m3fn, 128)
295295

296+
orig_shape = input.shape
297+
input = input.view(-1, input.shape[-1])
296298
out = fp8_gmm_weight_per_block_act_per_tile.apply(input, weight_fp8, tokens_per_expert)
299+
out = out.view(*orig_shape[:-1], -1)
297300
return out
298301

299302
@property

xtuner/v1/float8/float8_linear_tile_wise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
229229
weight_fp8 = slice_weight.apply(weight, self.ori_shape) if self.is_padded else weight
230230
else:
231231
weight = weight.view(*self.ori_shape)
232-
weight_fp8 = weight_to_per_block_float8_dynamic.apply(weight, torch.float8_e4m3fn, group_size=128)
232+
weight_fp8 = weight_to_per_block_float8_dynamic.apply(weight, torch.float8_e4m3fn, 128)
233233

234234
out = fp8_matmul_weight_per_block_act_per_tile.apply(input, weight_fp8)
235235

xtuner/v1/loss/ce_loss.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from typing import Annotated, Literal, cast
2+
from typing import Annotated, Any, Literal, cast
33

44
import torch
55
import torch.distributed as dist
@@ -27,12 +27,17 @@ class CELossConfig(BaseLossConfig):
2727
loss_reduction (str): The reduction mode for the loss. Options are "token", "sample", and "square".
2828
"""
2929

30+
mode: Annotated[Literal["eager", "chunk", "liger"], Parameter(help="loss calculation mode")] = "eager" # type: ignore
3031
loss_reduction: Annotated[Literal["token", "sample", "square"], Parameter(help="loss reduction mode")] = "token"
3132

3233
@property
3334
def loss_ctx_cls(self) -> type["CELossContext"]:
3435
return CELossContext
3536

37+
def model_post_init(self, __context: Any) -> None:
38+
if self.mode == "liger":
39+
assert self.loss_reduction == "token", "Currently, cannot use liger kernel with sample or square reduction"
40+
3641

3742
class CELossKwargs(BaseLossKwargs):
3843
"""Keyword arguments for cross-entropy loss computation.
@@ -76,6 +81,17 @@ class CELossContext(BaseLossContext[CELossContextInputItem]):
7681
loss_cfg: CELossConfig
7782
loss_kwargs: CELossKwargs
7883

84+
def __init__(self, loss_cfg: CELossConfig, loss_kwargs: CELossKwargs):
85+
super().__init__(loss_cfg, loss_kwargs)
86+
87+
if loss_cfg.mode == "liger":
88+
from liger_kernel.transformers.fused_linear_cross_entropy import (
89+
LigerFusedLinearCrossEntropyLoss,
90+
)
91+
self.liger_loss_fct = LigerFusedLinearCrossEntropyLoss(reduction="sum")
92+
else:
93+
self.liger_loss_fct = None
94+
7995
@classmethod
8096
def build_batches_loss_kwargs(
8197
cls,
@@ -181,3 +197,27 @@ def pack(cls, loss_ctx_list: list[Self]) -> Self: # type: ignore
181197
shifted_loss_weights = torch.cat([i.loss_weight for i in loss_kwargs], dim=-1)
182198
cat_loss_kwargs = CELossKwargs(shifted_labels=shifted_labels, loss_weight=shifted_loss_weights)
183199
return cls(loss_cfg=loss_cfg, loss_kwargs=cat_loss_kwargs)
200+
201+
def chunk_mode(
202+
self,
203+
hidden_states: torch.Tensor,
204+
head_weight: torch.Tensor,
205+
head_bias: torch.Tensor | None,
206+
loss_kwargs: CELossKwargs,
207+
):
208+
if self.loss_cfg.mode == "chunk":
209+
return super().chunk_mode(hidden_states, head_weight, head_bias, loss_kwargs)
210+
else:
211+
assert self.liger_loss_fct is not None, "liger_loss_fct must be initialized in liger mode"
212+
shifted_labels = loss_kwargs.shifted_labels # (bs, seq_len)
213+
loss_weight = loss_kwargs.loss_weight # (bs, seq_len)
214+
215+
bs, seq, dim = hidden_states.shape
216+
hidden_states = hidden_states.reshape(bs * seq, dim)
217+
shifted_labels = shifted_labels.flatten()
218+
# liger kernel dont support reduction=="none"
219+
loss = self.liger_loss_fct(head_weight, hidden_states, shifted_labels)
220+
mask = loss_weight != 0
221+
w = loss_weight.sum() / mask.sum()
222+
loss = loss * w
223+
return loss, None

xtuner/v1/ops/rms_norm.py renamed to xtuner/v1/ops/rms_norm/__init__.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
from typing import Protocol
1+
from functools import partial
22

33
import torch
44

5-
6-
class RMSNormProtocol(Protocol):
7-
def __call__(self, x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor: ...
5+
from .protocol import RMSNormProtocol
86

97

108
def native_rms_norm(x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor:
@@ -19,7 +17,13 @@ def npu_rms_norm(x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch
1917
return torch_npu.npu_rms_norm(x, weight, epsilon=epsilon)[0]
2018

2119

22-
def get_rms_norm() -> RMSNormProtocol:
20+
def gpu_rms_norm(x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor:
21+
from .gpu import rms_norm_fn
22+
23+
return rms_norm_fn(x, weight, bias=None, eps=epsilon)
24+
25+
26+
def get_rms_norm_fn() -> RMSNormProtocol:
2327
from xtuner.v1.utils import get_device
2428

2529
device = get_device()
@@ -28,7 +32,7 @@ def get_rms_norm() -> RMSNormProtocol:
2832
elif device == "npu":
2933
return npu_rms_norm
3034
else:
31-
return native_rms_norm
35+
return gpu_rms_norm
3236

3337

34-
rms_norm = get_rms_norm()
38+
rms_norm = get_rms_norm_fn()

0 commit comments

Comments
 (0)