Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,27 +678,42 @@ def _type_promotion_lattice(strict: bool, x64: bool) -> dict[JAXType, list[JAXTy
x64: allow promotions that form x64 types from non-x64 inputs?
"""
b1, = _bool_types
uint2, uint4, u1, u2, u4, u8, int2, int4, i1, i2, i4, i8 = _int_types
*f1_types, bf, f2, f4, f8 = _float_types
c4, c8 = _complex_types
u2, u4, u8, u16, u32, u64, i2, i4, i8, i16, i32, i64 = _int_types
*small_float_types, bf16, f16, f32, f64 = _float_types
c64, c128 = _complex_types
i_, f_, c_ = _weak_types
if not strict:
out: dict[JAXType, list[JAXType]]
out = {
b1: [i_],
i_: [u1, uint2, uint4, i1, int2, int4],
uint2: [], uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
int2: [], int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
f_: [*f1_types, bf, f2, c_],
**{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
c_: [c4], c4: [c8], c8: [],
out: dict[JAXType, list[JAXType]] = {
b1: [i_],
i_: [i2, u2],
u2: [i4, u4],
u4: [i8, u8],
u8: [i16, u16],
u16: [i32, u32],
u32: [i64, u64],
u64: [f_],
i2: [i4],
i4: [i8],
i8: [i16],
i16: [i32],
i32: [i64],
i64: [f_],
f_: [*small_float_types, bf16, f16, c_],
**{t: [] for t in small_float_types},
bf16: [f32],
f16: [f32],
f32: [f64, c64],
f64: [c128],
c_: [c64],
c64: [c128],
c128: [],
}
# If x64 mode is not enabled, then we want to avoid any promotions that form
# 64-bit types from non-64-bit inputs. There's only one of these in the
# entire promotion lattice, namely u4xi4->i8, which we can avoid by
# replacing it with u4xi4->i4.
if not x64:
out[u4] = [i4, u8]
out[u32] = [i32, u64]
return out
else:
return {
Expand Down
31 changes: 5 additions & 26 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ def testPromoteDtypesStandard(self):
)

small_fp_dtypes = set(fp8_dtypes + fp4_dtypes)
implicit_int_dtypes = set(signed_dtypes + unsigned_dtypes) - set(intn_dtypes)
int_dtypes = set(signed_dtypes + unsigned_dtypes)

for t1 in all_dtypes:
self.assertEqual(t1, dtypes.promote_types(t1, t1))
self.assertEqual(t1, dtypes.promote_types(t1, np.bool_))
# TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8.
if t1 in small_fp_dtypes or t1 in intn_dtypes:
if t1 in small_fp_dtypes:
assertTypePromotionError(t1, np.complex128)
else:
self.assertEqual(
Expand All @@ -297,10 +297,8 @@ def testPromoteDtypesStandard(self):
and (t1 != np.bool_)
and (t2 != np.bool_)
and (
t1 in intn_dtypes or
t2 in intn_dtypes or
(t1 in small_fp_dtypes and t2 not in implicit_int_dtypes) or
(t2 in small_fp_dtypes and t1 not in implicit_int_dtypes)
(t1 in small_fp_dtypes and t2 not in int_dtypes)
or (t2 in small_fp_dtypes and t1 not in int_dtypes)
)
):
assertTypePromotionError(t1, t2)
Expand All @@ -317,10 +315,7 @@ def testPromoteDtypesStandard(self):
# inexact types.
for t in float_dtypes + complex_dtypes:
for i in bool_dtypes + signed_dtypes + unsigned_dtypes:
if i in intn_dtypes:
assertTypePromotionError(t, i)
else:
self.assertEqual(t, dtypes.promote_types(t, i))
self.assertEqual(t, dtypes.promote_types(t, i))

# Promotions between exact types, or between inexact types, match NumPy.
for groups in [bool_dtypes + np_signed_dtypes + np_unsigned_dtypes,
Expand Down Expand Up @@ -1126,22 +1121,6 @@ def testFloat4PromotionError(self):
".*4-bit floats do not support implicit promotion"):
x + y

@jax.numpy_dtype_promotion('standard')
@jtu.run_on_devices('tpu')
def testInt2PromotionError(self):
for dtype in intn_dtypes:
if dtype.name == 'int2' or dtype.name == 'uint2':
# TODO(b/343490729): Remove continue once the bug is fixed.
continue

x = jnp.array(1, dtype=dtype)
y = jnp.array(1, dtype='int32')
with self.assertRaisesRegex(
dtypes.TypePromotionError,
'.*[24]-bit integers do not support implicit promotion',
):
x + y

@jtu.sample_product(
dtype=all_dtypes,
weak_type=[True, False],
Expand Down
Loading