Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions src/blosc2/lazyexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,13 +510,15 @@ def check_smaller_shape(value_shape, shape, slice_shape, slice_):
This follows the NumPy broadcasting rules.
"""
# slice_shape must be as long as shape
if len(slice_shape) != len(shape):
raise ValueError("slice_shape must be as long as shape")
if len(slice_shape) != len(slice_):
raise ValueError("slice_shape must be as long as slice_")
no_nones_shape = tuple(sh for sh, s in zip(slice_shape, slice_, strict=True) if s is not None)
no_nones_slice = tuple(s for sh, s in zip(slice_shape, slice_, strict=True) if s is not None)
is_smaller_shape = any(
s > (1 if i >= len(value_shape) else value_shape[i]) for i, s in enumerate(slice_shape)
s > (1 if i >= len(value_shape) else value_shape[i]) for i, s in enumerate(no_nones_shape)
)
slice_past_bounds = any(
s.stop > (1 if i >= len(value_shape) else value_shape[i]) for i, s in enumerate(slice_)
s.stop > (1 if i >= len(value_shape) else value_shape[i]) for i, s in enumerate(no_nones_slice)
)
return len(value_shape) < len(shape) or is_smaller_shape or slice_past_bounds

Expand Down Expand Up @@ -547,10 +549,31 @@ def compute_smaller_slice(larger_shape, smaller_shape, larger_slice):
"""
Returns the slice of the smaller array that corresponds to the slice of the larger array.
"""
diff_dims = len(larger_shape) - len(smaller_shape)
j_small = len(smaller_shape) - 1
j_large = len(larger_shape) - 1
smaller_shape_nones = []
larger_shape_nones = []
for s in reversed(larger_slice):
if s is None:
smaller_shape_nones.append(1)
larger_shape_nones.append(1)
else:
if j_small >= 0:
smaller_shape_nones.append(smaller_shape[j_small])
j_small -= 1
if j_large >= 0:
larger_shape_nones.append(larger_shape[j_large])
j_large -= 1
smaller_shape_nones.reverse()
larger_shape_nones.reverse()
diff_dims = len(larger_shape_nones) - len(smaller_shape_nones)
return tuple(
larger_slice[i] if smaller_shape[i - diff_dims] != 1 else slice(0, larger_shape[i])
for i in range(diff_dims, len(larger_shape))
None
if larger_slice[i] is None
else (
larger_slice[i] if smaller_shape_nones[i - diff_dims] != 1 else slice(0, larger_shape_nones[i])
)
for i in range(diff_dims, len(larger_shape_nones))
)


Expand Down Expand Up @@ -1694,7 +1717,6 @@ def slices_eval_getitem(
_slice_bcast = tuple(slice(i, i + 1) if isinstance(i, int) else i for i in _slice.raw)
slice_shape = ndindex.ndindex(_slice_bcast).newshape(shape) # includes dummy dimensions
_slice = _slice.raw
offset = tuple(s.start for s in _slice_bcast) # offset for the udf

# Get the slice of each operand
slice_operands = {}
Expand All @@ -1715,6 +1737,7 @@ def slices_eval_getitem(

# Evaluate the expression using slices of operands
if callable(expression):
offset = tuple(0 if s is None else s.start for s in _slice_bcast) # offset for the udf
result = np.empty(slice_shape, dtype=dtype)
expression(tuple(slice_operands.values()), result, offset=offset)
else:
Expand Down Expand Up @@ -2161,7 +2184,7 @@ def chunked_eval( # noqa: C901
"""
try:
# standardise slice to be ndindex.Tuple
item = () if item in (None, slice(None, None, None)) else item
item = () if item == slice(None, None, None) else item
item = item if isinstance(item, tuple) else (item,)
item = tuple(
slice(s.start, s.stop, 1 if s.step is None else s.step) if isinstance(s, slice) else s
Expand Down
7 changes: 7 additions & 0 deletions tests/ndarray/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@
([77, 134, 13], [31, 13, 5], [7, 8, 3], (slice(3, 7), slice(50, 100), 7), np.float64),
([77, 134, 13], [31, 13, 5], [7, 8, 3], (slice(3, 56, 3), slice(100, 50, -4), 7), np.float64),
([12, 13, 14, 15, 16], [5, 5, 5, 5, 5], [2, 2, 2, 2, 2], (slice(1, 3), ..., slice(3, 6)), np.float32),
(
[12, 13, 14, 15, 16],
[5, 5, 5, 5, 5],
[2, 2, 2, 2, 2],
(None, slice(1, 3), None, ..., slice(3, 6)),
np.float32,
),
]


