Skip to content

Commit

Permalink
Add moveaxis
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Sep 26, 2023
1 parent 9f73a02 commit 0469b55
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@
broadcast_to,
concat,
expand_dims,
moveaxis,
permute_dims,
reshape,
squeeze,
Expand All @@ -232,6 +233,7 @@
"broadcast_to",
"concat",
"expand_dims",
"moveaxis",
"permute_dims",
"reshape",
"squeeze",
Expand Down
26 changes: 26 additions & 0 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,32 @@ def flatten(x):
return reshape(x, (-1,))


def moveaxis(
x,
source,
destination,
/,
):
# From NumPy: https://github.com/numpy/numpy/blob/a4120979d216cce00dcee511aad70bf7b45ef6e0/numpy/core/numeric.py#L1389-L1457
from numpy.core.numeric import normalize_axis_tuple

source = normalize_axis_tuple(source, x.ndim, "source")
destination = normalize_axis_tuple(destination, x.ndim, "destination")
if len(source) != len(destination):
raise ValueError(
"`source` and `destination` arguments must have "
"the same number of elements"
)

order = [n for n in range(x.ndim) if n not in source]

for dest, src in sorted(zip(destination, source)):
order.insert(dest, src)

result = permute_dims(x, order)
return result


def permute_dims(x, /, axes):
# From dask transpose
if axes:
Expand Down
9 changes: 9 additions & 0 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,15 @@ def test_expand_dims(spec, executor):
assert_array_equal(b.compute(executor=executor), np.expand_dims([1, 2, 3], 0))


def test_moveaxis(spec):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.moveaxis(a, [0, -1], [-1, 0])
assert_array_equal(
b.compute(),
np.moveaxis(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), [0, -1], [-1, 0]),
)


def test_permute_dims(spec, executor):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.permute_dims(a, (1, 0))
Expand Down

0 comments on commit 0469b55

Please sign in to comment.