Skip to content

Commit 7aed388

Browse files
[mxfp8 moe training] simplify e8m0 -> fp32 calc
stack-info: PR: #3201, branch: danielvegamyhre/stack/80
1 parent 710192d commit 7aed388

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

torchao/prototype/mx_formats/kernels.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,10 +1371,9 @@ def _dequant_mxfp8_kernel(
13711371

13721372
@triton.jit
13731373
def _e8m0_to_fp32(scale_e8m0):
1374-
e8m0_exponent_bias = 127
13751374
e8m0_nan_val = 255
1376-
s_offset = scale_e8m0.to(tl.int16) - e8m0_exponent_bias
1377-
s_fp = tl.exp2(s_offset.to(tl.float32))
1375+
fp32_mantissa_bits = 23
1376+
s_fp = scale_e8m0 << fp32_mantissa_bits
13781377
s_fp = tl.where(scale_e8m0 != e8m0_nan_val, s_fp, float("nan"))
13791378
return s_fp.to(tl.float32)
13801379

0 commit comments

Comments
 (0)