Skip to content

Commit

Permalink
Implement virtual array indexing using ndindex
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Apr 5, 2024
1 parent c2283f3 commit db71acf
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 86 deletions.
17 changes: 9 additions & 8 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
OrthogonalIndexer,
SliceDimIndexer,
is_integer_list,
is_slice,
replace_ellipsis,
)

Expand Down Expand Up @@ -408,7 +407,7 @@ def index(x, key):
key = (key,)

# No op case
if all(is_slice(ind) and ind == slice(None) for ind in key):
if all(isinstance(ind, slice) and ind == slice(None) for ind in key):
return x

# Remove None values, to be filled in with expand_dims at end
Expand All @@ -435,7 +434,9 @@ def index(x, key):
selection = replace_ellipsis(selection, x.shape)

# Check selection is supported
if any(s.step is not None and s.step < 1 for s in selection if is_slice(s)):
if any(
s.step is not None and s.step < 1 for s in selection if isinstance(s, slice)
):
raise NotImplementedError(f"Slice step must be >= 1: {key}")
assert all(isinstance(s, (slice, list, Integral)) for s in selection)
where_list = [i for i, ind in enumerate(selection) if is_integer_list(ind)]
Expand Down Expand Up @@ -489,7 +490,6 @@ def merged_chunk_len_for_indexer(s):
extra_projected_mem=extra_projected_mem,
target_chunks=target_chunks,
selection=selection,
advanced_indexing=len(where_list) > 0,
)

