Skip to content

Commit

Permalink
Implement stack using general_blockwise
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Jan 20, 2024
1 parent 30df284 commit 0e54be8
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from cubed.backend_array_api import numpy_array_to_backend_array
from cubed.core import squeeze # noqa: F401
from cubed.core import blockwise, rechunk, unify_chunks
from cubed.core.ops import elemwise, map_blocks, map_direct
from cubed.core.ops import elemwise, general_blockwise, map_blocks, map_direct
from cubed.utils import get_item, to_chunksize
from cubed.vendor.dask.array.core import broadcast_chunks, normalize_chunks
from cubed.vendor.dask.array.reshape import reshape_rechunk
Expand Down Expand Up @@ -293,8 +293,16 @@ def stack(arrays, /, *, axis=0):
# (output is already catered for in blockwise)
extra_projected_mem = a.chunkmem

return map_direct(
array_names = [a.name for a in arrays]

def block_function(out_key):
out_coords = out_key[1:]
in_name = array_names[out_coords[axis]]
return ((in_name, *(out_coords[:axis] + out_coords[(axis + 1) :])),)

return general_blockwise(
_read_stack_chunk,
block_function,
*arrays,
shape=shape,
dtype=dtype,
Expand All @@ -304,9 +312,5 @@ def stack(arrays, /, *, axis=0):
)


def _read_stack_chunk(x, *arrays, axis=None, block_id=None):
array = arrays[block_id[axis]]
idx = tuple(v for i, v in enumerate(block_id) if i != axis)
out = array.zarray[get_item(array.chunks, idx)]
out = nxp.expand_dims(out, axis=axis)
return out
def _read_stack_chunk(array, axis=None):
return nxp.expand_dims(array, axis=axis)

0 comments on commit 0e54be8

Please sign in to comment.