We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 710192d commit 7aed388Copy full SHA for 7aed388
torchao/prototype/mx_formats/kernels.py
@@ -1371,10 +1371,9 @@ def _dequant_mxfp8_kernel(
1371
1372
@triton.jit
1373
def _e8m0_to_fp32(scale_e8m0):
1374
- e8m0_exponent_bias = 127
1375
e8m0_nan_val = 255
1376
- s_offset = scale_e8m0.to(tl.int16) - e8m0_exponent_bias
1377
- s_fp = tl.exp2(s_offset.to(tl.float32))
+ fp32_mantissa_bits = 23
+ s_fp = scale_e8m0 << fp32_mantissa_bits
1378
s_fp = tl.where(scale_e8m0 != e8m0_nan_val, s_fp, float("nan"))
1379
return s_fp.to(tl.float32)
1380
0 commit comments