# merge chunks for any dims with step > 1 so they are
Expand All @@ -515,13 +515,14 @@ def _read_index_chunk(
*arrays,
target_chunks=None,
selection=None,
advanced_indexing=None,
block_id=None,
):
array = arrays[0].zarray
if advanced_indexing:
array = array.oindex
idx = block_id
# Note that since we only have a maximum of one integer array index
# we don't need to use Zarr orthogonal indexing, since it is
# "available directly on the array" according to
# https://zarr.readthedocs.io/en/stable/tutorial.html#orthogonal-indexing
out = array[_target_chunk_selection(target_chunks, idx, selection)]
out = numpy_array_to_backend_array(out)
return out
Expand All @@ -534,7 +535,7 @@ def _target_chunk_selection(target_chunks, idx, selection):
sel = []
i = 0 # index into target_chunks and idx
for s in selection:
if is_slice(s):
if isinstance(s, slice):
offset = s.start or 0
step = s.step if s.step is not None else 1
start = tuple(
Expand Down
2 changes: 2 additions & 0 deletions cubed/runtime/executors/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"donfig",
"fsspec",
"mypy_extensions", # for rechunker
"ndindex",
"networkx",
"pytest-mock", # TODO: only needed for tests
"s3fs",
Expand All @@ -35,6 +36,7 @@
"donfig",
"fsspec",
"mypy_extensions", # for rechunker
"ndindex",
"networkx",
"pytest-mock", # TODO: only needed for tests
"gcsfs",
Expand Down
90 changes: 24 additions & 66 deletions cubed/storage/virtual.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
from typing import Any

import numpy as np
import zarr
from zarr.indexing import BasicIndexer, is_slice
from ndindex import ndindex

from cubed.backend_array_api import backend_array_to_numpy_array
from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import numpy_array_to_backend_array
from cubed.types import T_DType, T_RegularChunks, T_Shape
Expand All @@ -21,25 +19,15 @@ def __init__(
dtype: T_DType,
chunks: T_RegularChunks,
):
# use an empty in-memory Zarr array as a template since it normalizes its properties
template = zarr.empty(
shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore()
)
self.shape = template.shape
self.dtype = template.dtype
self.chunks = template.chunks
self.template = template
self.shape = shape
self.dtype = np.dtype(dtype)
self.chunks = chunks

def __getitem__(self, key):
if not isinstance(key, tuple):
key = (key,)
indexer = BasicIndexer(key, self.template)
idx = ndindex[key]
newshape = idx.newshape(self.shape)
# use broadcast trick so array chunks only occupy a single value in memory
return broadcast_trick(nxp.empty)(indexer.shape, dtype=self.dtype)

@property
def oindex(self):
return self.template.oindex
return broadcast_trick(nxp.empty)(newshape, dtype=self.dtype)


class VirtualFullArray:
Expand All @@ -52,48 +40,29 @@ def __init__(
chunks: T_RegularChunks,
fill_value: Any = None,
):
# use an empty in-memory Zarr array as a template since it normalizes its properties
template = zarr.full(
shape,
fill_value,
dtype=dtype,
chunks=chunks,
store=zarr.storage.MemoryStore(),
)
self.shape = template.shape
self.dtype = template.dtype
self.chunks = template.chunks
self.template = template
self.shape = shape
self.dtype = np.dtype(dtype)
self.chunks = chunks
self.fill_value = fill_value

def __getitem__(self, key):
if not isinstance(key, tuple):
key = (key,)
indexer = BasicIndexer(key, self.template)
idx = ndindex[key]
newshape = idx.newshape(self.shape)
# use broadcast trick so array chunks only occupy a single value in memory
return broadcast_trick(nxp.full)(
indexer.shape, fill_value=self.fill_value, dtype=self.dtype
newshape, fill_value=self.fill_value, dtype=self.dtype
)

@property
def oindex(self):
return self.template.oindex


class VirtualOffsetsArray:
"""An array that is never materialized (in memory or on disk) and contains sequentially incrementing integers."""

def __init__(self, shape: T_Shape):
dtype = nxp.int32
chunks = (1,) * len(shape)
# use an empty in-memory Zarr array as a template since it normalizes its properties
template = zarr.empty(
shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore()
)
self.shape = template.shape
self.dtype = template.dtype
self.chunks = template.chunks
self.ndim = template.ndim
self.shape = shape
self.dtype = np.dtype(dtype)
self.chunks = chunks

def __getitem__(self, key):
if key == () and self.shape == ():
Expand All @@ -117,28 +86,13 @@ def __init__(
f"Size of in memory array is {memory_repr(array.nbytes)} which exceeds maximum of {memory_repr(max_nbytes)}. Consider loading the array from storage using `from_array`."
)
self.array = array
# use an in-memory Zarr array as a template since it normalizes its properties
# and is needed for oindex
template = zarr.empty(
array.shape,
dtype=array.dtype,
chunks=chunks,
store=zarr.storage.MemoryStore(),
)
self.shape = template.shape
self.dtype = template.dtype
self.chunks = template.chunks
self.template = template
if array.size > 0:
template[...] = backend_array_to_numpy_array(array)
self.shape = array.shape
self.dtype = array.dtype
self.chunks = chunks

def __getitem__(self, key):
return self.array.__getitem__(key)

@property
def oindex(self):
return self.template.oindex


def _key_to_index_tuple(selection):
if isinstance(selection, slice):
Expand All @@ -148,7 +102,11 @@ def _key_to_index_tuple(selection):
for s in selection:
if isinstance(s, Integral):
sel.append(s)
elif is_slice(s) and s.stop == s.start + 1 and (s.step is None or s.step == 1):
elif (
isinstance(s, slice)
and s.stop == s.start + 1
and (s.step is None or s.step == 1)
):
sel.append(s.start)
else:
raise NotImplementedError(f"Offset selection not supported: {selection}")
Expand Down
24 changes: 16 additions & 8 deletions cubed/storage/zarr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from operator import mul
from typing import Optional, Union

import numpy as np
import zarr
from toolz import reduce

from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store

Expand All @@ -23,18 +26,23 @@ def __init__(
**kwargs,
):
"""Create a Zarr array lazily in memory."""
# use an empty in-memory Zarr array as a template since it normalizes its properties
template = zarr.empty(
shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore()
)
self.shape = template.shape
self.dtype = template.dtype
self.chunks = template.chunks
self.nbytes = template.nbytes
self.shape = shape
self.dtype = np.dtype(dtype)
self.chunks = chunks
self.store = store
self.path = path
self.kwargs = kwargs

@property
def size(self):
"""Number of elements in the array."""
return reduce(mul, self.shape, 1)

@property
def nbytes(self) -> int:
"""Number of bytes in array"""
return self.size * self.dtype.itemsize

def create(self, mode: str = "w-") -> zarr.Array:
"""Create the Zarr array in storage.
Expand Down
1 change: 1 addition & 0 deletions cubed/tests/runtime/test_modal_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"donfig",
"fsspec",
"mypy_extensions", # for rechunker
"ndindex",
"networkx",
"pytest-mock", # TODO: only needed for tests
"s3fs",
Expand Down
10 changes: 6 additions & 4 deletions cubed/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ def spec(tmp_path):
],
)
def test_int_array_index_1d(spec, ind):
a = xp.arange(12, chunks=(4,), spec=spec)
assert_array_equal(a[ind].compute(), np.arange(12)[ind])
a = xp.arange(12, chunks=(3,), spec=spec)
b = a.rechunk((4,)) # force materialization to test indexing against zarr
assert_array_equal(b[ind].compute(), np.arange(12)[ind])


@pytest.mark.parametrize(
Expand All @@ -40,11 +41,12 @@ def test_int_array_index_1d(spec, ind):
def test_int_array_index_2d(spec, ind):
a = xp.asarray(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
chunks=(2, 2),
chunks=(3, 3),
spec=spec,
)
b = a.rechunk((2, 2)) # force materialization to test indexing against zarr
x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])
assert_array_equal(a[ind].compute(), x[ind])
assert_array_equal(b[ind].compute(), x[ind])


def test_multiple_int_array_indexes(spec):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"donfig",
"fsspec",
"mypy_extensions", # for rechunker
"ndindex",
"networkx < 2.8.3",
"numpy >= 1.22",
"tenacity",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ aiostream
array-api-compat
fsspec
mypy_extensions # for rechunker
ndindex
networkx < 2.8.3
numpy >= 1.22
tenacity
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-matplotlib.*]
ignore_missing_imports = True
[mypy-ndindex.*]
ignore_missing_imports = True
[mypy-networkx.*]
ignore_missing_imports = True
[mypy-numpy.*]
Expand Down

0 comments on commit db71acf

Please sign in to comment.