Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support block_id for general_blockwise functions #593

Merged
merged 2 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 77 additions & 30 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,9 @@
from cubed.spec import spec_from_config
from cubed.storage.backend import open_backend_array
from cubed.types import T_RegularChunks, T_Shape
from cubed.utils import (
_concatenate2,
array_memory,
array_size,
get_item,
offset_to_block_id,
to_chunksize,
)
from cubed.utils import _concatenate2, array_memory, array_size, get_item
from cubed.utils import numblocks as compute_numblocks
from cubed.utils import offset_to_block_id, to_chunksize
from cubed.vendor.dask.array.core import normalize_chunks
from cubed.vendor.dask.array.utils import validate_axis
from cubed.vendor.dask.blockwise import broadcast_dimensions, lol_product
Expand Down Expand Up @@ -342,6 +337,77 @@ def general_blockwise(
target_paths=None,
extra_func_kwargs=None,
**kwargs,
) -> Union["Array", Tuple["Array", ...]]:
if has_keyword(func, "block_id"):
from cubed.array_api.creation_functions import offsets_virtual_array

# Create an array of index offsets with the same chunk structure as the args,
# which we convert to block ids (chunk coordinates) later.
array0 = arrays[0]
# note that primitive general_blockwise checks that all chunkss have same numblocks
numblocks = compute_numblocks(chunkss[0])
offsets = offsets_virtual_array(numblocks, array0.spec)
new_arrays = arrays + (offsets,)

def key_function_with_offset(key_function):
def wrap(out_key):
out_coords = out_key[1:]
offset_in_key = ((offsets.name,) + out_coords,)
return key_function(out_key) + offset_in_key

return wrap

def func_with_block_id(func):
def wrap(*a, **kw):
offset = int(a[-1]) # convert from 0-d array
block_id = offset_to_block_id(offset, numblocks)
return func(*a[:-1], block_id=block_id, **kw)

return wrap

num_input_blocks = kwargs.pop("num_input_blocks", None)
if num_input_blocks is not None:
num_input_blocks = num_input_blocks + (1,) # for offsets array

return _general_blockwise(
func_with_block_id(func),
key_function_with_offset(key_function),
*new_arrays,
shapes=shapes,
dtypes=dtypes,
chunkss=chunkss,
target_stores=target_stores,
target_paths=target_paths,
extra_func_kwargs=extra_func_kwargs,
num_input_blocks=num_input_blocks,
**kwargs,
)

return _general_blockwise(
func,
key_function,
*arrays,
shapes=shapes,
dtypes=dtypes,
chunkss=chunkss,
target_stores=target_stores,
target_paths=target_paths,
extra_func_kwargs=extra_func_kwargs,
**kwargs,
)


def _general_blockwise(
func,
key_function,
*arrays,
shapes,
dtypes,
chunkss,
target_stores=None,
target_paths=None,
extra_func_kwargs=None,
**kwargs,
) -> Union["Array", Tuple["Array", ...]]:
assert len(arrays) > 0

Expand Down Expand Up @@ -504,12 +570,6 @@ def merged_chunk_len_for_indexer(ia, c):
if _is_chunk_aligned_selection(idx):
# use general_blockwise, which allows more opportunities for optimization than map_direct

from cubed.array_api.creation_functions import offsets_virtual_array

# general_blockwise doesn't support block_id, so emulate it ourselves
numblocks = tuple(map(len, target_chunks))
offsets = offsets_virtual_array(numblocks, x.spec)

def key_function(out_key):
out_coords = out_key[1:]

Expand All @@ -521,24 +581,17 @@ def key_function(out_key):
in_sel, x.zarray_maybe_lazy.shape, x.zarray_maybe_lazy.chunks
)

