diff --git a/cubed/core/ops.py b/cubed/core/ops.py index deba169c5..3253fbea0 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -16,7 +16,6 @@ OrthogonalIndexer, SliceDimIndexer, is_integer_list, - is_slice, replace_ellipsis, ) @@ -408,7 +407,7 @@ def index(x, key): key = (key,) # No op case - if all(is_slice(ind) and ind == slice(None) for ind in key): + if all(isinstance(ind, slice) and ind == slice(None) for ind in key): return x # Remove None values, to be filled in with expand_dims at end @@ -435,7 +434,9 @@ def index(x, key): selection = replace_ellipsis(selection, x.shape) # Check selection is supported - if any(s.step is not None and s.step < 1 for s in selection if is_slice(s)): + if any( + s.step is not None and s.step < 1 for s in selection if isinstance(s, slice) + ): raise NotImplementedError(f"Slice step must be >= 1: {key}") assert all(isinstance(s, (slice, list, Integral)) for s in selection) where_list = [i for i, ind in enumerate(selection) if is_integer_list(ind)] @@ -489,7 +490,6 @@ def merged_chunk_len_for_indexer(s): extra_projected_mem=extra_projected_mem, target_chunks=target_chunks, selection=selection, - advanced_indexing=len(where_list) > 0, ) # merge chunks for any dims with step > 1 so they are @@ -515,13 +515,14 @@ def _read_index_chunk( *arrays, target_chunks=None, selection=None, - advanced_indexing=None, block_id=None, ): array = arrays[0].zarray - if advanced_indexing: - array = array.oindex idx = block_id + # Note that since we only have a maximum of one integer array index + # we don't need to use Zarr orthogonal indexing, since it is + # "available directly on the array" according to + # https://zarr.readthedocs.io/en/stable/tutorial.html#orthogonal-indexing out = array[_target_chunk_selection(target_chunks, idx, selection)] out = numpy_array_to_backend_array(out) return out @@ -534,7 +535,7 @@ def _target_chunk_selection(target_chunks, idx, selection): sel = [] i = 0 # index into target_chunks and idx for s in selection: - if is_slice(s): + if isinstance(s, slice): offset = s.start or 0 step = s.step if s.step is not None else 1 start = tuple( diff --git a/cubed/runtime/executors/modal.py b/cubed/runtime/executors/modal.py index df81fce2a..9a3fe40a7 100644 --- a/cubed/runtime/executors/modal.py +++ b/cubed/runtime/executors/modal.py @@ -21,6 +21,7 @@ "donfig", "fsspec", "mypy_extensions", # for rechunker + "ndindex", "networkx", "pytest-mock", # TODO: only needed for tests "s3fs", @@ -35,6 +36,7 @@ "donfig", "fsspec", "mypy_extensions", # for rechunker + "ndindex", "networkx", "pytest-mock", # TODO: only needed for tests "gcsfs", diff --git a/cubed/storage/virtual.py b/cubed/storage/virtual.py index a919321d1..4922f67f2 100644 --- a/cubed/storage/virtual.py +++ b/cubed/storage/virtual.py @@ -2,10 +2,8 @@ from typing import Any import numpy as np -import zarr -from zarr.indexing import BasicIndexer, is_slice +from ndindex import ndindex -from cubed.backend_array_api import backend_array_to_numpy_array from cubed.backend_array_api import namespace as nxp from cubed.backend_array_api import numpy_array_to_backend_array from cubed.types import T_DType, T_RegularChunks, T_Shape @@ -21,25 +19,15 @@ def __init__( dtype: T_DType, chunks: T_RegularChunks, ): - # use an empty in-memory Zarr array as a template since it normalizes its properties - template = zarr.empty( - shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore() - ) - self.shape = template.shape - self.dtype = template.dtype - self.chunks = template.chunks - self.template = template + self.shape = shape + self.dtype = np.dtype(dtype) + self.chunks = chunks def __getitem__(self, key): - if not isinstance(key, tuple): - key = (key,) - indexer = BasicIndexer(key, self.template) + idx = ndindex[key] + newshape = idx.newshape(self.shape) # use broadcast trick so array chunks only occupy a single value in memory - return broadcast_trick(nxp.empty)(indexer.shape, dtype=self.dtype) - - @property - def oindex(self): - return self.template.oindex + return broadcast_trick(nxp.empty)(newshape, dtype=self.dtype) class VirtualFullArray: @@ -52,33 +40,19 @@ def __init__( chunks: T_RegularChunks, fill_value: Any = None, ): - # use an empty in-memory Zarr array as a template since it normalizes its properties - template = zarr.full( - shape, - fill_value, - dtype=dtype, - chunks=chunks, - store=zarr.storage.MemoryStore(), - ) - self.shape = template.shape - self.dtype = template.dtype - self.chunks = template.chunks - self.template = template + self.shape = shape + self.dtype = np.dtype(dtype) + self.chunks = chunks self.fill_value = fill_value def __getitem__(self, key): - if not isinstance(key, tuple): - key = (key,) - indexer = BasicIndexer(key, self.template) + idx = ndindex[key] + newshape = idx.newshape(self.shape) # use broadcast trick so array chunks only occupy a single value in memory return broadcast_trick(nxp.full)( - indexer.shape, fill_value=self.fill_value, dtype=self.dtype + newshape, fill_value=self.fill_value, dtype=self.dtype ) - @property - def oindex(self): - return self.template.oindex - class VirtualOffsetsArray: """An array that is never materialized (in memory or on disk) and contains sequentially incrementing integers.""" @@ -86,14 +60,9 @@ class VirtualOffsetsArray: def __init__(self, shape: T_Shape): dtype = nxp.int32 chunks = (1,) * len(shape) - # use an empty in-memory Zarr array as a template since it normalizes its properties - template = zarr.empty( - shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore() - ) - self.shape = template.shape - self.dtype = template.dtype - self.chunks = template.chunks - self.ndim = template.ndim + self.shape = shape + self.dtype = np.dtype(dtype) + self.chunks = chunks def __getitem__(self, key): if key == () and self.shape == (): @@ -117,28 +86,13 @@ def __init__( f"Size of in memory array is {memory_repr(array.nbytes)} which exceeds maximum of {memory_repr(max_nbytes)}. Consider loading the array from storage using `from_array`." ) self.array = array - # use an in-memory Zarr array as a template since it normalizes its properties - # and is needed for oindex - template = zarr.empty( - array.shape, - dtype=array.dtype, - chunks=chunks, - store=zarr.storage.MemoryStore(), - ) - self.shape = template.shape - self.dtype = template.dtype - self.chunks = template.chunks - self.template = template - if array.size > 0: - template[...] = backend_array_to_numpy_array(array) + self.shape = array.shape + self.dtype = array.dtype + self.chunks = chunks def __getitem__(self, key): return self.array.__getitem__(key) - @property - def oindex(self): - return self.template.oindex - def _key_to_index_tuple(selection): if isinstance(selection, slice): @@ -148,7 +102,11 @@ def _key_to_index_tuple(selection): for s in selection: if isinstance(s, Integral): sel.append(s) - elif is_slice(s) and s.stop == s.start + 1 and (s.step is None or s.step == 1): + elif ( + isinstance(s, slice) + and s.stop == s.start + 1 + and (s.step is None or s.step == 1) + ): sel.append(s.start) else: raise NotImplementedError(f"Offset selection not supported: {selection}") diff --git a/cubed/storage/zarr.py b/cubed/storage/zarr.py index af048d78e..136740b26 100644 --- a/cubed/storage/zarr.py +++ b/cubed/storage/zarr.py @@ -1,6 +1,9 @@ +from operator import mul from typing import Optional, Union +import numpy as np import zarr +from toolz import reduce from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store @@ -23,18 +26,23 @@ def __init__( **kwargs, ): """Create a Zarr array lazily in memory.""" - # use an empty in-memory Zarr array as a template since it normalizes its properties - template = zarr.empty( - shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore() - ) - self.shape = template.shape - self.dtype = template.dtype - self.chunks = template.chunks - self.nbytes = template.nbytes + self.shape = shape + self.dtype = np.dtype(dtype) + self.chunks = chunks self.store = store self.path = path self.kwargs = kwargs + @property + def size(self): + """Number of elements in the array.""" + return reduce(mul, self.shape, 1) + + @property + def nbytes(self) -> int: + """Number of bytes in array""" + return self.size * self.dtype.itemsize + def create(self, mode: str = "w-") -> zarr.Array: """Create the Zarr array in storage. diff --git a/cubed/tests/runtime/test_modal_async.py b/cubed/tests/runtime/test_modal_async.py index 3b4b13ec0..7efae3064 100644 --- a/cubed/tests/runtime/test_modal_async.py +++ b/cubed/tests/runtime/test_modal_async.py @@ -23,6 +23,7 @@ "donfig", "fsspec", "mypy_extensions", # for rechunker + "ndindex", "networkx", "pytest-mock", # TODO: only needed for tests "s3fs", diff --git a/cubed/tests/test_indexing.py b/cubed/tests/test_indexing.py index fc7b3d9f1..e9ac27e52 100644 --- a/cubed/tests/test_indexing.py +++ b/cubed/tests/test_indexing.py @@ -23,8 +23,9 @@ def spec(tmp_path): ], ) def test_int_array_index_1d(spec, ind): - a = xp.arange(12, chunks=(4,), spec=spec) - assert_array_equal(a[ind].compute(), np.arange(12)[ind]) + a = xp.arange(12, chunks=(3,), spec=spec) + b = a.rechunk((4,)) # force materialization to test indexing against zarr + assert_array_equal(b[ind].compute(), np.arange(12)[ind]) @pytest.mark.parametrize( @@ -40,11 +41,12 @@ def test_int_array_index_1d(spec, ind): def test_int_array_index_2d(spec, ind): a = xp.asarray( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], - chunks=(2, 2), + chunks=(3, 3), spec=spec, ) + b = a.rechunk((2, 2)) # force materialization to test indexing against zarr x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]) - assert_array_equal(a[ind].compute(), x[ind]) + assert_array_equal(b[ind].compute(), x[ind]) def test_multiple_int_array_indexes(spec): diff --git a/pyproject.toml b/pyproject.toml index 78457a240..b86d45d1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "donfig", "fsspec", "mypy_extensions", # for rechunker + "ndindex", "networkx < 2.8.3", "numpy >= 1.22", "tenacity", diff --git a/requirements.txt b/requirements.txt index 381ac93ce..32f184807 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ aiostream array-api-compat fsspec mypy_extensions # for rechunker +ndindex networkx < 2.8.3 numpy >= 1.22 tenacity diff --git a/setup.cfg b/setup.cfg index 62965f083..8606c1e8b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,8 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-matplotlib.*] ignore_missing_imports = True +[mypy-ndindex.*] +ignore_missing_imports = True [mypy-networkx.*] ignore_missing_imports = True [mypy-numpy.*]