From 2d3440f16da943b145b4b6975ba542fdc0d8d66e Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 9 Oct 2024 08:07:43 +0100 Subject: [PATCH] Support `block_id` for `general_blockwise` functions (#593) * Ensure numblocks match for multiple outputs in general blockwise * Support block_id for general_blockwise --- cubed/core/ops.py | 107 +++++++++++++++++------- cubed/primitive/blockwise.py | 22 +++-- cubed/tests/primitive/test_blockwise.py | 39 +++++++++ cubed/utils.py | 4 + 4 files changed, 134 insertions(+), 38 deletions(-) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 6224c112..7e16d0d0 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -26,14 +26,9 @@ from cubed.spec import spec_from_config from cubed.storage.backend import open_backend_array from cubed.types import T_RegularChunks, T_Shape -from cubed.utils import ( - _concatenate2, - array_memory, - array_size, - get_item, - offset_to_block_id, - to_chunksize, -) +from cubed.utils import _concatenate2, array_memory, array_size, get_item +from cubed.utils import numblocks as compute_numblocks +from cubed.utils import offset_to_block_id, to_chunksize from cubed.vendor.dask.array.core import normalize_chunks from cubed.vendor.dask.array.utils import validate_axis from cubed.vendor.dask.blockwise import broadcast_dimensions, lol_product @@ -342,6 +337,77 @@ def general_blockwise( target_paths=None, extra_func_kwargs=None, **kwargs, +) -> Union["Array", Tuple["Array", ...]]: + if has_keyword(func, "block_id"): + from cubed.array_api.creation_functions import offsets_virtual_array + + # Create an array of index offsets with the same chunk structure as the args, + # which we convert to block ids (chunk coordinates) later. + array0 = arrays[0] + # note that primitive general_blockwise checks that all chunkss have same numblocks + numblocks = compute_numblocks(chunkss[0]) + offsets = offsets_virtual_array(numblocks, array0.spec) + new_arrays = arrays + (offsets,) + + def key_function_with_offset(key_function): + def wrap(out_key): + out_coords = out_key[1:] + offset_in_key = ((offsets.name,) + out_coords,) + return key_function(out_key) + offset_in_key + + return wrap + + def func_with_block_id(func): + def wrap(*a, **kw): + offset = int(a[-1]) # convert from 0-d array + block_id = offset_to_block_id(offset, numblocks) + return func(*a[:-1], block_id=block_id, **kw) + + return wrap + + num_input_blocks = kwargs.pop("num_input_blocks", None) + if num_input_blocks is not None: + num_input_blocks = num_input_blocks + (1,) # for offsets array + + return _general_blockwise( + func_with_block_id(func), + key_function_with_offset(key_function), + *new_arrays, + shapes=shapes, + dtypes=dtypes, + chunkss=chunkss, + target_stores=target_stores, + target_paths=target_paths, + extra_func_kwargs=extra_func_kwargs, + num_input_blocks=num_input_blocks, + **kwargs, + ) + + return _general_blockwise( + func, + key_function, + *arrays, + shapes=shapes, + dtypes=dtypes, + chunkss=chunkss, + target_stores=target_stores, + target_paths=target_paths, + extra_func_kwargs=extra_func_kwargs, + **kwargs, + ) + + +def _general_blockwise( + func, + key_function, + *arrays, + shapes, + dtypes, + chunkss, + target_stores=None, + target_paths=None, + extra_func_kwargs=None, + **kwargs, ) -> Union["Array", Tuple["Array", ...]]: assert len(arrays) > 0 @@ -504,12 +570,6 @@ def merged_chunk_len_for_indexer(ia, c): if _is_chunk_aligned_selection(idx): # use general_blockwise, which allows more opportunities for optimization than map_direct - from cubed.array_api.creation_functions import offsets_virtual_array - - # general_blockwise doesn't support block_id, so emulate it ourselves - numblocks = tuple(map(len, target_chunks)) - offsets = offsets_virtual_array(numblocks, x.spec) - def key_function(out_key): out_coords = out_key[1:] @@ -521,24 +581,17 @@ def key_function(out_key): in_sel, x.zarray_maybe_lazy.shape, x.zarray_maybe_lazy.chunks ) - offset_in_key = ((offsets.name,) + out_coords,) - return ( - tuple((x.name,) + chunk_coords for (chunk_coords, _, _) in indexer) - + offset_in_key + return tuple( + (x.name,) + chunk_coords for (chunk_coords, _, _) in indexer ) - # since selection is chunk-aligned, we know that we only read one block of x - num_input_blocks = (1, 1) # x, offsets - out = general_blockwise( _assemble_index_chunk, key_function, x, - offsets, shapes=[shape], dtypes=[x.dtype], chunkss=[target_chunks], - num_input_blocks=num_input_blocks, target_chunks=target_chunks, selection=selection, in_shape=x.shape, @@ -622,14 +675,8 @@ def _assemble_index_chunk( selection=None, in_shape=None, in_chunksize=None, + block_id=None, ): - # last array contains the offset for the block_id - offset = int(arrs[-1]) # convert from 0-d array - numblocks = tuple(map(len, target_chunks)) - block_id = offset_to_block_id(offset, numblocks) - - arrs = arrs[:-1] # drop offset array - # compute the selection on x required to get the relevant chunk for out_coords out_coords = block_id in_sel = _target_chunk_selection(target_chunks, out_coords, selection) diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index d36b5fbb..1d6618c2 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -18,14 +18,9 @@ from cubed.runtime.types import CubedPipeline from cubed.storage.zarr import T_ZarrArray, lazy_zarr_array from cubed.types import T_Chunks, T_DType, T_Shape, T_Store -from cubed.utils import ( - array_memory, - chunk_memory, - get_item, - map_nested, - split_into, - to_chunksize, -) +from cubed.utils import array_memory, chunk_memory, get_item, map_nested +from cubed.utils import numblocks as compute_numblocks +from cubed.utils import split_into, to_chunksize from cubed.vendor.dask.array.core import normalize_chunks from cubed.vendor.dask.blockwise import _get_coord_mapping, _make_dims, lol_product from cubed.vendor.dask.core import flatten @@ -261,6 +256,8 @@ def general_blockwise( """A more general form of ``blockwise`` that uses a function to specify the block mapping, rather than an index notation, and which supports multiple outputs. + For multiple outputs, all output arrays must have matching numblocks. + Parameters ---------- func : callable @@ -308,9 +305,18 @@ def general_blockwise( output_chunk_memory = 0 target_array = [] + numblocks0 = None for i, target_store in enumerate(target_stores): chunks_normal = normalize_chunks(chunkss[i], shape=shapes[i], dtype=dtypes[i]) chunksize = to_chunksize(chunks_normal) + if numblocks0 is None: + numblocks0 = compute_numblocks(chunks_normal) + else: + numblocks = compute_numblocks(chunks_normal) + if numblocks != numblocks0: + raise ValueError( + f"All outputs must have matching number of blocks in each dimension. Chunks specified: {chunkss}" + ) if isinstance(target_store, zarr.Array): ta = target_store else: diff --git a/cubed/tests/primitive/test_blockwise.py b/cubed/tests/primitive/test_blockwise.py index a5886d6e..fe091e07 100644 --- a/cubed/tests/primitive/test_blockwise.py +++ b/cubed/tests/primitive/test_blockwise.py @@ -285,6 +285,45 @@ def block_function(out_key): assert_array_equal(res2[:], -np.sqrt(input)) +def test_blockwise_multiple_outputs_fails_different_numblocks(tmp_path): + source = create_zarr( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + dtype=int, + chunks=(2, 2), + store=tmp_path / "source.zarr", + ) + allowed_mem = 1000 + target_store1 = tmp_path / "target1.zarr" + target_store2 = tmp_path / "target2.zarr" + + in_name = "x" + + def sqrts(x): + yield np.sqrt(x) + yield -np.sqrt(x) + + def block_function(out_key): + out_coords = out_key[1:] + return ((in_name, *out_coords),) + + with pytest.raises( + ValueError, + match="All outputs must have matching number of blocks in each dimension", + ): + general_blockwise( + sqrts, + block_function, + source, + allowed_mem=allowed_mem, + reserved_mem=0, + target_stores=[target_store1, target_store2], + shapes=[(3, 3), (3, 3)], + dtypes=[float, float], + chunkss=[(2, 2), (4, 2)], # numblocks differ + in_names=[in_name], + ) + + def test_make_blockwise_key_function_map(): func = lambda x: 0 diff --git a/cubed/utils.py b/cubed/utils.py index 16b8a53d..e5837134 100644 --- a/cubed/utils.py +++ b/cubed/utils.py @@ -147,6 +147,10 @@ def to_chunksize(chunkset: T_RectangularChunks) -> T_RegularChunks: return tuple(max(c[0], 1) for c in chunkset) +def numblocks(chunks: T_RectangularChunks) -> Tuple[int, ...]: + return tuple(map(len, chunks)) + + @dataclass class StackSummary: """Like Python's ``FrameSummary``, but with module information."""