Skip to content

Commit

Permalink
Implement virtual array indexing using ndindex
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Jun 17, 2024
1 parent d9fddc7 commit 201f950
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 91 deletions.
17 changes: 9 additions & 8 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
OrthogonalIndexer,
SliceDimIndexer,
is_integer_list,
is_slice,
replace_ellipsis,
)

Expand Down Expand Up @@ -409,7 +408,7 @@ def index(x, key):
key = (key,)

# No op case
if all(is_slice(ind) and ind == slice(None) for ind in key):
if all(isinstance(ind, slice) and ind == slice(None) for ind in key):
return x

# Remove None values, to be filled in with expand_dims at end
Expand All @@ -436,7 +435,9 @@ def index(x, key):
selection = replace_ellipsis(selection, x.shape)

# Check selection is supported
if any(s.step is not None and s.step < 1 for s in selection if is_slice(s)):
if any(
s.step is not None and s.step < 1 for s in selection if isinstance(s, slice)
):
raise NotImplementedError(f"Slice step must be >= 1: {key}")
assert all(isinstance(s, (slice, list, Integral)) for s in selection)
where_list = [i for i, ind in enumerate(selection) if is_integer_list(ind)]
Expand Down Expand Up @@ -490,7 +491,6 @@ def merged_chunk_len_for_indexer(s):
extra_projected_mem=extra_projected_mem,
target_chunks=target_chunks,
selection=selection,
advanced_indexing=len(where_list) > 0,
)

