Skip to content

Commit c299472

Browse files
committed
support buffer protocol
1 parent 97cc18b commit c299472

File tree

3 files changed

+133
-45
lines changed

3 files changed

+133
-45
lines changed

dartsclone/_dartsclone.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,6 @@ cdef extern from "darts.h":
5151

5252
cdef class DoubleArray:
5353
cdef CppDoubleArray *wrapped
54+
cdef Py_buffer _buf
55+
cdef Py_ssize_t _shape[1]
56+
cdef Py_ssize_t _strides[1]

dartsclone/_dartsclone.pyx

Lines changed: 119 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,58 @@
1-
from libc.stdlib cimport malloc, free
1+
from libc.stdlib cimport calloc, free
2+
3+
cdef extern from "Python.h":
4+
ctypedef struct PyObject
5+
int PyObject_GetBuffer(PyObject *exporter, Py_buffer *view, int flags)
6+
void PyBuffer_Release(Py_buffer *view)
7+
const int PyBUF_C_CONTIGUOUS
28

39

410
cdef class DoubleArray:
511
def __cinit__(self):
612
self.wrapped = new CppDoubleArray()
13+
self._strides[0] = 1
714

815
def __dealloc__(self):
16+
if <PyObject *>self._buf.obj != NULL:
17+
PyBuffer_Release(&self._buf)
918
del self.wrapped
1019

1120
def __getstate__(self):
12-
return self.array()
21+
return bytes(self.array())
1322

1423
def __setstate__(self, array):
1524
self.set_array(array)
1625

17-
def array(self):
18-
cdef size_t total_size = self.wrapped.total_size()
19-
cdef char[:] data = <char[:total_size]>self.wrapped.array()
20-
return bytes(data)
26+
def __getbuffer__(self, Py_buffer *buffer, int flags):
27+
buffer.buf = <char *>self.wrapped.array()
28+
buffer.obj = self
29+
buffer.len = self._shape[0] = self.wrapped.total_size()
30+
buffer.readonly = True
31+
buffer.itemsize = 1
32+
buffer.format = 'B'
33+
buffer.ndim = 1
34+
buffer.shape = self._shape
35+
buffer.strides = self._strides
36+
buffer.suboffsets = NULL
37+
buffer.internal = NULL
38+
39+
def __releasebuffer__(self, Py_buffer *buffer):
40+
pass
2141

22-
def set_array(self, const unsigned char[::1] array, size_t size=0):
23-
self.wrapped.set_array(<const void*> &array[0], size)
42+
def array(self):
43+
return memoryview(self)
44+
45+
def set_array(self, array, size_t size=0):
46+
cdef Py_buffer _buf
47+
if PyObject_GetBuffer(<PyObject *>array, &_buf, PyBUF_C_CONTIGUOUS) < 0:
48+
return
49+
if _buf.buf == self.wrapped.array():
50+
PyBuffer_Release(&_buf)
51+
raise ValueError("passed buffer refers to itself")
52+
if <PyObject *>self._buf.obj != NULL:
53+
PyBuffer_Release(&self._buf)
54+
self._buf = _buf
55+
self.wrapped.set_array(_buf.buf, size)
2456

