Skip to content

Commit 80df6b2

Browse files
committed
Small improvements (but nthreads has to be 1 still)
1 parent b25398f commit 80df6b2

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed

src/blosc2/blosc2_ext.pyx

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ from cpython cimport (
2323
PyBytes_FromStringAndSize,
2424
PyObject_GetBuffer,
2525
)
26+
from cpython.ref cimport Py_INCREF, Py_DECREF
2627
from cpython.pycapsule cimport PyCapsule_GetPointer, PyCapsule_New
2728
from cython.operator cimport dereference
2829
from libc.stdint cimport uintptr_t
@@ -553,6 +554,7 @@ ctypedef struct udf_udata:
553554
b2nd_array_t *array
554555
int64_t chunks_in_array[B2ND_MAX_DIM]
555556
int64_t blocks_in_chunk[B2ND_MAX_DIM]
557+
void* numexpr_handle # Cached numexpr compiled expression handle
556558

557559
MAX_TYPESIZE = BLOSC2_MAXTYPESIZE
558560
MAX_BUFFERSIZE = BLOSC2_MAX_BUFFERSIZE
@@ -778,6 +780,9 @@ cdef _check_cparams(blosc2_cparams *cparams):
778780
raise ValueError("Cannot use multi-threading with user defined Python filters")
779781

780782
if cparams.prefilter != NULL:
783+
# Note: last_expr_prefilter uses numexpr C API which is more thread-friendly,
784+
# but we still access Python objects (dicts, lists) in aux_last_expr,
785+
# so multi-threading can cause issues. Keep single-threaded for safety.
781786
raise ValueError("`nthreads` must be 1 when a prefilter is set")
782787

783788
cdef _check_dparams(blosc2_dparams* dparams, blosc2_cparams* cparams=NULL):
@@ -1674,13 +1679,25 @@ cdef class SChunk:
16741679
raise RuntimeError("Could not create compression context")
16751680

16761681
cpdef remove_prefilter(self, func_name, _new_ctx=True):
1682+
cdef udf_udata* udf_data
1683+
cdef user_filters_udata* udata
1684+
16771685
if func_name is not None and func_name in blosc2.prefilter_funcs:
16781686
del blosc2.prefilter_funcs[func_name]
16791687

1680-
# From Python the preparams->udata with always have the field py_func
1681-
cdef user_filters_udata * udata = <user_filters_udata *>self.schunk.storage.cparams.preparams.user_data
1682-
free(udata.py_func)
1683-
free(self.schunk.storage.cparams.preparams.user_data)
1688+
# Clean up the numexpr handle if this is a last_expr_prefilter
1689+
if self.schunk.storage.cparams.prefilter == <blosc2_prefilter_fn>last_expr_prefilter:
1690+
udf_data = <udf_udata*>self.schunk.storage.cparams.preparams.user_data
1691+
if udf_data.numexpr_handle != NULL:
1692+
Py_DECREF(<object>udf_data.numexpr_handle)
1693+
free(udf_data.py_func)
1694+
free(udf_data)
1695+
else:
1696+
# From Python the preparams->udata with always have the field py_func
1697+
udata = <user_filters_udata*>self.schunk.storage.cparams.preparams.user_data
1698+
free(udata.py_func)
1699+
free(udata)
1700+
16841701
free(self.schunk.storage.cparams.preparams)
16851702
self.schunk.storage.cparams.preparams = NULL
16861703
self.schunk.storage.cparams.prefilter = NULL
@@ -1837,7 +1854,9 @@ cdef int aux_last_expr(udf_udata *udata, int64_t nchunk, int32_t nblock,
18371854
offset = tuple(start_ndim[i] for i in range(udata.array.ndim))
18381855

18391856
# Use numexpr C API for faster evaluation
1840-
numexpr_handle = numexpr_get_last_compiled()
1857+
# Use the cached handle from udata (set during _set_pref_last_expr)
1858+
# This allows multi-threading since all threads share the same handle
1859+
numexpr_handle = udata.numexpr_handle
18411860
if numexpr_handle != NULL:
18421861
# Get the variable names order from the compiled expression
18431862
compiled_ex = <object>numexpr_handle
@@ -2789,7 +2808,15 @@ cdef class NDArray:
27892808
cparams.prefilter = <blosc2_prefilter_fn> last_expr_prefilter
27902809

27912810
cdef blosc2_prefilter_params* preparams = <blosc2_prefilter_params *> malloc(sizeof(blosc2_prefilter_params))
2792-
preparams.user_data = self._fill_udf_udata(func_id, inputs_id)
2811+
cdef udf_udata* udata = self._fill_udf_udata(func_id, inputs_id)
2812+
2813+
# Get and cache the numexpr compiled expression handle for multi-threading
2814+
udata.numexpr_handle = numexpr_get_last_compiled()
2815+
# Increment reference count to keep the expression alive across threads
2816+
if udata.numexpr_handle != NULL:
2817+
Py_INCREF(<object>udata.numexpr_handle)
2818+
2819+
preparams.user_data = udata
27932820
cparams.preparams = preparams
27942821
_check_cparams(cparams)
27952822

src/blosc2/lazyexpr.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,9 @@ def fast_eval( # noqa: C901
12631263

12641264
if True:
12651265
cparams = kwargs.pop("cparams", blosc2.CParams())
1266+
# Force single-threaded execution for prefilter evaluation
1267+
# The prefilter callback accesses Python objects which aren't thread-safe
1268+
# across blosc2's C threads. numexpr does its own multi-threading internally.
12661269
if cparams.nthreads > 1:
12671270
prev_nthreads = cparams.nthreads
12681271
cparams.nthreads = 1
@@ -1280,7 +1283,8 @@ def fast_eval( # noqa: C901
12801283
# Physical allocation happens here (when writing):
12811284
res_eval[...] = aux
12821285
res_eval.schunk.remove_prefilter(func_name)
1283-
res_eval.schunk.cparams.nthreads = prev_nthreads
1286+
if cparams.nthreads > 1:
1287+
res_eval.schunk.cparams.nthreads = prev_nthreads
12841288

12851289
return res_eval
12861290

0 commit comments

Comments
 (0)