Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions trax/tf_numpy/jax_tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def onp_fun(a, b):
check_xla = not set((lhs_dtype, rhs_dtype)).intersection(
(onp.int32, onp.int64))

tol = {onp.float64: 1e-14}
tol = {onp.float64: 1e-14, onp.float16: 0.04}
tol = max(jtu.tolerance(lhs_dtype, tol), jtu.tolerance(rhs_dtype, tol))
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True,
check_incomplete_shape=True,
Expand Down Expand Up @@ -1301,8 +1301,12 @@ def testCumSumProd(self, axis, shape, dtype, out_dtype, onp_op, lnp_op, rng_fact
tol=tol)
# XLA lacks int64 Cumsum/Cumprod kernels (b/168841378).
check_xla = out_dtype != onp.int64
rtol = None
if out_dtype == onp.float16:
rtol = 2e-3
self._CompileAndCheck(
lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True,
lnp_fun, args_maker, check_dtypes=True, rtol=rtol,
check_incomplete_shape=True,
check_experimental_compile=check_xla,
check_xla_forced_compile=check_xla)

Expand Down
4 changes: 3 additions & 1 deletion trax/tf_numpy/jax_tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,10 @@ def format_test_name_suffix(opname, shapes, dtypes):

# We use special symbols, represented as singleton objects, to distinguish
# between NumPy scalars, Python scalars, and 0-D arrays.
class ScalarShape(object):
class ScalarShape:
def __len__(self): return 0
def __getitem__(self, i):
raise IndexError(f'index {i} out of range.')
class _NumpyScalar(ScalarShape): pass
class _PythonScalar(ScalarShape): pass
NUMPY_SCALAR_SHAPE = _NumpyScalar()
Expand Down