2557
def clear(self):
2658
self.wrapped.clear()
@@ -41,26 +73,45 @@ cdef class DoubleArray:
4173
lengths = None,
4274
values = None):
4375
cdef size_t num_keys = len(keys)
44-
cdef const char** _keys = <const char**> malloc(num_keys * sizeof(char*))
76+
cdef const char** _keys = NULL
77+
cdef Py_buffer* _buf = NULL
4578
cdef size_t *_lengths = NULL
4679
cdef int *_values = NULL
47-
for i, key in enumerate(keys):
48-
_keys[i] = key
49-
if lengths is not None:
50-
_lengths = <size_t *> malloc(num_keys * sizeof(size_t))
51-
for i, length in enumerate(lengths):
52-
_lengths[i] = length
53-
if values is not None:
54-
_values = <int *> malloc(num_keys * sizeof(int))
55-
for i, value in enumerate(values):
56-
_values[i] = value
80+
5781
try:
82+
_keys = <const char**> calloc(num_keys, sizeof(char*))
83+
if _keys == NULL:
84+
raise MemoryError("failed to allocate memory for key array")
85+
_buf = <Py_buffer *> calloc(num_keys, sizeof(Py_buffer))
86+
if _buf == NULL:
87+
raise MemoryError("failed to allocate memory for buffer")
88+
for i, key in enumerate(keys):
89+
if PyObject_GetBuffer(<PyObject *>key, &_buf[i], PyBUF_C_CONTIGUOUS) < 0:
90+
return
91+
_keys[i] = <const char *> _buf[i].buf
92+
if lengths is not None:
93+
_lengths = <size_t *> calloc(num_keys, sizeof(size_t))
94+
if _lengths == NULL:
95+
raise MemoryError("failed to allocate memory for length array")
96+
for i, length in enumerate(lengths):
97+
_lengths[i] = length
98+
if values is not None:
99+
_values = <int *> calloc(num_keys, sizeof(int))
100+
if _values == NULL:
101+
raise MemoryError("failed to allocate memory for value array")
102+
for i, value in enumerate(values):
103+
_values[i] = value
58104
self.wrapped.build(num_keys, _keys, <const size_t*> _lengths, <const int*> _values, NULL)
59105
finally:
60-
free(_keys)
61-
if lengths is not None:
106+
if _keys != NULL:
107+
free(_keys)
108+
if _buf != NULL:
109+
for i in range(num_keys):
110+
PyBuffer_Release(&_buf[i])
111+
free(_buf)
112+
if _lengths != NULL:
62113
free(_lengths)
63-
if values is not None:
114+
if _values != NULL:
64115
free(_values)
65116

66117
def open(self, file_name,
@@ -88,39 +139,66 @@ cdef class DoubleArray:
88139
size_t length = 0,
89140
size_t node_pos = 0,
90141
pair_type=True):
91-
cdef const char *_key = key
92-
if pair_type:
93-
return self.__exact_match_search_pair_type(_key, length, node_pos)
94-
else:
95-
return self.__exact_match_search(_key, length, node_pos)
142+
cdef Py_buffer buf
143+
if PyObject_GetBuffer(<PyObject *>key, &buf, PyBUF_C_CONTIGUOUS) < 0:
144+
return
145+
try:
146+
if length == 0:
147+
if buf.len == 0:
148+
raise ValueError("buffer cannot be empty")
149+
length = buf.len
150+
if pair_type:
151+
return self.__exact_match_search_pair_type(<const char *>buf.buf, length, node_pos)
152+
else:
153+
return self.__exact_match_search(<const char *>buf.buf, length, node_pos)
154+
finally:
155+
PyBuffer_Release(&buf)
96156

97157
def common_prefix_search(self, key,
98158
size_t max_num_results = 0,
99159
size_t length = 0,
100160
size_t node_pos = 0,
101161
pair_type=True):
102-
cdef const char *_key = key
103-
if max_num_results == 0:
104-
max_num_results = len(key)
105-
if pair_type:
106-
return self.__common_prefix_search_pair_type(_key, max_num_results, length, node_pos)
107-
else:
108-
return self.__common_prefix_search(_key, max_num_results, length, node_pos)
162+
cdef Py_buffer buf
163+
if PyObject_GetBuffer(<PyObject *>key, &buf, PyBUF_C_CONTIGUOUS) < 0:
164+
return
165+
try:
166+
if length == 0:
167+
if buf.len == 0:
168+
raise ValueError("buffer cannot be empty")
169+
length = buf.len
170+
if max_num_results == 0:
171+
max_num_results = len(key)
172+
if pair_type:
173+
return self.__common_prefix_search_pair_type(<const char *>buf.buf, max_num_results, length, node_pos)
174+
else:
175+
return self.__common_prefix_search(<const char *>buf.buf, max_num_results, length, node_pos)
176+
finally:
177+
PyBuffer_Release(&buf)
109178

110179
def traverse(self, key,
111180
size_t node_pos,
112181
size_t key_pos,
113182
size_t length = 0):
114-
cdef const char *_key = key
183+
cdef Py_buffer buf
115184
cdef int result
116-
with nogil:
117-
result = self.wrapped.traverse(_key, node_pos, key_pos, length)
118-
return result
185+
if PyObject_GetBuffer(<PyObject *>key, &buf, PyBUF_C_CONTIGUOUS) < 0:
186+
return
187+
try:
188+
if length == 0:
189+
if buf.len == 0:
190+
raise ValueError("buffer cannot be empty")
191+
length = buf.len
192+
with nogil:
193+
result = self.wrapped.traverse(<const char *>buf.buf, node_pos, key_pos, length)
194+
return result
195+
finally:
196+
PyBuffer_Release(&buf)
119197

