Skip to content

Commit 2ad3f17

Browse files
chr1sj0nesGoogle-ML-Automation
authored andcommitted
Allow type promotion of small integer types.
PiperOrigin-RevId: 825386768
1 parent f6101c5 commit 2ad3f17

File tree

2 files changed

+33
-23
lines changed

2 files changed

+33
-23
lines changed

jax/_src/dtypes.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -678,27 +678,42 @@ def _type_promotion_lattice(strict: bool, x64: bool) -> dict[JAXType, list[JAXTy
678678
x64: allow promotions that form x64 types from non-x64 inputs?
679679
"""
680680
b1, = _bool_types
681-
uint2, uint4, u1, u2, u4, u8, int2, int4, i1, i2, i4, i8 = _int_types
682-
*f1_types, bf, f2, f4, f8 = _float_types
683-
c4, c8 = _complex_types
681+
u2, u4, u8, u16, u32, u64, i2, i4, i8, i16, i32, i64 = _int_types
682+
*small_float_types, bf16, f16, f32, f64 = _float_types
683+
c64, c128 = _complex_types
684684
i_, f_, c_ = _weak_types
685685
if not strict:
686-
out: dict[JAXType, list[JAXType]]
687-
out = {
688-
b1: [i_],
689-
i_: [u1, uint2, uint4, i1, int2, int4],
690-
uint2: [], uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
691-
int2: [], int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
692-
f_: [*f1_types, bf, f2, c_],
693-
**{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
694-
c_: [c4], c4: [c8], c8: [],
686+
out: dict[JAXType, list[JAXType]] = {
687+
b1: [i_],
688+
i_: [i2, u2],
689+
u2: [i4, u4],
690+
u4: [i8, u8],
691+
u8: [i16, u16],
692+
u16: [i32, u32],
693+
u32: [i64, u64],
694+
u64: [f_],
695+
i2: [i4],
696+
i4: [i8],
697+
i8: [i16],
698+
i16: [i32],
699+
i32: [i64],
700+
i64: [f_],
701+
f_: [*small_float_types, bf16, f16, c_],
702+
**{t: [] for t in small_float_types},
703+
bf16: [f32],
704+
f16: [f32],
705+
f32: [f64, c64],
706+
f64: [c128],
707+
c_: [c64],
708+
c64: [c128],
709+
c128: [],
695710
}
696711
# If x64 mode is not enabled, then we want to avoid any promotions that form
697712
# 64-bit types from non-64-bit inputs. There's only one of these in the
698713
# entire promotion lattice, namely u4xi4->i8, which we can avoid by
699714
# replacing it with u4xi4->i4.
700715
if not x64:
701-
out[u4] = [i4, u8]
716+
out[u32] = [i32, u64]
702717
return out
703718
else:
704719
return {

tests/dtypes_test.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,13 @@ def testPromoteDtypesStandard(self):
277277
)
278278

279279
small_fp_dtypes = set(fp8_dtypes + fp4_dtypes)
280-
implicit_int_dtypes = set(signed_dtypes + unsigned_dtypes) - set(intn_dtypes)
280+
int_dtypes = set(signed_dtypes + unsigned_dtypes)
281281

282282
for t1 in all_dtypes:
283283
self.assertEqual(t1, dtypes.promote_types(t1, t1))
284284
self.assertEqual(t1, dtypes.promote_types(t1, np.bool_))
285285
# TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8.
286-
if t1 in small_fp_dtypes or t1 in intn_dtypes:
286+
if t1 in small_fp_dtypes:
287287
assertTypePromotionError(t1, np.complex128)
288288
else:
289289
self.assertEqual(
@@ -297,10 +297,8 @@ def testPromoteDtypesStandard(self):
297297
and (t1 != np.bool_)
298298
and (t2 != np.bool_)
299299
and (
300-
t1 in intn_dtypes or
301-
t2 in intn_dtypes or
302-
(t1 in small_fp_dtypes and t2 not in implicit_int_dtypes) or
303-
(t2 in small_fp_dtypes and t1 not in implicit_int_dtypes)
300+
(t1 in small_fp_dtypes and t2 not in int_dtypes)
301+
or (t2 in small_fp_dtypes and t1 not in int_dtypes)
304302
)
305303
):
306304
assertTypePromotionError(t1, t2)
@@ -317,10 +315,7 @@ def testPromoteDtypesStandard(self):
317315
# inexact types.
318316
for t in float_dtypes + complex_dtypes:
319317
for i in bool_dtypes + signed_dtypes + unsigned_dtypes:
320-
if i in intn_dtypes:
321-
assertTypePromotionError(t, i)
322-
else:
323-
self.assertEqual(t, dtypes.promote_types(t, i))
318+
self.assertEqual(t, dtypes.promote_types(t, i))
324319

325320
# Promotions between exact types, or between inexact types, match NumPy.
326321
for groups in [bool_dtypes + np_signed_dtypes + np_unsigned_dtypes,

0 commit comments

Comments
 (0)