Skip to content

Commit

Permalink
Implement flip using map_selection
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Oct 28, 2024
1 parent c1df109 commit 889227c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 28 deletions.
70 changes: 42 additions & 28 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from cubed.backend_array_api import namespace as nxp
from cubed.core import squeeze # noqa: F401
from cubed.core import blockwise, rechunk, unify_chunks
from cubed.core.ops import elemwise, general_blockwise, map_blocks, map_direct
from cubed.core.ops import (
elemwise,
general_blockwise,
map_blocks,
map_direct,
map_selection,
)
from cubed.utils import block_id_to_offset, get_item, offset_to_block_id, to_chunksize
from cubed.vendor.dask.array.core import broadcast_chunks, normalize_chunks
from cubed.vendor.dask.array.reshape import reshape_rechunk
Expand Down Expand Up @@ -178,42 +184,50 @@ def flip(x, /, *, axis=None):
if not isinstance(axis, tuple):
axis = (axis,)
axis = validate_axis(axis, x.ndim)
return map_direct(
_flip,

def selection_function(out_key):
out_coords = out_key[1:]
block_id = out_coords

# produce a key that has slices (except for axis dimensions, which are replaced below)
idx = tuple(0 if i == axis else v for i, v in enumerate(block_id))
key = list(get_item(x.chunks, idx))

for ax in axis:
# determine the start and stop indexes for this block along the axis dimension
start = block_id[ax] * x.chunksize[ax]
stop = start + x.chunksize[ax]
stop = min(stop, x.shape[ax])

# flip start and stop
axis_len = x.shape[ax]
start, stop = axis_len - stop, axis_len - start

# replace with slice
key[ax] = slice(start, stop)

return tuple(key)

max_num_input_blocks = _flip_num_input_blocks(axis, x.shape, x.chunksize)

return map_selection(
nxp.flip,
selection_function,
x,
shape=x.shape,
dtype=x.dtype,
chunks=x.chunks,
extra_projected_mem=x.chunkmem,
target_chunks=x.chunks,
max_num_input_blocks=max_num_input_blocks,
axis=axis,
)


def _flip(x, *arrays, target_chunks=None, axis=None, block_id=None):
array = arrays[0].zarray # underlying Zarr array (or virtual array)
chunks = target_chunks

# produce a key that has slices (except for axis dimensions, which are replaced below)
idx = tuple(0 if i == axis else v for i, v in enumerate(block_id))
key = list(get_item(chunks, idx))

def _flip_num_input_blocks(axis, shape, chunksizes):
num = 1
for ax in axis:
# determine the start and stop indexes for this block along the axis dimension
chunksize = to_chunksize(chunks)
start = block_id[ax] * chunksize[ax]
stop = start + x.shape[ax]

# flip start and stop
axis_len = array.shape[ax]
start, stop = axis_len - stop, axis_len - start

# replace with slice
key[ax] = slice(start, stop)

key = tuple(key)

return nxp.flip(array[key], axis=axis)
if shape[ax] % chunksizes[ax] != 0:
num *= 2
return num


def moveaxis(
Expand Down
2 changes: 2 additions & 0 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,11 +517,13 @@ def test_expand_dims(spec, executor):
[
((10,), (4,), None),
((10,), (4,), 0),
((10,), (5,), 0),
((10, 7), (4, 3), None),
((10, 7), (4, 3), 0),
((10, 7), (4, 3), 1),
((10, 7), (4, 3), (0, 1)),
((10, 7), (4, 3), -1),
((10, 7), (5, 3), (0, 1)),
],
)
def test_flip(executor, shape, chunks, axis):
Expand Down

0 comments on commit 889227c

Please sign in to comment.