Skip to content

Commit e9c8606

Browse files
committed
MAINT: avoid explicit xp.float64 in test_special_cases
1 parent 3ab78bd commit e9c8606

2 files changed

Lines changed: 11 additions & 9 deletions

File tree

array_api_tests/test_creation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def test_arange(dtype, data):
198198
), f"out[0]={out[0]}, but should be {_start} {f_func}"
199199
except Exception as exc:
200200
ph.add_note(exc, repro_snippet)
201-
raise
201+
raise
202202

203203

204204
@given(shape=hh.shapes(min_side=1), data=st.data())

array_api_tests/test_special_cases.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1556,10 +1556,11 @@ def test_unary(func_name, func, case):
15561556
def test_binary(func_name, func, case, data):
15571557
# We don't use example() like in test_unary because the same internal shared
15581558
# strategies used in both x1's and x2's don't "sync" with example() draws.
1559-
x1_value = data.draw(case.x1_cond_from_dtype(xp.float64), label="x1_value")
1560-
x2_value = data.draw(case.x2_cond_from_dtype(xp.float64), label="x2_value")
1561-
x1 = xp.asarray(x1_value, dtype=xp.float64)
1562-
x2 = xp.asarray(x2_value, dtype=xp.float64)
1559+
dtyp = dh.widest_real_dtype # float64 if available else float32
1560+
x1_value = data.draw(case.x1_cond_from_dtype(dtyp), label="x1_value")
1561+
x2_value = data.draw(case.x2_cond_from_dtype(dtyp), label="x2_value")
1562+
x1 = xp.asarray(x1_value, dtype=dtyp)
1563+
x2 = xp.asarray(x2_value, dtype=dtyp)
15631564

15641565
out = func(x1, x2)
15651566
out_value = float(out)
@@ -1576,10 +1577,11 @@ def test_binary(func_name, func, case, data):
15761577
@given(data=st.data())
15771578
def test_iop(iop_name, iop, case, data):
15781579
# See test_binary comment
1579-
x1_value = data.draw(case.x1_cond_from_dtype(xp.float64), label="x1_value")
1580-
x2_value = data.draw(case.x2_cond_from_dtype(xp.float64), label="x2_value")
1581-
x1 = xp.asarray(x1_value, dtype=xp.float64)
1582-
x2 = xp.asarray(x2_value, dtype=xp.float64)
1580+
dtyp = dh.widest_real_dtype
1581+
x1_value = data.draw(case.x1_cond_from_dtype(dtyp), label="x1_value")
1582+
x2_value = data.draw(case.x2_cond_from_dtype(dtyp), label="x2_value")
1583+
x1 = xp.asarray(x1_value, dtype=dtyp)
1584+
x2 = xp.asarray(x2_value, dtype=dtyp)
15831585

15841586
res = iop(x1, x2)
15851587
res_value = float(res)

0 commit comments

Comments
 (0)