Skip to content

Commit b5b2c8d

Browse files
committed
ENH: add minimal dlpack tests
No cross-library tests, just `from_dlpack` with the same producer and consumer.
1 parent cf7bc9f commit b5b2c8d

File tree

1 file changed

+127
-0
lines changed

1 file changed

+127
-0
lines changed

array_api_tests/test_dlpack.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from enum import Enum
2+
3+
from hypothesis import given, strategies as st
4+
from . import _array_module as xp
5+
from . import pytest_helpers as ph
6+
from . import hypothesis_helpers as hh
7+
8+
# dlpack Enum values,
9+
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack_device__.html
10+
11+
class DLPackDeviceEnum(Enum):
12+
CPU = 1
13+
CUDA = 2
14+
CPU_PINNED = 3
15+
OPENCL = 4
16+
VULKAN = 7
17+
METAL = 8
18+
VPI = 9
19+
ROCM = 10
20+
CUDA_MANAGED = 13
21+
ONE_API = 14
22+
23+
24+
def _compatible_devices(devices):
25+
"""Given a list of devices, filter out dlpack-incompatible ones."""
26+
# XXX: there seems to be no better way than try-catch for __dlpack_device__()
27+
28+
# XXX: this process actually fails with CuPy because CuPy ignores the device= argument
29+
# cf https://github.com/data-apis/array-api-compat/issues/337 and
30+
# https://github.com/cupy/cupy/issues/9848
31+
# Luckily, CuPy only supports CUDA devices, and they are all compatible.
32+
compatible_ = []
33+
for device in devices:
34+
x = xp.empty(2, device=device)
35+
try:
36+
x.__dlpack_device__()
37+
except:
38+
# case in point: torch.device(type="meta") raises
39+
# ValueError: Unknown device type meta for Dlpack
40+
pass
41+
else:
42+
# no exception => device is compatible
43+
compatible_.append(device)
44+
return compatible_
45+
46+
47+
@given(dtype=hh.all_dtypes, data=st.data())
48+
def test_dlpack_device(dtype, data):
49+
"""Test the array object __dlpack_device__ method."""
50+
# TODO: 1. generate inputs on non-default devices
51+
x = xp.empty(3, dtype=dtype)
52+
device_type, device_id = x.__dlpack_device__()
53+
54+
assert DLPackDeviceEnum(int(device_type))
55+
assert isinstance(device_id, int)
56+
57+
58+
@given(
59+
x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1, max_side=2)),
60+
copy_kw=hh.kwargs(
61+
copy=st.booleans() | st.none()
62+
),
63+
max_version_kw=hh.kwargs(
64+
max_version=st.tuples(
65+
st.integers(min_value=0, max_value=2),
66+
st.integers(min_value=0, max_value=0)
67+
)
68+
),
69+
dl_device_kw=hh.kwargs(
70+
dl_device=st.tuples( # XXX: the 2023.12 standard only mandates ... kDLCPU ?
71+
st.just(DLPackDeviceEnum.CPU.value),
72+
st.just(0)
73+
)
74+
),
75+
data=st.data()
76+
)
77+
def test_dunder_dlpack(x, copy_kw, max_version_kw, dl_device_kw, data):
78+
repro_snippet = ph.format_snippet(
79+
f"x.__dlpack__ with {copy_kw = }, {max_version_kw = } and {dl_device_kw = }"
80+
)
81+
82+
try:
83+
x.__dlpack__(**copy_kw, **max_version_kw, **dl_device_kw)
84+
# apparently, we cannot do anything with the DLPack capsule from python
85+
except Exception as exc:
86+
ph.add_note(exc, repro_snippet)
87+
raise
88+
89+
90+
@given(
91+
x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1, max_side=2)),
92+
copy_kw=hh.kwargs(copy=st.booleans()),
93+
data=st.data()
94+
)
95+
def test_from_dlpack(x, copy_kw, data):
96+
# TODO: 1. test copy; 2. generate inputs on non-default devices;
97+
# 3. test for copy=False cross-device transfers
98+
# 4. test 0D arrays / numpy scalars (the latter do not support dlpack ATM)
99+
100+
copy = copy_kw["copy"] if copy_kw else None
101+
if copy is False:
102+
# XXX there is no way to tell if a no-copy cross-device transfer is meant to succeed
103+
devices = [x.device]
104+
else:
105+
devices = xp.__array_namespace_info__().devices()
106+
devices = _compatible_devices(devices)
107+
108+
tgt_device_kw = data.draw(
109+
hh.kwargs(device=st.sampled_from(devices) | st.none())
110+
)
111+
tgt_device = tgt_device_kw['device'] if tgt_device_kw else None
112+
113+
repro_snippet = ph.format_snippet(
114+
f"y = from_dlpack({x!r}, **tgt_device_kw, **copy_kw) with {tgt_device_kw=} and {copy_kw=}"
115+
)
116+
try:
117+
y = xp.from_dlpack(x, **tgt_device_kw, **copy_kw)
118+
119+
if tgt_device is None:
120+
assert y.device == x.device
121+
assert xp.all(y == x)
122+
else:
123+
assert y.device == tgt_device
124+
125+
except Exception as exc:
126+
ph.add_note(exc, repro_snippet)
127+
raise

0 commit comments

Comments
 (0)