From fe71cb4d792b224c1a26e04a2ca77a3768a5bcce Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 19 Jun 2024 09:16:50 +0100 Subject: [PATCH] Use ndindex for implementing 'index' (#481) --- cubed/core/ops.py | 126 ++++++++++++++++------------------ cubed/tests/test_array_api.py | 37 ++++++++-- cubed/tests/test_indexing.py | 28 ++++++-- 3 files changed, 112 insertions(+), 79 deletions(-) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 90df1f72..83b8e90b 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -8,17 +8,11 @@ from typing import TYPE_CHECKING, Any, Sequence, Union from warnings import warn +import ndindex import numpy as np import zarr from tlz import concat, first, partition from toolz import accumulate, map -from zarr.indexing import ( - IntDimIndexer, - OrthogonalIndexer, - SliceDimIndexer, - is_integer_list, - replace_ellipsis, -) from cubed import config from cubed.backend_array_api import namespace as nxp @@ -407,72 +401,69 @@ def index(x, key): if not isinstance(key, tuple): key = (key,) - # No op case - if all(isinstance(ind, slice) and ind == slice(None) for ind in key): - return x - - # Remove None values, to be filled in with expand_dims at end - where_none = [i for i, ind in enumerate(key) if ind is None] - for i, a in enumerate(where_none): - n = sum(isinstance(ind, Integral) for ind in key[:a]) - if n: - where_none[i] -= n - key = tuple(ind for ind in key if ind is not None) - - # Replace arrays with lists - note that this may trigger a computation! - selection = tuple( - dim_sel.compute().tolist() if isinstance(dim_sel, CoreArray) else dim_sel + # Replace Cubed arrays with NumPy arrays - note that this may trigger a computation! + key = tuple( + dim_sel.compute() if isinstance(dim_sel, CoreArray) else dim_sel for dim_sel in key ) - # Replace np.ndarray with lists - # Note that this shouldn't be needed, instead change xarray to return array API types - # (Variable.__getitem__ -> _broadcast_indexes creates an OuterIndexer that always uses np.array) - selection = tuple( - dim_sel.tolist() if isinstance(dim_sel, np.ndarray) else dim_sel - for dim_sel in selection - ) - # Replace ellipsis with slices - selection = replace_ellipsis(selection, x.shape) + + # Canonicalize index + idx = ndindex.ndindex(key) + idx = idx.expand(x.shape) + + # Remove newaxis values, to be filled in with expand_dims at end + where_newaxis = [ + i for i, ia in enumerate(idx.args) if isinstance(ia, ndindex.Newaxis) + ] + for i, a in enumerate(where_newaxis): + n = sum(isinstance(ia, ndindex.Integer) for ia in idx.args[:a]) + if n: + where_newaxis[i] -= n + idx = ndindex.Tuple(*(ia for ia in idx.args if not isinstance(ia, ndindex.Newaxis))) + selection = idx.raw # Check selection is supported - if any( - s.step is not None and s.step < 1 for s in selection if isinstance(s, slice) - ): + if any(ia.step < 1 for ia in idx.args if isinstance(ia, ndindex.Slice)): raise NotImplementedError(f"Slice step must be >= 1: {key}") - assert all(isinstance(s, (slice, list, Integral)) for s in selection) - where_list = [i for i, ind in enumerate(selection) if is_integer_list(ind)] - if len(where_list) > 1: + if not all( + isinstance(ia, (ndindex.Integer, ndindex.Slice, ndindex.IntegerArray)) + for ia in idx.args + ): + raise NotImplementedError( + "Only integer, slice, or integer array indexes are allowed." + ) + if sum(1 for ia in idx.args if isinstance(ia, ndindex.IntegerArray)) > 1: raise NotImplementedError("Only one integer array index is allowed.") - # Use a Zarr indexer just to find the resulting array shape and chunks - # Need to use an in-memory representation since the Zarr file has not been written yet - inmem = zarr.empty_like(x.zarray_maybe_lazy, store=zarr.storage.MemoryStore()) - indexer = OrthogonalIndexer(selection, inmem) - - def chunk_len_for_indexer(s): - if not isinstance(s, SliceDimIndexer): - return s.dim_chunk_len - return max(s.dim_chunk_len // s.step, 1) - - def merged_chunk_len_for_indexer(s): - if not isinstance(s, SliceDimIndexer): - return s.dim_chunk_len - if s.step is None or s.step == 1: - return s.dim_chunk_len - if (s.dim_chunk_len // s.step) < 1: - return s.dim_chunk_len - # note that this may not be the same as s.dim_chunk_len + # Use ndindex to find the resulting array shape and chunks + + def chunk_len_for_indexer(ia, c): + if not isinstance(ia, ndindex.Slice): + return c + return max(c // ia.step, 1) + + def merged_chunk_len_for_indexer(ia, c): + if not isinstance(ia, ndindex.Slice): + return c + if ia.step == 1: + return c + if (c // ia.step) < 1: + return c + # note that this may not be the same as c # but it is guaranteed to be a multiple of the corresponding # value returned by chunk_len_for_indexer, which is required # by merge_chunks - return (s.dim_chunk_len // s.step) * s.step + return (c // ia.step) * ia.step - shape = indexer.shape + shape = idx.newshape(x.shape) + if shape == x.shape: + # no op case + return x dtype = x.dtype chunks = tuple( - chunk_len_for_indexer(s) - for s in indexer.dim_indexers - if not isinstance(s, IntDimIndexer) + chunk_len_for_indexer(ia, c) + for ia, c in zip(idx.args, x.chunksize) + if not isinstance(ia, ndindex.Integer) ) target_chunks = normalize_chunks(chunks, shape, dtype=dtype) @@ -496,14 +487,14 @@ def merged_chunk_len_for_indexer(s): # merge chunks for any dims with step > 1 so they are # the same size as the input (or slightly smaller due to rounding) merged_chunks = tuple( - merged_chunk_len_for_indexer(s) - for s in indexer.dim_indexers - if not isinstance(s, IntDimIndexer) + merged_chunk_len_for_indexer(ia, c) + for ia, c in zip(idx.args, x.chunksize) + if not isinstance(ia, ndindex.Integer) ) if chunks != merged_chunks: out = merge_chunks(out, merged_chunks) - for axis in where_none: + for axis in where_newaxis: from cubed.array_api.manipulation_functions import expand_dims out = expand_dims(out, axis=axis) @@ -545,7 +536,8 @@ def _target_chunk_selection(target_chunks, idx, selection): j = idx[i] sel.append(slice(start[j], start[j + 1], step)) i += 1 - elif is_integer_list(s): + # ndindex uses np.ndarray for integer arrays + elif isinstance(s, np.ndarray): # find the cumulative chunk starts target_chunk_starts = [0] + list( accumulate(add, [c for c in target_chunks[i]]) @@ -554,9 +546,11 @@ def _target_chunk_selection(target_chunks, idx, selection): j = idx[i] sel.append(s[target_chunk_starts[j] : target_chunk_starts[j + 1]]) i += 1 - else: + elif isinstance(s, int): sel.append(s) # don't increment i since integer indexes don't have a dimension in the target + else: + raise ValueError(f"Unsupported selection: {s}") return tuple(sel) diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index c5fb2d6f..2be010f0 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -216,15 +216,16 @@ def test_negative(spec, executor): @pytest.mark.parametrize( "ind", [ + Ellipsis, 6, - (6, None), # add a new dimension - (None, 6), # add a new dimension + (6, xp.newaxis), + (xp.newaxis, 6), slice(None), slice(10), slice(3, None), slice(3, 3), slice(3, 10), - (slice(10), None), # add a new dimension + (slice(10), xp.newaxis), ], ) def test_index_1d(spec, ind): @@ -235,17 +236,20 @@ def test_index_1d(spec, ind): @pytest.mark.parametrize( "ind", [ + Ellipsis, (2, 3), - (None, 2, 3), # add a new dimension + (xp.newaxis, 2, 3), + (Ellipsis, slice(2, 4)), (slice(None), slice(2, 2)), (slice(None), slice(2, 4)), (slice(3), slice(2, None)), (slice(1, None), slice(4)), + (slice(1, 3), Ellipsis), (slice(1, 1), slice(None)), (slice(1, 3), slice(None)), - (None, slice(None), slice(2, 4)), # add a new dimension - (slice(None), None, slice(2, 4)), # add a new dimension - (slice(None), slice(2, 4), None), # add a new dimension + (xp.newaxis, slice(None), slice(2, 4)), + (slice(None), xp.newaxis, slice(2, 4)), + (slice(None), slice(2, 4), xp.newaxis), (slice(None), 1), (1, slice(2, 4)), ], @@ -260,6 +264,25 @@ def test_index_2d(spec, ind): assert_array_equal(a[ind].compute(), x[ind]) +@pytest.mark.parametrize( + "ind", + [ + Ellipsis, + (slice(None), slice(None)), + (slice(0, 4), slice(None)), + (slice(None), slice(0, 4)), + (slice(0, 4), slice(0, 4)), + ], +) +def test_index_2d_no_op(spec, ind): + a = xp.asarray( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + chunks=(2, 2), + spec=spec, + ) + assert a is a[ind] + + @pytest.mark.parametrize( "shape, chunks, ind, new_chunks_expected", [ diff --git a/cubed/tests/test_indexing.py b/cubed/tests/test_indexing.py index e9ac27e5..efa8c6cf 100644 --- a/cubed/tests/test_indexing.py +++ b/cubed/tests/test_indexing.py @@ -18,8 +18,8 @@ def spec(tmp_path): "ind", [ [6, 7, 2, 9, 10], - ([6, 7, 2, 9, 10], None), # add a new dimension - (None, [6, 7, 2, 9, 10]), # add a new dimension + ([6, 7, 2, 9, 10], xp.newaxis), + (xp.newaxis, [6, 7, 2, 9, 10]), ], ) def test_int_array_index_1d(spec, ind): @@ -33,9 +33,9 @@ def test_int_array_index_1d(spec, ind): [ (slice(None), [2, 1]), ([1, 2, 1], slice(None)), - (None, slice(None), [2, 1]), - (slice(None), None, [2, 1]), - (slice(None), [2, 1], None), + (xp.newaxis, slice(None), [2, 1]), + (slice(None), xp.newaxis, [2, 1]), + (slice(None), [2, 1], xp.newaxis), ], ) def test_int_array_index_2d(spec, ind): @@ -49,6 +49,22 @@ def test_int_array_index_2d(spec, ind): assert_array_equal(b[ind].compute(), x[ind]) +@pytest.mark.parametrize( + "ind", + [ + (slice(None), [0, 1, 2, 3]), + ([0, 1, 2, 3], slice(None)), + ], +) +def test_int_array_index_2d_no_op(spec, ind): + a = xp.asarray( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + chunks=(3, 3), + spec=spec, + ) + assert a is a[ind] + + def test_multiple_int_array_indexes(spec): with pytest.raises(NotImplementedError): a = xp.asarray( @@ -56,4 +72,4 @@ def test_multiple_int_array_indexes(spec): chunks=(2, 2), spec=spec, ) - a[[1, 2, 1], [2, 1]] + a[[1, 2, 1], [2, 1, 0]]