diff --git a/src/blosc2/lazyexpr.py b/src/blosc2/lazyexpr.py index 7613c9f2..dd177957 100644 --- a/src/blosc2/lazyexpr.py +++ b/src/blosc2/lazyexpr.py @@ -510,13 +510,15 @@ def check_smaller_shape(value_shape, shape, slice_shape, slice_): This follows the NumPy broadcasting rules. """ # slice_shape must be as long as shape - if len(slice_shape) != len(shape): - raise ValueError("slice_shape must be as long as shape") + if len(slice_shape) != len(slice_): + raise ValueError("slice_shape must be as long as slice_") + no_nones_shape = tuple(sh for sh, s in zip(slice_shape, slice_, strict=True) if s is not None) + no_nones_slice = tuple(s for sh, s in zip(slice_shape, slice_, strict=True) if s is not None) is_smaller_shape = any( - s > (1 if i >= len(value_shape) else value_shape[i]) for i, s in enumerate(slice_shape) + s > (1 if i >= len(value_shape) else value_shape[i]) for i, s in enumerate(no_nones_shape) ) slice_past_bounds = any( - s.stop > (1 if i >= len(value_shape) else value_shape[i]) for i, s in enumerate(slice_) + s.stop > (1 if i >= len(value_shape) else value_shape[i]) for i, s in enumerate(no_nones_slice) ) return len(value_shape) < len(shape) or is_smaller_shape or slice_past_bounds @@ -547,10 +549,31 @@ def compute_smaller_slice(larger_shape, smaller_shape, larger_slice): """ Returns the slice of the smaller array that corresponds to the slice of the larger array. """ - diff_dims = len(larger_shape) - len(smaller_shape) + j_small = len(smaller_shape) - 1 + j_large = len(larger_shape) - 1 + smaller_shape_nones = [] + larger_shape_nones = [] + for s in reversed(larger_slice): + if s is None: + smaller_shape_nones.append(1) + larger_shape_nones.append(1) + else: + if j_small >= 0: + smaller_shape_nones.append(smaller_shape[j_small]) + j_small -= 1 + if j_large >= 0: + larger_shape_nones.append(larger_shape[j_large]) + j_large -= 1 + smaller_shape_nones.reverse() + larger_shape_nones.reverse() + diff_dims = len(larger_shape_nones) - len(smaller_shape_nones) return tuple( - larger_slice[i] if smaller_shape[i - diff_dims] != 1 else slice(0, larger_shape[i]) - for i in range(diff_dims, len(larger_shape)) + None + if larger_slice[i] is None + else ( + larger_slice[i] if smaller_shape_nones[i - diff_dims] != 1 else slice(0, larger_shape_nones[i]) + ) + for i in range(diff_dims, len(larger_shape_nones)) ) @@ -1694,7 +1717,6 @@ def slices_eval_getitem( _slice_bcast = tuple(slice(i, i + 1) if isinstance(i, int) else i for i in _slice.raw) slice_shape = ndindex.ndindex(_slice_bcast).newshape(shape) # includes dummy dimensions _slice = _slice.raw - offset = tuple(s.start for s in _slice_bcast) # offset for the udf # Get the slice of each operand slice_operands = {} @@ -1715,6 +1737,7 @@ def slices_eval_getitem( # Evaluate the expression using slices of operands if callable(expression): + offset = tuple(0 if s is None else s.start for s in _slice_bcast) # offset for the udf result = np.empty(slice_shape, dtype=dtype) expression(tuple(slice_operands.values()), result, offset=offset) else: @@ -2161,7 +2184,7 @@ def chunked_eval( # noqa: C901 """ try: # standardise slice to be ndindex.Tuple - item = () if item in (None, slice(None, None, None)) else item + item = () if item == slice(None, None, None) else item item = item if isinstance(item, tuple) else (item,) item = tuple( slice(s.start, s.stop, 1 if s.step is None else s.step) if isinstance(s, slice) else s diff --git a/tests/ndarray/test_getitem.py b/tests/ndarray/test_getitem.py index 7fadd897..b309b06f 100644 --- a/tests/ndarray/test_getitem.py +++ b/tests/ndarray/test_getitem.py @@ -18,6 +18,13 @@ ([77, 134, 13], [31, 13, 5], [7, 8, 3], (slice(3, 7), slice(50, 100), 7), np.float64), ([77, 134, 13], [31, 13, 5], [7, 8, 3], (slice(3, 56, 3), slice(100, 50, -4), 7), np.float64), ([12, 13, 14, 15, 16], [5, 5, 5, 5, 5], [2, 2, 2, 2, 2], (slice(1, 3), ..., slice(3, 6)), np.float32), + ( + [12, 13, 14, 15, 16], + [5, 5, 5, 5, 5], + [2, 2, 2, 2, 2], + (None, slice(1, 3), None, ..., slice(3, 6)), + np.float32, + ), ] diff --git a/tests/ndarray/test_lazyexpr.py b/tests/ndarray/test_lazyexpr.py index 88136602..e102e6d0 100644 --- a/tests/ndarray/test_lazyexpr.py +++ b/tests/ndarray/test_lazyexpr.py @@ -90,6 +90,11 @@ def test_simple_getitem(array_fixture): res = expr[sl] np.testing.assert_allclose(res, nres[sl]) + # Test None indexing + sl = (None, slice(3, 8), None) + res = expr[sl] + np.testing.assert_allclose(res, nres[sl]) + # Mix Proxy and NDArray operands def test_proxy_simple_getitem(array_fixture): @@ -114,14 +119,13 @@ def test_mix_operands(array_fixture): np.testing.assert_allclose(expr[:], nres) np.testing.assert_allclose(expr.compute()[:], nres) - # TODO: fix this - # expr = na2 + a1 - # nres = ne_evaluate("na2 + na1") - # sl = slice(100) - # res = expr[sl] - # np.testing.assert_allclose(res, nres[sl]) - # np.testing.assert_allclose(expr[:], nres) - # np.testing.assert_allclose(expr.compute()[:], nres) + expr = na2 + a1 + nres = ne_evaluate("na2 + na1") + sl = slice(100) + res = expr[sl] + np.testing.assert_allclose(res, nres[sl]) + np.testing.assert_allclose(expr[:], nres) + np.testing.assert_allclose(expr.compute()[:], nres) expr = a1 + na2 + a3 nres = ne_evaluate("na1 + na2 + na3") @@ -151,19 +155,13 @@ def test_mix_operands(array_fixture): np.testing.assert_allclose(expr[:], nres) np.testing.assert_allclose(expr.compute()[:], nres) - # TODO: support this case - # expr = a1 + na2 * a3 - # print("--------------------------------------------------------") - # print(type(expr)) - # print(expr.expression) - # print(expr.operands) - # print("--------------------------------------------------------") - # nres = ne_evaluate("na1 + na2 * na3") - # sl = slice(100) - # res = expr[sl] - # np.testing.assert_allclose(res, nres[sl]) - # np.testing.assert_allclose(expr[:], nres) - # np.testing.assert_allclose(expr.compute()[:], nres) + expr = a1 + na2 * a3 + nres = ne_evaluate("na1 + na2 * na3") + sl = slice(100) + res = expr[sl] + np.testing.assert_allclose(res, nres[sl]) + np.testing.assert_allclose(expr[:], nres) + np.testing.assert_allclose(expr.compute()[:], nres) # Add more test functions to test different aspects of the code diff --git a/tests/ndarray/test_lazyudf.py b/tests/ndarray/test_lazyudf.py index 29b44e39..9ea7229a 100644 --- a/tests/ndarray/test_lazyudf.py +++ b/tests/ndarray/test_lazyudf.py @@ -259,7 +259,7 @@ def test_params(chunked_eval): [ ((40, 20), (30, 10), (5, 5), (slice(0, 5), slice(5, 20)), "eval.b2nd", False), ((13, 13, 10), (10, 10, 5), (5, 5, 3), (slice(0, 12), slice(3, 13), ...), "eval.b2nd", True), - ((13, 13), (10, 10), (5, 5), (slice(3, 8), slice(9, 12)), None, False), + ((13, 13), (10, 10), (5, 5), (slice(3, 8), None, slice(9, 12)), None, False), ], ) def test_getitem(shape, chunks, blocks, slices, urlpath, contiguous, chunked_eval):