@@ -107,7 +107,19 @@ def f6_e3m2_unpacked_to_f32(x: torch.Tensor):
107107 import triton .language as tl
108108
109109 @triton .jit
110- def _fp4_packed_to_bf16 (x_packed ):
110+ def _fp4_packed_to_bf16 (
111+ x_packed ,
112+ sign_mask_f4 ,
113+ mantissa_mask_f4 ,
114+ mbits_f4_e2m1 ,
115+ ebits_f4_e2m1 ,
116+ f4_e2m1_exp_bias ,
117+ mbits_f32 ,
118+ ebits_f32 ,
119+ f32_exp_bias ,
120+ zero_bits_f32 ,
121+ zero_point_five_bits_f32 ,
122+ ):
111123 """
112124 Input: a tensor of packed fp4 values
113125 Output: a tensor of bfloat16 values
@@ -123,7 +135,7 @@ def _fp4_packed_to_bf16(x_packed):
123135 # output = x_unpacked.to(tl.float32)
124136
125137 # save the sign
126- sign_f4 = x & SIGN_MASK_F4
138+ sign_f4 = x & sign_mask_f4
127139
128140 # set everything to positive, will add sign back at the end
129141 x_pos = x ^ sign_f4
@@ -138,25 +150,25 @@ def _fp4_packed_to_bf16(x_packed):
138150 denormal_mask = x_pos == 1
139151
140152 # calculate the new exponent and shift it to bits 2:9 of the result
141- exp_biased_f4 = x_pos >> MBITS_F4_E2M1
142- exp_biased_f32 = exp_biased_f4 - F4_E2M1_EXP_BIAS + F32_EXP_BIAS
143- exp_biased_f32 = exp_biased_f32 .to (tl .int32 ) << MBITS_F32
153+ exp_biased_f4 = x_pos >> mbits_f4_e2m1
154+ exp_biased_f32 = exp_biased_f4 - f4_e2m1_exp_bias + f32_exp_bias
155+ exp_biased_f32 = exp_biased_f32 .to (tl .int32 ) << mbits_f32
144156
145157 # shift the mantissa to bits 10:32 of the result
146- mantissa_f4 = x_pos & MANTISSA_MASK_F4
147- mantissa_f32 = mantissa_f4 .to (tl .int32 ) << (MBITS_F32 - MBITS_F4_E2M1 )
158+ mantissa_f4 = x_pos & mantissa_mask_f4
159+ mantissa_f32 = mantissa_f4 .to (tl .int32 ) << (mbits_f32 - mbits_f4_e2m1 )
148160 output = mantissa_f32
149161
150162 # combine the pieces
151163 result = exp_biased_f32 | mantissa_f32
152164 # result[zero_mask] = ZERO_BITS_F32
153- result = tl .where (zero_mask , ZERO_BITS_F32 , result )
165+ result = tl .where (zero_mask , zero_bits_f32 , result )
154166 # result[denormal_mask] = ZERO_POINT_FIVE_BITS_F32
155- result = tl .where (denormal_mask , ZERO_POINT_FIVE_BITS_F32 , result )
167+ result = tl .where (denormal_mask , zero_point_five_bits_f32 , result )
156168
157169 # add sign back
158170 sign_f32 = sign_f4 .to (tl .int32 ) << (
159- MBITS_F32 - MBITS_F4_E2M1 + EBITS_F32 - EBITS_F4_E2M1
171+ mbits_f32 - mbits_f4_e2m1 + ebits_f32 - ebits_f4_e2m1
160172 )
161173 result = result | sign_f32
162174
@@ -174,6 +186,16 @@ def triton_f4_to_bf16_kernel(
174186 x_ptr ,
175187 output_ptr ,
176188 n_elements_in ,
189+ sign_mask_f4 : tl .constexpr ,
190+ mantissa_mask_f4 : tl .constexpr ,
191+ mbits_f4_e2m1 : tl .constexpr ,
192+ ebits_f4_e2m1 : tl .constexpr ,
193+ f4_e2m1_exp_bias : tl .constexpr ,
194+ mbits_f32 : tl .constexpr ,
195+ ebits_f32 : tl .constexpr ,
196+ f32_exp_bias : tl .constexpr ,
197+ zero_bits_f32 : tl .constexpr ,
198+ zero_point_five_bits_f32 : tl .constexpr ,
177199 BLOCK_SIZE_IN : tl .constexpr ,
178200 ):
179201 pid = tl .program_id (axis = 0 )
@@ -187,7 +209,19 @@ def triton_f4_to_bf16_kernel(
187209
188210 # packed uint8
189211 x_packed = tl .load (x_ptr + offsets_in , mask = mask_in )
190- output = _fp4_packed_to_bf16 (x_packed )
212+ output = _fp4_packed_to_bf16 (
213+ x_packed ,
214+ sign_mask_f4 ,
215+ mantissa_mask_f4 ,
216+ mbits_f4_e2m1 ,
217+ ebits_f4_e2m1 ,
218+ f4_e2m1_exp_bias ,
219+ mbits_f32 ,
220+ ebits_f32 ,
221+ f32_exp_bias ,
222+ zero_bits_f32 ,
223+ zero_point_five_bits_f32 ,
224+ )
191225
192226 # set up output offsets
193227 block_start_out = pid * BLOCK_SIZE_OUT
@@ -213,6 +247,18 @@ def triton_f4_to_scaled_bf16_kernel(
213247 output_ptr ,
214248 n_elements_in ,
215249 mx_block_size : tl .constexpr ,
250+ sign_mask_f4 : tl .constexpr ,
251+ mantissa_mask_f4 : tl .constexpr ,
252+ mbits_f4_e2m1 : tl .constexpr ,
253+ ebits_f4_e2m1 : tl .constexpr ,
254+ f4_e2m1_exp_bias : tl .constexpr ,
255+ mbits_f32 : tl .constexpr ,
256+ ebits_f32 : tl .constexpr ,
257+ f32_exp_bias : tl .constexpr ,
258+ zero_bits_f32 : tl .constexpr ,
259+ zero_point_five_bits_f32 : tl .constexpr ,
260+ e8m0_exponent_bias : tl .constexpr ,
261+ e8m0_exponent_nan_val : tl .constexpr ,
216262 BLOCK_SIZE_IN : tl .constexpr ,
217263 ):
218264 pid = tl .program_id (axis = 0 )
@@ -227,7 +273,19 @@ def triton_f4_to_scaled_bf16_kernel(
227273 mask_in = offsets_in < n_elements_in
228274 # packed uint8
229275 x_packed = tl .load (x_ptr + offsets_in , mask = mask_in )
230- output = _fp4_packed_to_bf16 (x_packed )
276+ output = _fp4_packed_to_bf16 (
277+ x_packed ,
278+ sign_mask_f4 ,
279+ mantissa_mask_f4 ,
280+ mbits_f4_e2m1 ,
281+ ebits_f4_e2m1 ,
282+ f4_e2m1_exp_bias ,
283+ mbits_f32 ,
284+ ebits_f32 ,
285+ f32_exp_bias ,
286+ zero_bits_f32 ,
287+ zero_point_five_bits_f32 ,
288+ )
231289
232290 # load scale
233291 block_start_s = pid * BLOCK_SIZE_S
@@ -236,9 +294,9 @@ def triton_f4_to_scaled_bf16_kernel(
236294 s = tl .load (s_ptr + offsets_s , mask = mask_s )
237295
238296 # create the scale in bf16
239- s_offset = s .to (tl .int16 ) - E8M0_EXPONENT_BIAS
297+ s_offset = s .to (tl .int16 ) - e8m0_exponent_bias
240298 s_fp = libdevice .pow (2.0 , s_offset ).to (tl .bfloat16 )
241- s_fp = tl .where (s != E8M0_EXPONENT_NAN_VAL , s_fp , float ("nan" ))
299+ s_fp = tl .where (s != e8m0_exponent_nan_val , s_fp , float ("nan" ))
242300
243301 # multiply output by scale
244302 # TODO(later): see if manipulating the exponent instead of fp
@@ -263,6 +321,16 @@ def triton_f4_to_bf16_kernel(
263321 x_ptr ,
264322 output_ptr ,
265323 n_elements_in ,
324+ sign_mask_f4 ,
325+ mantissa_mask_f4 ,
326+ mbits_f4_e2m1 ,
327+ ebits_f4_e2m1 ,
328+ f4_e2m1_exp_bias ,
329+ mbits_f32 ,
330+ ebits_f32 ,
331+ f32_exp_bias ,
332+ zero_bits_f32 ,
333+ zero_point_five_bits_f32 ,
266334 BLOCK_SIZE_IN ,
267335 ):
268336 raise AssertionError ("unsupported without triton" )
@@ -273,6 +341,18 @@ def triton_f4_to_scaled_bf16_kernel(
273341 output_ptr ,
274342 n_elements_in ,
275343 mx_block_size ,
344+ sign_mask_f4 ,
345+ mantissa_mask_f4 ,
346+ mbits_f4_e2m1 ,
347+ ebits_f4_e2m1 ,
348+ f4_e2m1_exp_bias ,
349+ mbits_f32 ,
350+ ebits_f32 ,
351+ f32_exp_bias ,
352+ zero_bits_f32 ,
353+ zero_point_five_bits_f32 ,
354+ e8m0_exponent_bias ,
355+ e8m0_exponent_nan_val ,
276356 BLOCK_SIZE_IN ,
277357 ):
278358 raise AssertionError ("unsupported without triton" )
@@ -294,7 +374,22 @@ def triton_f4_to_bf16(x: torch.Tensor):
294374 grid = lambda meta : ( # noqa: E731
295375 triton .cdiv (n_elements_in , meta ["BLOCK_SIZE_IN" ]),
296376 ) # noqa: E731,E501
297- triton_f4_to_bf16_kernel [grid ](x , output , n_elements_in , BLOCK_SIZE_IN = 512 )
377+ triton_f4_to_bf16_kernel [grid ](
378+ x ,
379+ output ,
380+ n_elements_in ,
381+ sign_mask_f4 = SIGN_MASK_F4 ,
382+ mantissa_mask_f4 = MANTISSA_MASK_F4 ,
383+ mbits_f4_e2m1 = MBITS_F4_E2M1 ,
384+ ebits_f4_e2m1 = EBITS_F4_E2M1 ,
385+ f4_e2m1_exp_bias = F4_E2M1_EXP_BIAS ,
386+ mbits_f32 = MBITS_F32 ,
387+ ebits_f32 = EBITS_F32 ,
388+ f32_exp_bias = F32_EXP_BIAS ,
389+ zero_bits_f32 = ZERO_BITS_F32 ,
390+ zero_point_five_bits_f32 = ZERO_POINT_FIVE_BITS_F32 ,
391+ BLOCK_SIZE_IN = 512 ,
392+ )
298393 return output
299394
300395
@@ -318,7 +413,23 @@ def triton_f4_to_scaled_bf16(
318413 triton .cdiv (n_elements_in , meta ["BLOCK_SIZE_IN" ]),
319414 )
320415 triton_f4_to_scaled_bf16_kernel [grid ](
321- x , s_e8m0 , output , n_elements_in , mx_block_size
416+ x ,
417+ s_e8m0 ,
418+ output ,
419+ n_elements_in ,
420+ mx_block_size ,
421+ sign_mask_f4 = SIGN_MASK_F4 ,
422+ mantissa_mask_f4 = MANTISSA_MASK_F4 ,
423+ mbits_f4_e2m1 = MBITS_F4_E2M1 ,
424+ ebits_f4_e2m1 = EBITS_F4_E2M1 ,
425+ f4_e2m1_exp_bias = F4_E2M1_EXP_BIAS ,
426+ mbits_f32 = MBITS_F32 ,
427+ ebits_f32 = EBITS_F32 ,
428+ f32_exp_bias = F32_EXP_BIAS ,
429+ zero_bits_f32 = ZERO_BITS_F32 ,
430+ zero_point_five_bits_f32 = ZERO_POINT_FIVE_BITS_F32 ,
431+ e8m0_exponent_bias = E8M0_EXPONENT_BIAS ,
432+ e8m0_exponent_nan_val = E8M0_EXPONENT_NAN_VAL ,
322433 )
323434 return output
324435
0 commit comments