Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement concat using general_blockwise #607

Merged
merged 4 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 101 additions & 26 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from cubed.core import squeeze # noqa: F401
from cubed.core import blockwise, rechunk, unify_chunks
from cubed.core.ops import (
_create_zarr_indexer,
elemwise,
general_blockwise,
map_blocks,
map_direct,
map_selection,
)
from cubed.utils import block_id_to_offset, get_item, offset_to_block_id, to_chunksize
Expand Down Expand Up @@ -95,6 +95,7 @@ def concat(arrays, /, *, axis=0, chunks=None):

# offsets along axis for the start of each array
offsets = [0] + list(tlz.accumulate(add, [a.shape[axis] for a in arrays]))
in_shapes = tuple(array.shape for array in arrays)

axis = validate_axis(axis, a.ndim)
shape = a.shape[:axis] + (offsets[-1],) + a.shape[axis + 1 :]
Expand All @@ -104,47 +105,93 @@ def concat(arrays, /, *, axis=0, chunks=None):
else:
chunks = normalize_chunks(chunks, shape=shape, dtype=dtype)

# memory allocated by reading one chunk from input array
# note that although the output chunk will overlap multiple input chunks,
# the chunks are read in series, reusing memory
extra_projected_mem = a.chunkmem
def key_function(out_key):
out_coords = out_key[1:]
block_id = out_coords

# determine the start and stop indexes for this block along the axis dimension
chunksize = to_chunksize(chunks)
start = block_id[axis] * chunksize[axis]
stop = start + chunksize[axis]
stop = min(stop, shape[axis])

# produce a key that has slices (except for axis dimension, which is replaced below)
idx = tuple(0 if i == axis else v for i, v in enumerate(block_id))
key = get_item(chunks, idx)

# find slices of the arrays
in_keys = []
for ai, sl in _array_slices(offsets, start, stop):
key = tuple(sl if i == axis else k for i, k in enumerate(key))

# use a Zarr BasicIndexer to convert this to input coordinates
a = arrays[ai]
indexer = _create_zarr_indexer(key, a.shape, a.chunksize)

in_keys.extend(
[(a.name,) + chunk_coords for (chunk_coords, _, _) in indexer]
)

return (iter(tuple(in_key for in_key in in_keys)),)

