diff --git a/cubed/array_api/manipulation_functions.py b/cubed/array_api/manipulation_functions.py index 21fe91ca7..c7707e37e 100644 --- a/cubed/array_api/manipulation_functions.py +++ b/cubed/array_api/manipulation_functions.py @@ -11,7 +11,7 @@ from cubed.backend_array_api import numpy_array_to_backend_array from cubed.core import squeeze # noqa: F401 from cubed.core import blockwise, rechunk, unify_chunks -from cubed.core.ops import elemwise, map_blocks, map_direct +from cubed.core.ops import elemwise, general_blockwise, map_blocks, map_direct from cubed.utils import get_item, to_chunksize from cubed.vendor.dask.array.core import broadcast_chunks, normalize_chunks from cubed.vendor.dask.array.reshape import reshape_rechunk @@ -293,8 +293,16 @@ def stack(arrays, /, *, axis=0): # (output is already catered for in blockwise) extra_projected_mem = a.chunkmem - return map_direct( + array_names = [a.name for a in arrays] + + def block_function(out_key): + out_coords = out_key[1:] + in_name = array_names[out_coords[axis]] + return ((in_name, *(out_coords[:axis] + out_coords[(axis + 1) :])),) + + return general_blockwise( _read_stack_chunk, + block_function, *arrays, shape=shape, dtype=dtype, @@ -304,9 +312,5 @@ def stack(arrays, /, *, axis=0): ) -def _read_stack_chunk(x, *arrays, axis=None, block_id=None): - array = arrays[block_id[axis]] - idx = tuple(v for i, v in enumerate(block_id) if i != axis) - out = array.zarray[get_item(array.chunks, idx)] - out = nxp.expand_dims(out, axis=axis) - return out +def _read_stack_chunk(array, axis=None): + return nxp.expand_dims(array, axis=axis)