@@ -18,7 +18,7 @@ __global__ void device_per_token_quant_bf16_to_fp8_general(
1818 const int32_t bid = blockIdx .x ;
1919 const int32_t tid = threadIdx .x ;
2020 constexpr fp32_t FP8_E4M3_MAX = 448 .0f ; // Maximum value representable in FP8 E4M3 format
21-
21+
2222 const bf16_t * _input = input + bid * N; // Input pointer for the token
2323 fp8_e4m3_t * _output = output + bid * N; // Output pointer for the token
2424
@@ -44,15 +44,14 @@ __global__ void device_per_token_quant_bf16_to_fp8_general(
4444 const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
4545
4646 // Compute the scale factor with epsilon to avoid division by zero
47- constexpr fp32_t epsilon = 1e-7f ;
48- const fp32_t scale = reduced_max / FP8_E4M3_MAX;
49- const fp32_t inv_scale = 1 .0f / (scale + epsilon);
47+ constexpr fp32_t epsilon = 1 .0f / (FP8_E4M3_MAX * 512 .0f );
48+ const fp32_t scale = fmaxf (epsilon, reduced_max / FP8_E4M3_MAX);
5049
5150 for (int32_t i = tid; i < N; i += TPB) {
5251 local_bf16 = workspace1[i];
53-
52+
5453 fp32_t tmp = cvt_bf16_f32 (local_bf16);
55- fp32_t x = tmp * inv_scale ;
54+ fp32_t x = tmp / scale ;
5655 local_f8 = fp8_e4m3_t (x);
5756
5857 _output[i] = local_f8;
@@ -77,7 +76,7 @@ __global__ void device_per_token_quant_bf16_to_fp8_vpt(
7776 const int32_t bid = blockIdx .x ;
7877 const int32_t tid = threadIdx .x ;
7978 constexpr fp32_t FP8_E4M3_MAX = 448 .0f ; // Maximum value representable in FP8 E4M3 format
80-
79+
8180 const bf16_t * _input = input + bid * N; // Input pointer for the token
8281 fp8_e4m3_t * _output = output + bid * N; // Output pointer for the token
8382
@@ -110,9 +109,8 @@ __global__ void device_per_token_quant_bf16_to_fp8_vpt(
110109 const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
111110
112111 // Compute the scale factor with epsilon to avoid division by zero
113- constexpr fp32_t epsilon = 1e-7f ;
114- const fp32_t scale = reduced_max / FP8_E4M3_MAX;
115- const fp32_t inv_scale = 1 .0f / (scale + epsilon);
112+ constexpr fp32_t epsilon = 1 .0f / (FP8_E4M3_MAX * 512 .0f );
113+ const fp32_t scale = fmaxf (epsilon, reduced_max / FP8_E4M3_MAX);
116114
117115 for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
118116 vec_copy<sizeof (bf16_t ) * VPT>(workspace2 + (i >> 1 ), local_bf16);
@@ -122,10 +120,10 @@ __global__ void device_per_token_quant_bf16_to_fp8_vpt(
122120 fp32x2_t x = bf16x2_to_fp32x2 (local_bf16[2 * j + 0 ]);
123121 fp32x2_t y = bf16x2_to_fp32x2 (local_bf16[2 * j + 1 ]);
124122 fp32x4_t ret = make_float4 (
125- x.x * inv_scale ,
126- x.y * inv_scale ,
127- y.x * inv_scale ,
128- y.y * inv_scale
123+ x.x / scale ,
124+ x.y / scale ,
125+ y.x / scale ,
126+ y.y / scale
129127 );
130128 local_f8[j] = fp8x4_e4m3_t (ret);
131129 }
@@ -155,7 +153,7 @@ __global__ void device_per_token_quant_bf16_to_fp8(
155153 const int32_t bid = blockIdx .x ;
156154 const int32_t tid = threadIdx .x ;
157155 constexpr fp32_t FP8_E4M3_MAX = 448 .0f ; // Maximum value representable in FP8 E4M3 format
158-
156+
159157 const bf16_t * _input = input + bid * N; // Input pointer for the token
160158 fp8_e4m3_t * _output = output + bid * N; // Output pointer for the token
161159
@@ -188,9 +186,8 @@ __global__ void device_per_token_quant_bf16_to_fp8(
188186 const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
189187
190188 // Compute the scale factor with epsilon to avoid division by zero
191- constexpr fp32_t epsilon = 1e-7f ;
192- const fp32_t scale = reduced_max / FP8_E4M3_MAX;
193- const fp32_t inv_scale = 1 .0f / (scale + epsilon);
189+ constexpr fp32_t epsilon = 1 .0f / (FP8_E4M3_MAX * 512 .0f );
190+ const fp32_t scale = fmaxf (epsilon, reduced_max / FP8_E4M3_MAX);
194191
195192 for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
196193 vec_copy<sizeof (bf16_t ) * VPT>(workspace + (i >> 1 ), local_bf16);
@@ -200,10 +197,10 @@ __global__ void device_per_token_quant_bf16_to_fp8(
200197 fp32x2_t x = bf16x2_to_fp32x2 (local_bf16[2 * j + 0 ]);
201198 fp32x2_t y = bf16x2_to_fp32x2 (local_bf16[2 * j + 1 ]);
202199 fp32x4_t ret = make_float4 (
203- x.x * inv_scale ,
204- x.y * inv_scale ,
205- y.x * inv_scale ,
206- y.y * inv_scale
200+ x.x / scale ,
201+ x.y / scale ,
202+ y.x / scale ,
203+ y.y / scale
207204 );
208205 local_f8[j] = fp8x4_e4m3_t (ret);
209206 }
0 commit comments