120198
def __exact_match_search(self, const char *key,
121199
size_t length = 0,
122200
size_t node_pos = 0):
123-
cdef int result
201+
cdef int result = 0
124202
with nogil:
125203
self.wrapped.exact_match_search(key, result, length, node_pos)
126204
return result
@@ -137,7 +215,7 @@ cdef class DoubleArray:
137215
size_t max_num_results,
138216
size_t length,
139217
size_t node_pos):
140-
cdef int *results = <int *> malloc(max_num_results * sizeof(int))
218+
cdef int *results = <int *> calloc(max_num_results, sizeof(int))
141219
cdef int result_len
142220
try:
143221
with nogil:
@@ -153,7 +231,7 @@ cdef class DoubleArray:
153231
size_t max_num_results,
154232
size_t length,
155233
size_t node_pos):
156-
cdef result_pair_type *results = <result_pair_type *> malloc(max_num_results * sizeof(result_pair_type))
234+
cdef result_pair_type *results = <result_pair_type *> calloc(max_num_results, sizeof(result_pair_type))
157235
cdef result_pair_type result
158236
cdef int result_len
159237
try:

test/test_darts.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class DoubleArrayTest(unittest.TestCase):
1212
def test_darts_no_values(self):
1313
keys = ['test', 'テスト', 'テストケース']
1414
darts = DoubleArray()
15-
darts.build(sorted([key.encode() for key in keys]))
15+
darts.build([key.encode() for key in keys])
1616
self.assertEqual(1, darts.exact_match_search('テスト'.encode(), pair_type=False))
1717
self.assertEqual(0, darts.common_prefix_search('testcase'.encode(), pair_type=False)[0])
1818
self.assertEqual(0, darts.exact_match_search('test'.encode(), pair_type=False))
@@ -21,7 +21,7 @@ def test_darts_no_values(self):
2121
def test_darts_with_values(self):
2222
keys = ['test', 'テスト', 'テストケース']
2323
darts = DoubleArray()
24-
darts.build(sorted([key.encode() for key in keys]), values=[3, 5, 1])
24+
darts.build([key.encode() for key in keys], values=[3, 5, 1])
2525
self.assertEqual(5, darts.exact_match_search('テスト'.encode(), pair_type=False))
2626
self.assertEqual(3, darts.common_prefix_search('testcase'.encode(), pair_type=False)[0])
2727
self.assertEqual(1, darts.exact_match_search('テストケース'.encode(), pair_type=False))
@@ -30,7 +30,7 @@ def test_darts_with_values(self):
3030
def test_darts_save(self):
3131
keys = ['test', 'テスト', 'テストケース']
3232
darts = DoubleArray()
33-
darts.build(sorted([key.encode() for key in keys]), values=[3, 5, 1])
33+
darts.build([key.encode() for key in keys], values=[3, 5, 1])
3434
with tempfile.NamedTemporaryFile('wb') as output_file:
3535
darts.save(output_file.name)
3636
output_file.flush()
@@ -54,13 +54,20 @@ def test_darts_pickle(self):
5454
def test_darts_array(self):
5555
keys = ['test', 'テスト', 'テストケース']
5656
darts = DoubleArray()
57-
darts.build(sorted([key.encode() for key in keys]), values=[3, 5, 1])
57+
darts.build([key.encode() for key in keys], values=[3, 5, 1])
5858
array = darts.array()
5959
darts = DoubleArray()
6060
darts.set_array(array)
6161
self.assertEqual(5, darts.exact_match_search('テスト'.encode(), pair_type=False))
6262
self.assertEqual(3, darts.common_prefix_search('testcase'.encode(), pair_type=False)[0])
6363

64+
def test_darts_buffers(self):
65+
keys = ['test', 'テスト', 'テストケース']
66+
darts = DoubleArray()
67+
darts.build([memoryview(key.encode()) for key in keys], values=[3, 5, 1])
68+
self.assertEqual(5, darts.exact_match_search(memoryview('テスト'.encode()), pair_type=False))
69+
self.assertEqual(3, darts.common_prefix_search(memoryview('testcase'.encode()), pair_type=False)[0])
70+
6471

6572
if __name__ == "__main__":
6673
unittest.main()

0 commit comments

Comments
 (0)