diff --git a/cubed/array_api/__init__.py b/cubed/array_api/__init__.py index 4dfdc500..8ef442ef 100644 --- a/cubed/array_api/__init__.py +++ b/cubed/array_api/__init__.py @@ -221,6 +221,7 @@ broadcast_to, concat, expand_dims, + moveaxis, permute_dims, reshape, squeeze, @@ -232,6 +233,7 @@ "broadcast_to", "concat", "expand_dims", + "moveaxis", "permute_dims", "reshape", "squeeze", diff --git a/cubed/array_api/manipulation_functions.py b/cubed/array_api/manipulation_functions.py index 3f6dea47..64662f87 100644 --- a/cubed/array_api/manipulation_functions.py +++ b/cubed/array_api/manipulation_functions.py @@ -159,6 +159,32 @@ def flatten(x): return reshape(x, (-1,)) +def moveaxis( + x, + source, + destination, + /, +): + # From NumPy: https://github.com/numpy/numpy/blob/a4120979d216cce00dcee511aad70bf7b45ef6e0/numpy/core/numeric.py#L1389-L1457 + from numpy.core.numeric import normalize_axis_tuple + + source = normalize_axis_tuple(source, x.ndim, "source") + destination = normalize_axis_tuple(destination, x.ndim, "destination") + if len(source) != len(destination): + raise ValueError( + "`source` and `destination` arguments must have " + "the same number of elements" + ) + + order = [n for n in range(x.ndim) if n not in source] + + for dest, src in sorted(zip(destination, source)): + order.insert(dest, src) + + result = permute_dims(x, order) + return result + + def permute_dims(x, /, axes): # From dask transpose if axes: diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index 3e30f6b2..365ada4b 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -439,6 +439,15 @@ def test_expand_dims(spec, executor): assert_array_equal(b.compute(executor=executor), np.expand_dims([1, 2, 3], 0)) +def test_moveaxis(spec): + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) + b = xp.moveaxis(a, [0, -1], [-1, 0]) + assert_array_equal( + b.compute(), + np.moveaxis(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), [0, -1], [-1, 0]), + ) + + def test_permute_dims(spec, executor): a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) b = xp.permute_dims(a, (1, 0))