diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 813e667f..f41649a4 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -96,7 +96,6 @@ jobs: # not implemented array_api_tests/test_array_object.py::test_setitem array_api_tests/test_array_object.py::test_setitem_masking - array_api_tests/test_manipulation_functions.py::test_flip array_api_tests/test_sorting_functions.py array_api_tests/test_statistical_functions.py::test_std array_api_tests/test_statistical_functions.py::test_var diff --git a/api_status.md b/api_status.md index ae929c94..bc618839 100644 --- a/api_status.md +++ b/api_status.md @@ -59,7 +59,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array- | | `broadcast_to` | :white_check_mark: | | | | | `concat` | :white_check_mark: | | | | | `expand_dims` | :white_check_mark: | | | -| | `flip` | :x: | | Needs indexing with step=-1, [#114](https://github.com/cubed-dev/cubed/issues/114) | +| | `flip` | :white_check_mark: | | | | | `permute_dims` | :white_check_mark: | | | | | `repeat` | :x: | 2023.12 | | | | `reshape` | :white_check_mark: | | Partial implementation | diff --git a/cubed/__init__.py b/cubed/__init__.py index 0149cb6a..f0089f3b 100644 --- a/cubed/__init__.py +++ b/cubed/__init__.py @@ -279,6 +279,7 @@ broadcast_to, concat, expand_dims, + flip, moveaxis, permute_dims, reshape, @@ -292,6 +293,7 @@ "broadcast_to", "concat", "expand_dims", + "flip", "moveaxis", "permute_dims", "reshape", diff --git a/cubed/array_api/__init__.py b/cubed/array_api/__init__.py index dd709141..b290674b 100644 --- a/cubed/array_api/__init__.py +++ b/cubed/array_api/__init__.py @@ -221,6 +221,7 @@ broadcast_to, concat, expand_dims, + flip, moveaxis, permute_dims, reshape, @@ -234,6 +235,7 @@ "broadcast_to", "concat", "expand_dims", + "flip", "moveaxis", "permute_dims", "reshape", diff --git a/cubed/array_api/manipulation_functions.py b/cubed/array_api/manipulation_functions.py index aaceeba4..350f924b 100644 --- a/cubed/array_api/manipulation_functions.py +++ b/cubed/array_api/manipulation_functions.py @@ -172,6 +172,50 @@ def flatten(x): return reshape(x, (-1,)) +def flip(x, /, *, axis=None): + if axis is None: + axis = tuple(range(x.ndim)) # all axes + if not isinstance(axis, tuple): + axis = (axis,) + axis = validate_axis(axis, x.ndim) + return map_direct( + _flip, + x, + shape=x.shape, + dtype=x.dtype, + chunks=x.chunks, + extra_projected_mem=x.chunkmem, + target_chunks=x.chunks, + axis=axis, + ) + + +def _flip(x, *arrays, target_chunks=None, axis=None, block_id=None): + array = arrays[0].zarray # underlying Zarr array (or virtual array) + chunks = target_chunks + + # produce a key that has slices (except for axis dimensions, which are replaced below) + idx = tuple(0 if i == axis else v for i, v in enumerate(block_id)) + key = list(get_item(chunks, idx)) + + for ax in axis: + # determine the start and stop indexes for this block along the axis dimension + chunksize = to_chunksize(chunks) + start = block_id[ax] * chunksize[ax] + stop = start + x.shape[ax] + + # flip start and stop + axis_len = array.shape[ax] + start, stop = axis_len - stop, axis_len - start + + # replace with slice + key[ax] = slice(start, stop) + + key = tuple(key) + + return nxp.flip(array[key], axis=axis) + + def moveaxis( x, source, diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index 2be010f0..e715184d 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -492,6 +492,31 @@ def test_expand_dims(spec, executor): assert_array_equal(b.compute(executor=executor), np.expand_dims([1, 2, 3], 0)) +@pytest.mark.parametrize( + "shape, chunks, axis", + [ + ((10,), (4,), None), + ((10,), (4,), 0), + ((10, 7), (4, 3), None), + ((10, 7), (4, 3), 0), + ((10, 7), (4, 3), 1), + ((10, 7), (4, 3), (0, 1)), + ((10, 7), (4, 3), -1), + ], +) +def test_flip(executor, shape, chunks, axis): + x = np.random.randint(10, size=shape) + a = xp.asarray(x, chunks=chunks) + b = xp.flip(a, axis=axis) + + assert b.chunks == a.chunks + + assert_array_equal( + b.compute(executor=executor), + np.flip(x, axis=axis), + ) + + def test_moveaxis(spec): a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) b = xp.moveaxis(a, [0, -1], [-1, 0])