offset_in_key = ((offsets.name,) + out_coords,)
return (
tuple((x.name,) + chunk_coords for (chunk_coords, _, _) in indexer)
+ offset_in_key
return tuple(
(x.name,) + chunk_coords for (chunk_coords, _, _) in indexer
)

# since selection is chunk-aligned, we know that we only read one block of x
num_input_blocks = (1, 1) # x, offsets

out = general_blockwise(
_assemble_index_chunk,
key_function,
x,
offsets,
shapes=[shape],
dtypes=[x.dtype],
chunkss=[target_chunks],
num_input_blocks=num_input_blocks,
target_chunks=target_chunks,
selection=selection,
in_shape=x.shape,
Expand Down Expand Up @@ -622,14 +675,8 @@ def _assemble_index_chunk(
selection=None,
in_shape=None,
in_chunksize=None,
block_id=None,
):
# last array contains the offset for the block_id
offset = int(arrs[-1]) # convert from 0-d array
numblocks = tuple(map(len, target_chunks))
block_id = offset_to_block_id(offset, numblocks)

arrs = arrs[:-1] # drop offset array

# compute the selection on x required to get the relevant chunk for out_coords
out_coords = block_id
in_sel = _target_chunk_selection(target_chunks, out_coords, selection)
Expand Down
22 changes: 14 additions & 8 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,9 @@
from cubed.runtime.types import CubedPipeline
from cubed.storage.zarr import T_ZarrArray, lazy_zarr_array
from cubed.types import T_Chunks, T_DType, T_Shape, T_Store
from cubed.utils import (
array_memory,
chunk_memory,
get_item,
map_nested,
split_into,
to_chunksize,
)
from cubed.utils import array_memory, chunk_memory, get_item, map_nested
from cubed.utils import numblocks as compute_numblocks
from cubed.utils import split_into, to_chunksize
from cubed.vendor.dask.array.core import normalize_chunks
from cubed.vendor.dask.blockwise import _get_coord_mapping, _make_dims, lol_product
from cubed.vendor.dask.core import flatten
Expand Down Expand Up @@ -261,6 +256,8 @@ def general_blockwise(
"""A more general form of ``blockwise`` that uses a function to specify the block
mapping, rather than an index notation, and which supports multiple outputs.

For multiple outputs, all output arrays must have matching numblocks.

Parameters
----------
func : callable
Expand Down Expand Up @@ -308,9 +305,18 @@ def general_blockwise(
output_chunk_memory = 0
target_array = []

numblocks0 = None
for i, target_store in enumerate(target_stores):
chunks_normal = normalize_chunks(chunkss[i], shape=shapes[i], dtype=dtypes[i])
chunksize = to_chunksize(chunks_normal)
if numblocks0 is None:
numblocks0 = compute_numblocks(chunks_normal)
else:
numblocks = compute_numblocks(chunks_normal)
if numblocks != numblocks0:
raise ValueError(
f"All outputs must have matching number of blocks in each dimension. Chunks specified: {chunkss}"
)
if isinstance(target_store, zarr.Array):
ta = target_store
else:
Expand Down
39 changes: 39 additions & 0 deletions cubed/tests/primitive/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,45 @@ def block_function(out_key):
assert_array_equal(res2[:], -np.sqrt(input))


def test_blockwise_multiple_outputs_fails_different_numblocks(tmp_path):
source = create_zarr(
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
dtype=int,
chunks=(2, 2),
store=tmp_path / "source.zarr",
)
allowed_mem = 1000
target_store1 = tmp_path / "target1.zarr"
target_store2 = tmp_path / "target2.zarr"

in_name = "x"

def sqrts(x):
yield np.sqrt(x)
yield -np.sqrt(x)

def block_function(out_key):
out_coords = out_key[1:]
return ((in_name, *out_coords),)

with pytest.raises(
ValueError,
match="All outputs must have matching number of blocks in each dimension",
):
general_blockwise(
sqrts,
block_function,
source,
allowed_mem=allowed_mem,
reserved_mem=0,
target_stores=[target_store1, target_store2],
shapes=[(3, 3), (3, 3)],
dtypes=[float, float],
chunkss=[(2, 2), (4, 2)], # numblocks differ
in_names=[in_name],
)


def test_make_blockwise_key_function_map():
func = lambda x: 0

Expand Down
4 changes: 4 additions & 0 deletions cubed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def to_chunksize(chunkset: T_RectangularChunks) -> T_RegularChunks:
return tuple(max(c[0], 1) for c in chunkset)


def numblocks(chunks: T_RectangularChunks) -> Tuple[int, ...]:
return tuple(map(len, chunks))


@dataclass
class StackSummary:
"""Like Python's ``FrameSummary``, but with module information."""
Expand Down
Loading