return map_direct(
num_input_blocks = (1,) * len(arrays)
iterable_input_blocks = (True,) * len(arrays)

# We have to mark this as fusable=False since the number of input args to
# the _read_concat_chunk function is *not* the same as the number of
# predecessor nodes in the DAG, and the fusion functions in blockwise
# assume they are the same. See https://github.com/cubed-dev/cubed/issues/414
# This also affects stack.
return general_blockwise(
_read_concat_chunk,
key_function,
*arrays,
shape=shape,
dtype=dtype,
chunks=chunks,
extra_projected_mem=extra_projected_mem,
shapes=[shape],
dtypes=[dtype],
chunkss=[chunks],
num_input_blocks=num_input_blocks,
iterable_input_blocks=iterable_input_blocks,
extra_func_kwargs=dict(dtype=dtype),
target_shape=shape,
target_chunks=chunks,
axis=axis,
offsets=offsets,
in_shapes=in_shapes,
function_nargs=1,
fusable=False,
)


def _read_concat_chunk(
x, *arrays, target_chunks=None, axis=None, offsets=None, block_id=None
arrays,
dtype=None,
target_shape=None,
target_chunks=None,
axis=None,
offsets=None,
in_shapes=None,
block_id=None,
):
# determine the start and stop indexes for this block along the axis dimension
chunks = target_chunks
chunksize = to_chunksize(chunks)
chunksize = to_chunksize(target_chunks)
start = block_id[axis] * chunksize[axis]
stop = start + x.shape[axis]

# produce a key that has slices (except for axis dimension, which is replaced below)
idx = tuple(0 if i == axis else v for i, v in enumerate(block_id))
key = get_item(chunks, idx)

# concatenate slices of the arrays
parts = []
for ai, sl in _array_slices(offsets, start, stop):
key = tuple(sl if i == axis else k for i, k in enumerate(key))
parts.append(arrays[ai].zarray[key])
return nxp.concat(parts, axis=axis)
stop = start + chunksize[axis]
stop = min(stop, target_shape[axis])

chunk_shape = tuple(ch[bi] for ch, bi in zip(target_chunks, block_id))
out = np.empty(chunk_shape, dtype=dtype)
for array, (lchunk_selection, lout_selection) in zip(
arrays,
_chunk_slices(
offsets, start, stop, target_chunks, chunksize, in_shapes, axis, block_id
),
):
out[lout_selection] = array[lchunk_selection]
return out


def _array_slices(offsets, start, stop):
"""Return pairs of array index and slice to slice from start to stop in the concatenated array."""
"""Return pairs of array index and array slice to slice from start to stop in the concatenated array."""
slice_start = start
while slice_start < stop:
# find array that slice_start falls in
Expand All @@ -154,6 +201,33 @@ def _array_slices(offsets, start, stop):
slice_start = slice_stop


def _chunk_slices(
offsets, start, stop, target_chunks, chunksize, in_shapes, axis, block_id
):
"""Return pairs of chunk slices to slice input array chunks and output concatenated chunk."""

# an output chunk may have selections from more than one array, so we need an offset per array
arr_sel_offset = 0 # offset along axis

# produce a key that has slices (except for axis dimension, which is replaced below)
idx = tuple(0 if i == axis else v for i, v in enumerate(block_id))
key = get_item(target_chunks, idx)

for ai, sl in _array_slices(offsets, start, stop):
key = tuple(sl if i == axis else k for i, k in enumerate(key))
indexer = _create_zarr_indexer(key, in_shapes[ai], chunksize)
for _, lchunk_selection, lout_selection in indexer:
lout_selection_with_offset = tuple(
sl
if ax != axis
else slice(sl.start + arr_sel_offset, sl.stop + arr_sel_offset)
for ax, sl in enumerate(lout_selection)
)
yield lchunk_selection, lout_selection_with_offset

arr_sel_offset += lout_selection[axis].stop


def expand_dims(x, /, *, axis):
if not isinstance(axis, tuple):
axis = (axis,)
Expand Down Expand Up @@ -420,6 +494,7 @@ def key_function(out_key):
dtypes=[dtype],
chunkss=[chunks],
axis=axis,
function_nargs=1,
fusable=False,
)

Expand Down
4 changes: 4 additions & 0 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,9 @@ def wrap(*a, **kw):

return wrap

function_nargs = kwargs.pop("function_nargs", None)
if function_nargs is not None:
function_nargs = function_nargs + 1 # for offsets array
num_input_blocks = kwargs.pop("num_input_blocks", None)
if num_input_blocks is not None:
num_input_blocks = num_input_blocks + (1,) # for offsets array
Expand All @@ -386,6 +389,7 @@ def wrap(*a, **kw):
target_stores=target_stores,
target_paths=target_paths,
extra_func_kwargs=extra_func_kwargs,
function_nargs=function_nargs,
num_input_blocks=num_input_blocks,
iterable_input_blocks=iterable_input_blocks,
**kwargs,
Expand Down
6 changes: 3 additions & 3 deletions cubed/core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def can_fuse_predecessors(
# if node itself can't be fused then there is nothing to fuse
if not is_fusable(nodes[name]):
logger.debug(
"can't fuse %s since it is not a primitive operation, or it uses map_direct",
"can't fuse %s since it is not a primitive operation, or it uses an operation that can't be fused (concat or stack)",
name,
)
return False
Expand Down Expand Up @@ -224,10 +224,10 @@ def can_fuse_predecessors(
)
return False

# if a predecessor has no primitive op then just use None
predecessor_primitive_ops = [
nodes[pre]["primitive_op"]
nodes[pre]["primitive_op"] if can_fuse else None
for pre, _, can_fuse in predecessor_ops_and_arrays(dag, name)
if can_fuse
]
return can_fuse_multiple_primitive_ops(
name,
Expand Down
Loading
Loading