Skip to content

Commit

Permalink
Implement flip (#528)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Aug 1, 2024
1 parent 9505005 commit db03d62
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 2 deletions.
1 change: 0 additions & 1 deletion .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ jobs:
# not implemented
array_api_tests/test_array_object.py::test_setitem
array_api_tests/test_array_object.py::test_setitem_masking
array_api_tests/test_manipulation_functions.py::test_flip
array_api_tests/test_sorting_functions.py
array_api_tests/test_statistical_functions.py::test_std
array_api_tests/test_statistical_functions.py::test_var
Expand Down
2 changes: 1 addition & 1 deletion api_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
| | `broadcast_to` | :white_check_mark: | | |
| | `concat` | :white_check_mark: | | |
| | `expand_dims` | :white_check_mark: | | |
| | `flip` | :x: | | Needs indexing with step=-1, [#114](https://github.com/cubed-dev/cubed/issues/114) |
| | `flip` | :white_check_mark: | | |
| | `permute_dims` | :white_check_mark: | | |
| | `repeat` | :x: | 2023.12 | |
| | `reshape` | :white_check_mark: | | Partial implementation |
Expand Down
2 changes: 2 additions & 0 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@
broadcast_to,
concat,
expand_dims,
flip,
moveaxis,
permute_dims,
reshape,
Expand All @@ -292,6 +293,7 @@
"broadcast_to",
"concat",
"expand_dims",
"flip",
"moveaxis",
"permute_dims",
"reshape",
Expand Down
2 changes: 2 additions & 0 deletions cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@
broadcast_to,
concat,
expand_dims,
flip,
moveaxis,
permute_dims,
reshape,
Expand All @@ -234,6 +235,7 @@
"broadcast_to",
"concat",
"expand_dims",
"flip",
"moveaxis",
"permute_dims",
"reshape",
Expand Down
44 changes: 44 additions & 0 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,50 @@ def flatten(x):
return reshape(x, (-1,))


def flip(x, /, *, axis=None):
if axis is None:
axis = tuple(range(x.ndim)) # all axes
if not isinstance(axis, tuple):
axis = (axis,)
axis = validate_axis(axis, x.ndim)
return map_direct(
_flip,
x,
shape=x.shape,
dtype=x.dtype,
chunks=x.chunks,
extra_projected_mem=x.chunkmem,
target_chunks=x.chunks,
axis=axis,
)


def _flip(x, *arrays, target_chunks=None, axis=None, block_id=None):
array = arrays[0].zarray # underlying Zarr array (or virtual array)
chunks = target_chunks

# produce a key that has slices (except for axis dimensions, which are replaced below)
idx = tuple(0 if i == axis else v for i, v in enumerate(block_id))
key = list(get_item(chunks, idx))

for ax in axis:
# determine the start and stop indexes for this block along the axis dimension
chunksize = to_chunksize(chunks)
start = block_id[ax] * chunksize[ax]
stop = start + x.shape[ax]

# flip start and stop
axis_len = array.shape[ax]
start, stop = axis_len - stop, axis_len - start

# replace with slice
key[ax] = slice(start, stop)

key = tuple(key)

return nxp.flip(array[key], axis=axis)


def moveaxis(
x,
source,
Expand Down
25 changes: 25 additions & 0 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,31 @@ def test_expand_dims(spec, executor):
assert_array_equal(b.compute(executor=executor), np.expand_dims([1, 2, 3], 0))


@pytest.mark.parametrize(
"shape, chunks, axis",
[
((10,), (4,), None),
((10,), (4,), 0),
((10, 7), (4, 3), None),
((10, 7), (4, 3), 0),
((10, 7), (4, 3), 1),
((10, 7), (4, 3), (0, 1)),
((10, 7), (4, 3), -1),
],
)
def test_flip(executor, shape, chunks, axis):
x = np.random.randint(10, size=shape)
a = xp.asarray(x, chunks=chunks)
b = xp.flip(a, axis=axis)

assert b.chunks == a.chunks

assert_array_equal(
b.compute(executor=executor),
np.flip(x, axis=axis),
)


def test_moveaxis(spec):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.moveaxis(a, [0, -1], [-1, 0])
Expand Down

0 comments on commit db03d62

Please sign in to comment.