Skip to content

Commit 407ce37

Browse files
committed
MAINT: dedupe real/complex helpers
1 parent d9c7f3b commit 407ce37

2 files changed

Lines changed: 2 additions & 22 deletions

File tree

array_api_tests/dtype_helpers.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -231,26 +231,6 @@ def real_dtype_for(dtyp):
231231
return real_dtype
232232

233233

234-
def complex_for_float(dtyp):
235-
"""For a real or complex dtype, return a matching complex dtype."""
236-
if api_version <= '2021.12':
237-
raise TypeError("complex dtypes require api_version >= 2022.12.")
238-
239-
if dtyp not in all_float_dtypes:
240-
raise ValueError(f"expected a real dtype, got {dtyp}.")
241-
242-
if dtyp == xp.float32:
243-
return xp.complex64
244-
elif dtyp == xp.float64:
245-
return xp.complex128
246-
elif dtyp == xp.complex64:
247-
return xp.complex64
248-
elif dtype == xp.complex128:
249-
return xp.complex128
250-
else:
251-
raise ValueError(f"Unknown dtype {dtyp}.")
252-
253-
254234
def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
255235
dtype_value_pairs = []
256236
for name, value in mapping.items():

array_api_tests/test_linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def test_eig(x):
343343

344344
eigenvalues = res.eigenvalues
345345
eigenvectors = res.eigenvectors
346-
expected_dtype = dh.complex_for_float(x.dtype)
346+
expected_dtype = dh.complex_dtype_for(x.dtype)
347347

348348
ph.assert_dtype("eig", in_dtype=x.dtype, out_dtype=eigenvalues.dtype,
349349
expected=expected_dtype, repr_name="eigenvalues.dtype")
@@ -373,7 +373,7 @@ def test_eig(x):
373373
@given(x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes))
374374
def test_eigvals(x):
375375
res = linalg.eigvals(x)
376-
expected_dtype = dh.complex_for_float(x.dtype)
376+
expected_dtype = dh.complex_dtype_for(x.dtype)
377377

378378
ph.assert_dtype("eigvals", in_dtype=x.dtype, out_dtype=res.dtype,
379379
expected=expected_dtype, repr_name="eigvals")

0 commit comments

Comments
 (0)