Skip to content

Commit 27126d2

Browse files
committed
[fix] fp8 precision
1 parent a3fa6c6 commit 27126d2

File tree

1 file changed

+19
-22
lines changed

1 file changed

+19
-22
lines changed

csrc/quant/per_token_quantize_bf16_fp8.cu

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ __global__ void device_per_token_quant_bf16_to_fp8_general(
1818
const int32_t bid = blockIdx.x;
1919
const int32_t tid = threadIdx.x;
2020
constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format
21-
21+
2222
const bf16_t* _input = input + bid * N; // Input pointer for the token
2323
fp8_e4m3_t* _output = output + bid * N; // Output pointer for the token
2424

@@ -44,15 +44,14 @@ __global__ void device_per_token_quant_bf16_to_fp8_general(
4444
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
4545

4646
// Compute the scale factor with epsilon to avoid division by zero
47-
constexpr fp32_t epsilon = 1e-7f;
48-
const fp32_t scale = reduced_max / FP8_E4M3_MAX;
49-
const fp32_t inv_scale = 1.0f / (scale + epsilon);
47+
constexpr fp32_t epsilon = 1.0f / (FP8_E4M3_MAX * 512.0f);
48+
const fp32_t scale = fmaxf(epsilon, reduced_max / FP8_E4M3_MAX);
5049

5150
for (int32_t i = tid; i < N; i += TPB) {
5251
local_bf16 = workspace1[i];
53-
52+
5453
fp32_t tmp = cvt_bf16_f32(local_bf16);
55-
fp32_t x = tmp * inv_scale;
54+
fp32_t x = tmp / scale;
5655
local_f8 = fp8_e4m3_t(x);
5756

5857
_output[i] = local_f8;
@@ -77,7 +76,7 @@ __global__ void device_per_token_quant_bf16_to_fp8_vpt(
7776
const int32_t bid = blockIdx.x;
7877
const int32_t tid = threadIdx.x;
7978
constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format
80-
79+
8180
const bf16_t* _input = input + bid * N; // Input pointer for the token
8281
fp8_e4m3_t* _output = output + bid * N; // Output pointer for the token
8382

@@ -110,9 +109,8 @@ __global__ void device_per_token_quant_bf16_to_fp8_vpt(
110109
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
111110

112111
// Compute the scale factor with epsilon to avoid division by zero
113-
constexpr fp32_t epsilon = 1e-7f;
114-
const fp32_t scale = reduced_max / FP8_E4M3_MAX;
115-
const fp32_t inv_scale = 1.0f / (scale + epsilon);
112+
constexpr fp32_t epsilon = 1.0f / (FP8_E4M3_MAX * 512.0f);
113+
const fp32_t scale = fmaxf(epsilon, reduced_max / FP8_E4M3_MAX);
116114

117115
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
118116
vec_copy<sizeof(bf16_t) * VPT>(workspace2 + (i >> 1), local_bf16);
@@ -122,10 +120,10 @@ __global__ void device_per_token_quant_bf16_to_fp8_vpt(
122120
fp32x2_t x = bf16x2_to_fp32x2(local_bf16[2 * j + 0]);
123121
fp32x2_t y = bf16x2_to_fp32x2(local_bf16[2 * j + 1]);
124122
fp32x4_t ret = make_float4(
125-
x.x * inv_scale,
126-
x.y * inv_scale,
127-
y.x * inv_scale,
128-
y.y * inv_scale
123+
x.x / scale,
124+
x.y / scale,
125+
y.x / scale,
126+
y.y / scale
129127
);
130128
local_f8[j] = fp8x4_e4m3_t(ret);
131129
}
@@ -155,7 +153,7 @@ __global__ void device_per_token_quant_bf16_to_fp8(
155153
const int32_t bid = blockIdx.x;
156154
const int32_t tid = threadIdx.x;
157155
constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format
158-
156+
159157
const bf16_t* _input = input + bid * N; // Input pointer for the token
160158
fp8_e4m3_t* _output = output + bid * N; // Output pointer for the token
161159

@@ -188,9 +186,8 @@ __global__ void device_per_token_quant_bf16_to_fp8(
188186
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
189187

190188
// Compute the scale factor with epsilon to avoid division by zero
191-
constexpr fp32_t epsilon = 1e-7f;
192-
const fp32_t scale = reduced_max / FP8_E4M3_MAX;
193-
const fp32_t inv_scale = 1.0f / (scale + epsilon);
189+
constexpr fp32_t epsilon = 1.0f / (FP8_E4M3_MAX * 512.0f);
190+
const fp32_t scale = fmaxf(epsilon, reduced_max / FP8_E4M3_MAX);
194191

195192
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
196193
vec_copy<sizeof(bf16_t) * VPT>(workspace + (i >> 1), local_bf16);
@@ -200,10 +197,10 @@ __global__ void device_per_token_quant_bf16_to_fp8(
200197
fp32x2_t x = bf16x2_to_fp32x2(local_bf16[2 * j + 0]);
201198
fp32x2_t y = bf16x2_to_fp32x2(local_bf16[2 * j + 1]);
202199
fp32x4_t ret = make_float4(
203-
x.x * inv_scale,
204-
x.y * inv_scale,
205-
y.x * inv_scale,
206-
y.y * inv_scale
200+
x.x / scale,
201+
x.y / scale,
202+
y.x / scale,
203+
y.y / scale
207204
);
208205
local_f8[j] = fp8x4_e4m3_t(ret);
209206
}

0 commit comments

Comments
 (0)