@@ -678,27 +678,42 @@ def _type_promotion_lattice(strict: bool, x64: bool) -> dict[JAXType, list[JAXTy
678678 x64: allow promotions that form x64 types from non-x64 inputs?
679679 """
680680 b1 , = _bool_types
681- uint2 , uint4 , u1 , u2 , u4 , u8 , int2 , int4 , i1 , i2 , i4 , i8 = _int_types
682- * f1_types , bf , f2 , f4 , f8 = _float_types
683- c4 , c8 = _complex_types
681+ u2 , u4 , u8 , u16 , u32 , u64 , i2 , i4 , i8 , i16 , i32 , i64 = _int_types
682+ * small_float_types , bf16 , f16 , f32 , f64 = _float_types
683+ c64 , c128 = _complex_types
684684 i_ , f_ , c_ = _weak_types
685685 if not strict :
686- out : dict [JAXType , list [JAXType ]]
687- out = {
688- b1 : [i_ ],
689- i_ : [u1 , uint2 , uint4 , i1 , int2 , int4 ],
690- uint2 : [], uint4 : [], u1 : [i2 , u2 ], u2 : [i4 , u4 ], u4 : [i8 , u8 ], u8 : [f_ ],
691- int2 : [], int4 : [], i1 : [i2 ], i2 : [i4 ], i4 : [i8 ], i8 : [f_ ],
692- f_ : [* f1_types , bf , f2 , c_ ],
693- ** {t : [] for t in f1_types }, bf : [f4 ], f2 : [f4 ], f4 : [f8 , c4 ], f8 : [c8 ],
694- c_ : [c4 ], c4 : [c8 ], c8 : [],
686+ out : dict [JAXType , list [JAXType ]] = {
687+ b1 : [i_ ],
688+ i_ : [i2 , u2 ],
689+ u2 : [i4 , u4 ],
690+ u4 : [i8 , u8 ],
691+ u8 : [i16 , u16 ],
692+ u16 : [i32 , u32 ],
693+ u32 : [i64 , u64 ],
694+ u64 : [f_ ],
695+ i2 : [i4 ],
696+ i4 : [i8 ],
697+ i8 : [i16 ],
698+ i16 : [i32 ],
699+ i32 : [i64 ],
700+ i64 : [f_ ],
701+ f_ : [* small_float_types , bf16 , f16 , c_ ],
702+ ** {t : [] for t in small_float_types },
703+ bf16 : [f32 ],
704+ f16 : [f32 ],
705+ f32 : [f64 , c64 ],
706+ f64 : [c128 ],
707+ c_ : [c64 ],
708+ c64 : [c128 ],
709+ c128 : [],
695710 }
696711 # If x64 mode is not enabled, then we want to avoid any promotions that form
697712 # 64-bit types from non-64-bit inputs. There's only one of these in the
698713 # entire promotion lattice, namely u4xi4->i8, which we can avoid by
699714 # replacing it with u4xi4->i4.
700715 if not x64 :
701- out [u4 ] = [i4 , u8 ]
716+ out [u32 ] = [i32 , u64 ]
702717 return out
703718 else :
704719 return {
0 commit comments