# merge chunks for any dims with step > 1 so they are
Expand All @@ -516,13 +516,14 @@ def _read_index_chunk(
*arrays,
target_chunks=None,
selection=None,
advanced_indexing=None,
block_id=None,
):
array = arrays[0].zarray
if advanced_indexing:
array = array.oindex
idx = block_id
# Note that since we only have a maximum of one integer array index
# we don't need to use Zarr orthogonal indexing, since it is
# "available directly on the array" according to
# https://zarr.readthedocs.io/en/stable/tutorial.html#orthogonal-indexing
out = array[_target_chunk_selection(target_chunks, idx, selection)]
out = numpy_array_to_backend_array(out)
return out
Expand All @@ -535,7 +536,7 @@ def _target_chunk_selection(target_chunks, idx, selection):
sel = []
i = 0 # index into target_chunks and idx
for s in selection:
if is_slice(s):
if isinstance(s, slice):
offset = s.start or 0
step = s.step if s.step is not None else 1
start = tuple(
Expand Down
2 changes: 2 additions & 0 deletions cubed/runtime/executors/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"donfig",
"fsspec",
"mypy_extensions", # for rechunker
"ndindex",
"networkx",
"pytest-mock", # TODO: only needed for tests
"s3fs",
Expand All @@ -52,6 +53,7 @@
"donfig",
"fsspec",
"mypy_extensions", # for rechunker
"ndindex",
"networkx",
"pytest-mock", # TODO: only needed for tests
"gcsfs",
Expand Down
95 changes: 24 additions & 71 deletions cubed/storage/virtual.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
from typing import Any

import numpy as np
import zarr
from zarr.indexing import BasicIndexer, is_slice
from ndindex import ndindex

from cubed.backend_array_api import backend_array_to_numpy_array
from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import numpy_array_to_backend_array
from cubed.types import T_DType, T_RegularChunks, T_Shape
Expand All @@ -21,31 +19,21 @@ def __init__(
dtype: T_DType,
chunks: T_RegularChunks,
):
# use an empty in-memory Zarr array as a template since it normalizes its properties
template = zarr.empty(
shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore()
)
self.shape = template.shape
self.dtype = template.dtype
self.chunks = template.chunks
self.template = template
self.shape = shape
self.dtype = np.dtype(dtype)
self.chunks = chunks

def __getitem__(self, key):
if not isinstance(key, tuple):
key = (key,)
indexer = BasicIndexer(key, self.template)
idx = ndindex[key]
newshape = idx.newshape(self.shape)
# use broadcast trick so array chunks only occupy a single value in memory
return broadcast_trick(nxp.empty)(indexer.shape, dtype=self.dtype)
return broadcast_trick(nxp.empty)(newshape, dtype=self.dtype)

@property
def chunkmem(self):
# take broadcast trick into account
return array_memory(self.dtype, (1,))

@property
def oindex(self):
return self.template.oindex


class VirtualFullArray:
"""An array that is never materialized (in memory or on disk) and contains a single fill value."""
Expand All @@ -57,53 +45,29 @@ def __init__(
chunks: T_RegularChunks,
fill_value: Any = None,
):
# use an empty in-memory Zarr array as a template since it normalizes its properties
template = zarr.full(
shape,
fill_value,
dtype=dtype,
chunks=chunks,
store=zarr.storage.MemoryStore(),
)
self.shape = template.shape
self.dtype = template.dtype
self.chunks = template.chunks
self.template = template
self.shape = shape
self.dtype = np.dtype(dtype)
self.chunks = chunks
self.fill_value = fill_value

def __getitem__(self, key):
if not isinstance(key, tuple):
key = (key,)
indexer = BasicIndexer(key, self.template)
idx = ndindex[key]
newshape = idx.newshape(self.shape)
# use broadcast trick so array chunks only occupy a single value in memory
return broadcast_trick(nxp.full)(
indexer.shape, fill_value=self.fill_value, dtype=self.dtype
newshape, fill_value=self.fill_value, dtype=self.dtype
)

@property
def chunkmem(self):
# take broadcast trick into account
return array_memory(self.dtype, (1,))

@property
def oindex(self):
return self.template.oindex


class VirtualOffsetsArray:
"""An array that is never materialized (in memory or on disk) and contains sequentially incrementing integers."""

def __init__(self, shape: T_Shape):
dtype = nxp.int32
chunks = (1,) * len(shape)
# use an empty in-memory Zarr array as a template since it normalizes its properties
template = zarr.empty(
shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore()
)
self.shape = template.shape
self.dtype = template.dtype
self.chunks = template.chunks
self.ndim = template.ndim
self.shape = shape
self.dtype = np.dtype(dtype)
self.chunks = chunks

def __getitem__(self, key):
if key == () and self.shape == ():
Expand All @@ -127,28 +91,13 @@ def __init__(
f"Size of in memory array is {memory_repr(array.nbytes)} which exceeds maximum of {memory_repr(max_nbytes)}. Consider loading the array from storage using `from_array`."
)
self.array = array
# use an in-memory Zarr array as a template since it normalizes its properties
# and is needed for oindex
template = zarr.empty(
array.shape,
dtype=array.dtype,
chunks=chunks,
store=zarr.storage.MemoryStore(),
)
self.shape = template.shape
self.dtype = template.dtype
self.chunks = template.chunks
self.template = template
if array.size > 0:
template[...] = backend_array_to_numpy_array(array)
self.shape = array.shape
self.dtype = array.dtype
self.chunks = chunks

def __getitem__(self, key):
return self.array.__getitem__(key)

@property
def oindex(self):
return self.template.oindex


def _key_to_index_tuple(selection):
if isinstance(selection, slice):
Expand All @@ -158,7 +107,11 @@ def _key_to_index_tuple(selection):
for s in selection:
if isinstance(s, Integral):
sel.append(s)
elif is_slice(s) and s.stop == s.start + 1 and (s.step is None or s.step == 1):
elif (
isinstance(s, slice)
and s.stop == s.start + 1
and (s.step is None or s.step == 1)
):
sel.append(s.start)
else:
raise NotImplementedError(f"Offset selection not supported: {selection}")
Expand Down
24 changes: 16 additions & 8 deletions cubed/storage/zarr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from operator import mul
from typing import Optional, Union

import numpy as np
import zarr
from toolz import reduce

from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store

Expand All @@ -23,18 +26,23 @@ def __init__(
**kwargs,
):
"""Create a Zarr array lazily in memory."""
# use an empty in-memory Zarr array as a template since it normalizes its properties
template = zarr.empty(
shape, dtype=dtype, chunks=chunks, store=zarr.storage.MemoryStore()
)
self.shape = template.shape
self.dtype = template.dtype
self.chunks = template.chunks
self.nbytes = template.nbytes
self.shape = shape
self.dtype = np.dtype(dtype)
self.chunks = chunks
self.store = store
self.path = path
self.kwargs = kwargs

@property
def size(self):
"""Number of elements in the array."""
return reduce(mul, self.shape, 1)

@property
def nbytes(self) -> int:
"""Number of bytes in array"""
return self.size * self.dtype.itemsize

def create(self, mode: str = "w-") -> zarr.Array:
"""Create the Zarr array in storage.
Expand Down
1 change: 1 addition & 0 deletions cubed/tests/runtime/test_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"donfig",
"fsspec",
"mypy_extensions", # for rechunker
"ndindex",
"networkx",
"pytest-mock", # TODO: only needed for tests
"s3fs",
Expand Down
10 changes: 6 additions & 4 deletions cubed/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ def spec(tmp_path):
],
)
def test_int_array_index_1d(spec, ind):
a = xp.arange(12, chunks=(4,), spec=spec)
assert_array_equal(a[ind].compute(), np.arange(12)[ind])
a = xp.arange(12, chunks=(3,), spec=spec)
b = a.rechunk((4,)) # force materialization to test indexing against zarr
assert_array_equal(b[ind].compute(), np.arange(12)[ind])


@pytest.mark.parametrize(
Expand All @@ -40,11 +41,12 @@ def test_int_array_index_1d(spec, ind):
def test_int_array_index_2d(spec, ind):
a = xp.asarray(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
chunks=(2, 2),
chunks=(3, 3),
spec=spec,
)
b = a.rechunk((2, 2)) # force materialization to test indexing against zarr
x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])
assert_array_equal(a[ind].compute(), x[ind])
assert_array_equal(b[ind].compute(), x[ind])


def test_multiple_int_array_indexes(spec):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"donfig",
"fsspec",
"mypy_extensions", # for rechunker
"ndindex",
"networkx != 2.8.3, != 2.8.4, != 2.8.5, != 2.8.6, != 2.8.7, != 2.8.8, != 3.0.*, != 3.1.*, != 3.2.*",
"numpy >= 1.22",
"tenacity",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ aiostream
array-api-compat
fsspec
mypy_extensions # for rechunker
ndindex
networkx != 2.8.3, != 2.8.4, != 2.8.5, != 2.8.6, != 2.8.7, != 2.8.8, != 3.0.*, != 3.1.*, != 3.2.*
numpy >= 1.22
tenacity
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-matplotlib.*]
ignore_missing_imports = True
[mypy-ndindex.*]
ignore_missing_imports = True
[mypy-networkx.*]
ignore_missing_imports = True
[mypy-numpy.*]
Expand Down

0 comments on commit 201f950

Please sign in to comment.