From 5a41c66d89cb43d54bc9e262cba2abc079929d9b Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 28 Oct 2025 23:27:25 -0700 Subject: [PATCH] Allow type promotion of small integer types. PiperOrigin-RevId: 825386768 --- jax/_src/dtypes.py | 41 ++++++++++++++++++++++++++++------------- tests/dtypes_test.py | 31 +++++-------------------------- 2 files changed, 33 insertions(+), 39 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index f10ee20a507b..63637a01336b 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -678,27 +678,42 @@ def _type_promotion_lattice(strict: bool, x64: bool) -> dict[JAXType, list[JAXTy x64: allow promotions that form x64 types from non-x64 inputs? """ b1, = _bool_types - uint2, uint4, u1, u2, u4, u8, int2, int4, i1, i2, i4, i8 = _int_types - *f1_types, bf, f2, f4, f8 = _float_types - c4, c8 = _complex_types + u2, u4, u8, u16, u32, u64, i2, i4, i8, i16, i32, i64 = _int_types + *small_float_types, bf16, f16, f32, f64 = _float_types + c64, c128 = _complex_types i_, f_, c_ = _weak_types if not strict: - out: dict[JAXType, list[JAXType]] - out = { - b1: [i_], - i_: [u1, uint2, uint4, i1, int2, int4], - uint2: [], uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], - int2: [], int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_], - f_: [*f1_types, bf, f2, c_], - **{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8], - c_: [c4], c4: [c8], c8: [], + out: dict[JAXType, list[JAXType]] = { + b1: [i_], + i_: [i2, u2], + u2: [i4, u4], + u4: [i8, u8], + u8: [i16, u16], + u16: [i32, u32], + u32: [i64, u64], + u64: [f_], + i2: [i4], + i4: [i8], + i8: [i16], + i16: [i32], + i32: [i64], + i64: [f_], + f_: [*small_float_types, bf16, f16, c_], + **{t: [] for t in small_float_types}, + bf16: [f32], + f16: [f32], + f32: [f64, c64], + f64: [c128], + c_: [c64], + c64: [c128], + c128: [], } # If x64 mode is not enabled, then we want to avoid any promotions that form # 64-bit types from non-64-bit inputs. There's only one of these in the # entire promotion lattice, namely u4xi4->i8, which we can avoid by # replacing it with u4xi4->i4. if not x64: - out[u4] = [i4, u8] + out[u32] = [i32, u64] return out else: return { diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index c566c6c28612..d4bf9eccfa71 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -277,13 +277,13 @@ def testPromoteDtypesStandard(self): ) small_fp_dtypes = set(fp8_dtypes + fp4_dtypes) - implicit_int_dtypes = set(signed_dtypes + unsigned_dtypes) - set(intn_dtypes) + int_dtypes = set(signed_dtypes + unsigned_dtypes) for t1 in all_dtypes: self.assertEqual(t1, dtypes.promote_types(t1, t1)) self.assertEqual(t1, dtypes.promote_types(t1, np.bool_)) # TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8. - if t1 in small_fp_dtypes or t1 in intn_dtypes: + if t1 in small_fp_dtypes: assertTypePromotionError(t1, np.complex128) else: self.assertEqual( @@ -297,10 +297,8 @@ def testPromoteDtypesStandard(self): and (t1 != np.bool_) and (t2 != np.bool_) and ( - t1 in intn_dtypes or - t2 in intn_dtypes or - (t1 in small_fp_dtypes and t2 not in implicit_int_dtypes) or - (t2 in small_fp_dtypes and t1 not in implicit_int_dtypes) + (t1 in small_fp_dtypes and t2 not in int_dtypes) + or (t2 in small_fp_dtypes and t1 not in int_dtypes) ) ): assertTypePromotionError(t1, t2) @@ -317,10 +315,7 @@ def testPromoteDtypesStandard(self): # inexact types. for t in float_dtypes + complex_dtypes: for i in bool_dtypes + signed_dtypes + unsigned_dtypes: - if i in intn_dtypes: - assertTypePromotionError(t, i) - else: - self.assertEqual(t, dtypes.promote_types(t, i)) + self.assertEqual(t, dtypes.promote_types(t, i)) # Promotions between exact types, or between inexact types, match NumPy. for groups in [bool_dtypes + np_signed_dtypes + np_unsigned_dtypes, @@ -1126,22 +1121,6 @@ def testFloat4PromotionError(self): ".*4-bit floats do not support implicit promotion"): x + y - @jax.numpy_dtype_promotion('standard') - @jtu.run_on_devices('tpu') - def testInt2PromotionError(self): - for dtype in intn_dtypes: - if dtype.name == 'int2' or dtype.name == 'uint2': - # TODO(b/343490729): Remove continue once the bug is fixed. - continue - - x = jnp.array(1, dtype=dtype) - y = jnp.array(1, dtype='int32') - with self.assertRaisesRegex( - dtypes.TypePromotionError, - '.*[24]-bit integers do not support implicit promotion', - ): - x + y - @jtu.sample_product( dtype=all_dtypes, weak_type=[True, False],