Skip to content

Commit

Permalink
Use ndindex for implementing 'index'
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Jun 18, 2024
1 parent c61ffd5 commit 053810c
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 79 deletions.
126 changes: 60 additions & 66 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,11 @@
from typing import TYPE_CHECKING, Any, Sequence, Union
from warnings import warn

import ndindex
import numpy as np
import zarr
from tlz import concat, first, partition
from toolz import accumulate, map
from zarr.indexing import (
IntDimIndexer,
OrthogonalIndexer,
SliceDimIndexer,
is_integer_list,
replace_ellipsis,
)

from cubed import config
from cubed.backend_array_api import namespace as nxp
Expand Down Expand Up @@ -407,72 +401,69 @@ def index(x, key):
if not isinstance(key, tuple):
key = (key,)

# No op case
if all(isinstance(ind, slice) and ind == slice(None) for ind in key):
return x

# Remove None values, to be filled in with expand_dims at end
where_none = [i for i, ind in enumerate(key) if ind is None]
for i, a in enumerate(where_none):
n = sum(isinstance(ind, Integral) for ind in key[:a])
if n:
where_none[i] -= n
key = tuple(ind for ind in key if ind is not None)

# Replace arrays with lists - note that this may trigger a computation!
selection = tuple(
dim_sel.compute().tolist() if isinstance(dim_sel, CoreArray) else dim_sel
# Replace Cubed arrays with NumPy arrays - note that this may trigger a computation!
key = tuple(
dim_sel.compute() if isinstance(dim_sel, CoreArray) else dim_sel
for dim_sel in key
)
# Replace np.ndarray with lists
# Note that this shouldn't be needed, instead change xarray to return array API types
# (Variable.__getitem__ -> _broadcast_indexes creates an OuterIndexer that always uses np.array)
selection = tuple(
dim_sel.tolist() if isinstance(dim_sel, np.ndarray) else dim_sel
for dim_sel in selection
)
# Replace ellipsis with slices
selection = replace_ellipsis(selection, x.shape)

# Canonicalize index
idx = ndindex.ndindex(key)
idx = idx.expand(x.shape)

# Remove newaxis values, to be filled in with expand_dims at end
where_newaxis = [
i for i, ia in enumerate(idx.args) if isinstance(ia, ndindex.Newaxis)
]
for i, a in enumerate(where_newaxis):
n = sum(isinstance(ia, ndindex.Integer) for ia in idx.args[:a])
if n:
where_newaxis[i] -= n
idx = ndindex.Tuple(*(ia for ia in idx.args if not isinstance(ia, ndindex.Newaxis)))
selection = idx.raw

# Check selection is supported
if any(
s.step is not None and s.step < 1 for s in selection if isinstance(s, slice)
):
if any(ia.step < 1 for ia in idx.args if isinstance(ia, ndindex.Slice)):
raise NotImplementedError(f"Slice step must be >= 1: {key}")
assert all(isinstance(s, (slice, list, Integral)) for s in selection)
where_list = [i for i, ind in enumerate(selection) if is_integer_list(ind)]
if len(where_list) > 1:
if not all(
isinstance(ia, (ndindex.Integer, ndindex.Slice, ndindex.IntegerArray))
for ia in idx.args
):
raise NotImplementedError(
"Only integer, slice, or integer array indexes are allowed."
)
if sum(1 for ia in idx.args if isinstance(ia, ndindex.IntegerArray)) > 1:
raise NotImplementedError("Only one integer array index is allowed.")

# Use a Zarr indexer just to find the resulting array shape and chunks
# Need to use an in-memory representation since the Zarr file has not been written yet
inmem = zarr.empty_like(x.zarray_maybe_lazy, store=zarr.storage.MemoryStore())
indexer = OrthogonalIndexer(selection, inmem)

def chunk_len_for_indexer(s):
if not isinstance(s, SliceDimIndexer):
return s.dim_chunk_len
return max(s.dim_chunk_len // s.step, 1)

def merged_chunk_len_for_indexer(s):
if not isinstance(s, SliceDimIndexer):
return s.dim_chunk_len
if s.step is None or s.step == 1:
return s.dim_chunk_len
if (s.dim_chunk_len // s.step) < 1:
return s.dim_chunk_len
# note that this may not be the same as s.dim_chunk_len
# Use ndindex to find the resulting array shape and chunks

def chunk_len_for_indexer(ia, c):
if not isinstance(ia, ndindex.Slice):
return c
return max(c // ia.step, 1)

def merged_chunk_len_for_indexer(ia, c):
if not isinstance(ia, ndindex.Slice):
return c
if ia.step == 1:
return c
if (c // ia.step) < 1:
return c
# note that this may not be the same as c
# but it is guaranteed to be a multiple of the corresponding
# value returned by chunk_len_for_indexer, which is required
# by merge_chunks
return (s.dim_chunk_len // s.step) * s.step
return (c // ia.step) * ia.step

shape = indexer.shape
shape = idx.newshape(x.shape)
if shape == x.shape:
# no op case
return x
dtype = x.dtype
chunks = tuple(
chunk_len_for_indexer(s)
for s in indexer.dim_indexers
if not isinstance(s, IntDimIndexer)
chunk_len_for_indexer(ia, c)
for ia, c in zip(idx.args, x.chunksize)
if not isinstance(ia, ndindex.Integer)
)

target_chunks = normalize_chunks(chunks, shape, dtype=dtype)
Expand All @@ -496,14 +487,14 @@ def merged_chunk_len_for_indexer(s):
# merge chunks for any dims with step > 1 so they are
# the same size as the input (or slightly smaller due to rounding)
merged_chunks = tuple(
merged_chunk_len_for_indexer(s)
for s in indexer.dim_indexers
if not isinstance(s, IntDimIndexer)
merged_chunk_len_for_indexer(ia, c)
for ia, c in zip(idx.args, x.chunksize)
if not isinstance(ia, ndindex.Integer)
)
if chunks != merged_chunks:
out = merge_chunks(out, merged_chunks)

for axis in where_none:
for axis in where_newaxis:
from cubed.array_api.manipulation_functions import expand_dims

out = expand_dims(out, axis=axis)
Expand Down Expand Up @@ -545,7 +536,8 @@ def _target_chunk_selection(target_chunks, idx, selection):
j = idx[i]
sel.append(slice(start[j], start[j + 1], step))
i += 1
elif is_integer_list(s):
# ndindex uses np.ndarray for integer arrays
elif isinstance(s, np.ndarray):
# find the cumulative chunk starts
target_chunk_starts = [0] + list(
accumulate(add, [c for c in target_chunks[i]])
Expand All @@ -554,9 +546,11 @@ def _target_chunk_selection(target_chunks, idx, selection):
j = idx[i]
sel.append(s[target_chunk_starts[j] : target_chunk_starts[j + 1]])
i += 1
else:
elif isinstance(s, int):
sel.append(s)
# don't increment i since integer indexes don't have a dimension in the target
else:
raise ValueError(f"Unsupported selection: {s}")
return tuple(sel)


Expand Down
37 changes: 30 additions & 7 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,16 @@ def test_negative(spec, executor):
@pytest.mark.parametrize(
"ind",
[
Ellipsis,
6,
(6, None), # add a new dimension
(None, 6), # add a new dimension
(6, xp.newaxis),
(xp.newaxis, 6),
slice(None),
slice(10),
slice(3, None),
slice(3, 3),
slice(3, 10),
(slice(10), None), # add a new dimension
(slice(10), xp.newaxis),
],
)
def test_index_1d(spec, ind):
Expand All @@ -235,17 +236,20 @@ def test_index_1d(spec, ind):
@pytest.mark.parametrize(
"ind",
[
Ellipsis,
(2, 3),
(None, 2, 3), # add a new dimension
(xp.newaxis, 2, 3),
(Ellipsis, slice(2, 4)),
(slice(None), slice(2, 2)),
(slice(None), slice(2, 4)),
(slice(3), slice(2, None)),
(slice(1, None), slice(4)),
(slice(1, 3), Ellipsis),
(slice(1, 1), slice(None)),
(slice(1, 3), slice(None)),
(None, slice(None), slice(2, 4)), # add a new dimension
(slice(None), None, slice(2, 4)), # add a new dimension
(slice(None), slice(2, 4), None), # add a new dimension
(xp.newaxis, slice(None), slice(2, 4)),
(slice(None), xp.newaxis, slice(2, 4)),
(slice(None), slice(2, 4), xp.newaxis),
(slice(None), 1),
(1, slice(2, 4)),
],
Expand All @@ -260,6 +264,25 @@ def test_index_2d(spec, ind):
assert_array_equal(a[ind].compute(), x[ind])


@pytest.mark.parametrize(
"ind",
[
Ellipsis,
(slice(None), slice(None)),
(slice(0, 4), slice(None)),
(slice(None), slice(0, 4)),
(slice(0, 4), slice(0, 4)),
],
)
def test_index_2d_no_op(spec, ind):
a = xp.asarray(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
chunks=(2, 2),
spec=spec,
)
assert a is a[ind]


@pytest.mark.parametrize(
"shape, chunks, ind, new_chunks_expected",
[
Expand Down
28 changes: 22 additions & 6 deletions cubed/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def spec(tmp_path):
"ind",
[
[6, 7, 2, 9, 10],
([6, 7, 2, 9, 10], None), # add a new dimension
(None, [6, 7, 2, 9, 10]), # add a new dimension
([6, 7, 2, 9, 10], xp.newaxis),
(xp.newaxis, [6, 7, 2, 9, 10]),
],
)
def test_int_array_index_1d(spec, ind):
Expand All @@ -33,9 +33,9 @@ def test_int_array_index_1d(spec, ind):
[
(slice(None), [2, 1]),
([1, 2, 1], slice(None)),
(None, slice(None), [2, 1]),
(slice(None), None, [2, 1]),
(slice(None), [2, 1], None),
(xp.newaxis, slice(None), [2, 1]),
(slice(None), xp.newaxis, [2, 1]),
(slice(None), [2, 1], xp.newaxis),
],
)
def test_int_array_index_2d(spec, ind):
Expand All @@ -49,11 +49,27 @@ def test_int_array_index_2d(spec, ind):
assert_array_equal(b[ind].compute(), x[ind])


@pytest.mark.parametrize(
"ind",
[
(slice(None), [0, 1, 2, 3]),
([0, 1, 2, 3], slice(None)),
],
)
def test_int_array_index_2d_no_op(spec, ind):
a = xp.asarray(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
chunks=(3, 3),
spec=spec,
)
assert a is a[ind]


def test_multiple_int_array_indexes(spec):
with pytest.raises(NotImplementedError):
a = xp.asarray(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
chunks=(2, 2),
spec=spec,
)
a[[1, 2, 1], [2, 1]]
a[[1, 2, 1], [2, 1, 0]]

0 comments on commit 053810c

Please sign in to comment.