Expand Down
40 changes: 19 additions & 21 deletions tests/ndarray/test_lazyexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def test_simple_getitem(array_fixture):
res = expr[sl]
np.testing.assert_allclose(res, nres[sl])

# Test None indexing
sl = (None, slice(3, 8), None)
res = expr[sl]
np.testing.assert_allclose(res, nres[sl])


# Mix Proxy and NDArray operands
def test_proxy_simple_getitem(array_fixture):
Expand All @@ -114,14 +119,13 @@ def test_mix_operands(array_fixture):
np.testing.assert_allclose(expr[:], nres)
np.testing.assert_allclose(expr.compute()[:], nres)

# TODO: fix this
# expr = na2 + a1
# nres = ne_evaluate("na2 + na1")
# sl = slice(100)
# res = expr[sl]
# np.testing.assert_allclose(res, nres[sl])
# np.testing.assert_allclose(expr[:], nres)
# np.testing.assert_allclose(expr.compute()[:], nres)
expr = na2 + a1
nres = ne_evaluate("na2 + na1")
sl = slice(100)
res = expr[sl]
np.testing.assert_allclose(res, nres[sl])
np.testing.assert_allclose(expr[:], nres)
np.testing.assert_allclose(expr.compute()[:], nres)

expr = a1 + na2 + a3
nres = ne_evaluate("na1 + na2 + na3")
Expand Down Expand Up @@ -151,19 +155,13 @@ def test_mix_operands(array_fixture):
np.testing.assert_allclose(expr[:], nres)
np.testing.assert_allclose(expr.compute()[:], nres)

# TODO: support this case
# expr = a1 + na2 * a3
# print("--------------------------------------------------------")
# print(type(expr))
# print(expr.expression)
# print(expr.operands)
# print("--------------------------------------------------------")
# nres = ne_evaluate("na1 + na2 * na3")
# sl = slice(100)
# res = expr[sl]
# np.testing.assert_allclose(res, nres[sl])
# np.testing.assert_allclose(expr[:], nres)
# np.testing.assert_allclose(expr.compute()[:], nres)
expr = a1 + na2 * a3
nres = ne_evaluate("na1 + na2 * na3")
sl = slice(100)
res = expr[sl]
np.testing.assert_allclose(res, nres[sl])
np.testing.assert_allclose(expr[:], nres)
np.testing.assert_allclose(expr.compute()[:], nres)


# Add more test functions to test different aspects of the code
Expand Down
2 changes: 1 addition & 1 deletion tests/ndarray/test_lazyudf.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def test_params(chunked_eval):
[
((40, 20), (30, 10), (5, 5), (slice(0, 5), slice(5, 20)), "eval.b2nd", False),
((13, 13, 10), (10, 10, 5), (5, 5, 3), (slice(0, 12), slice(3, 13), ...), "eval.b2nd", True),
((13, 13), (10, 10), (5, 5), (slice(3, 8), slice(9, 12)), None, False),
((13, 13), (10, 10), (5, 5), (slice(3, 8), None, slice(9, 12)), None, False),
],
)
def test_getitem(shape, chunks, blocks, slices, urlpath, contiguous, chunked_eval):
Expand Down
Loading