Skip to content

ENH: allow running on torch.mps device, which does not have float64/complex128#434

Merged
ev-br merged 1 commit intodata-apis:masterfrom
ev-br:torch_mps
Apr 3, 2026
Merged

ENH: allow running on torch.mps device, which does not have float64/complex128#434
ev-br merged 1 commit intodata-apis:masterfrom
ev-br:torch_mps

Conversation

@ev-br
Copy link
Copy Markdown
Member

@ev-br ev-br commented Apr 3, 2026

towards #431

Adding the following to __init__.py

import array_api_compat.torch as xp
xp_name = xp.__name__
xp.set_default_device("mps")

and running

$ ARRAY_API_TESTS_SKIP_DTYPES=uint32,uint64,uint16,float64,complex128  pytest array_api_tests/ -vs --skips-file=../array-api-compat/torch-xfails.txt --max-examples=10

generates a lot of failures, which, however, fall into three categories:

  1. test_special_cases : a lot of edge cases are not implemented on an MPS device;
  2. linear algebra: aten::linalg_eig' is not currently implemented, also aten::_linalg_svd.U, aten::linalg_qr.out
  3. assorted edge cases on MPS with numerically large values: for instance, test_linspace fails on MPS and passes on CPU:
(Pdb) xp.linspace(start, stop, 100, dtype=xp.float32)[-1] - stop
tensor(-67108864., device='mps:0')
(Pdb) p np.linspace(start, stop, 100, dtype=np.float32)[-1] - stop
np.float32(0.0)
(Pdb) p np.linspace(start, stop, 100, dtype=np.float32, device="cpu")[-1] - stop
np.float32(0.0)
(Pdb) start, stop, num
(0.0, 1072614462193664.0, 100)

@ev-br ev-br merged commit a71706f into data-apis:master Apr 3, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant