Skip to content

Commit

Permalink
Array API Manipulation Function Improvements (Bears-R-Us#3056)
Browse files Browse the repository at this point in the history
* refactor some array manipulation commands to reduce uneeded comm with locale 0

Signed-off-by: Jeremiah Corrado <[email protected]>

* add implementation of tile, unstack, and repeat to maniputation-functions module. Implement rank-reducing indexing for ND arrays. Bug fixes and performance improvements in manipultation functions.

Signed-off-by: Jeremiah Corrado <[email protected]>

* fix flake8, mypy and python compat errors

Signed-off-by: Jeremiah Corrado <[email protected]>

* add error handling to expandDims and stack for cases where max array rank would be exceeded

Signed-off-by: Jeremiah Corrado <[email protected]>

* add special cases for flatten and unflatten to reshapeMsg

Signed-off-by: Jeremiah Corrado <[email protected]>

* replace rank reducing slices in flatten/unflatten with explicit put/get operations

Signed-off-by: Jeremiah Corrado <[email protected]>

* use AryUtil flatten/unflatten in SetMsg

Signed-off-by: Jeremiah Corrado <[email protected]>

---------

Signed-off-by: Jeremiah Corrado <[email protected]>
  • Loading branch information
jeremiah-corrado authored Apr 10, 2024
1 parent f8b4f3c commit 2cd6162
Show file tree
Hide file tree
Showing 13 changed files with 860 additions and 306 deletions.
8 changes: 8 additions & 0 deletions arkouda/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,15 @@
concat,
expand_dims,
flip,
moveaxis,
permute_dims,
repeat,
reshape,
roll,
squeeze,
stack,
tile,
unstack,
)

from ._searching_functions import argmax, argmin, nonzero, where
Expand Down Expand Up @@ -255,11 +259,15 @@
"concat",
"expand_dims",
"flip",
"moveaxis",
"permute_dims",
"repeat",
"reshape",
"roll",
"squeeze",
"stack",
"tile",
"unstack",
]

__all__ += ["argmax", "argmin", "nonzero", "where"]
Expand Down
124 changes: 73 additions & 51 deletions arkouda/array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,13 @@
from typing import List, Optional, Tuple, Union, cast
from arkouda.client import generic_msg
from arkouda.pdarrayclass import create_pdarray
from arkouda.pdarraycreation import scalar_array
from arkouda.util import broadcast_dims

import numpy as np


def broadcast_arrays(*arrays: Array) -> List[Array]:
"""
Array API compatible wrapper for :py:func:`np.broadcast_arrays <numpy.broadcast_arrays>`.
See its docstring for more information.
"""

shapes = [a.shape for a in arrays]
bcShape = shapes[0]
for shape in shapes[1:]:
Expand Down Expand Up @@ -50,15 +45,9 @@ def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array:
raise ValueError(f"Failed to broadcast array: {e}")


# Note: the function name is different here
def concat(
arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.concatenate <numpy.concatenate>`.
See its docstring for more information.
"""
# TODO: type promotion across input arrays

return Array._new(
Expand All @@ -81,11 +70,6 @@ def concat(


def expand_dims(x: Array, /, *, axis: int) -> Array:
"""
Array API compatible wrapper for :py:func:`np.expand_dims <numpy.expand_dims>`.
See its docstring for more information.
"""
try:
return Array._new(
create_pdarray(
Expand All @@ -106,11 +90,6 @@ def expand_dims(x: Array, /, *, axis: int) -> Array:


def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
"""
Array API compatible wrapper for :py:func:`np.flip <numpy.flip>`.
See its docstring for more information.
"""
axisList = []
if axis is not None:
axisList = list(axis) if isinstance(axis, tuple) else [axis]
Expand All @@ -137,15 +116,22 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
def moveaxis(
x: Array, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]], /
) -> Array:
raise NotImplementedError("moveaxis is not yet implemented")
perm = list(range(x.ndim))
if isinstance(source, tuple):
if isinstance(destination, tuple):
for s, d in zip(source, destination):
perm[s] = d
else:
raise ValueError("source and destination must both be tuples if source is a tuple")
elif isinstance(destination, int):
perm[source] = destination
else:
raise ValueError("source and destination must both be integers if source is a tuple")

return permute_dims(x, axes=tuple(perm))

def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
"""
Array API compatible wrapper for :py:func:`np.transpose <numpy.transpose>`.

See its docstring for more information.
"""
def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
try:
return Array._new(
create_pdarray(
Expand All @@ -166,17 +152,33 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:


def repeat(x: Array, repeats: Union[int, Array], /, *, axis: Optional[int] = None) -> Array:
raise NotImplementedError("repeat is not yet implemented")
if isinstance(repeats, int):
reps = Array._new(scalar_array(repeats))
else:
reps = repeats

if axis is None:
return Array._new(
create_pdarray(
cast(
str,
generic_msg(
cmd=f"repeatFlat{x.ndim}D",
args={
"name": x._array,
"repeats": reps._array,
},
),
)
)
)
else:
raise NotImplementedError("repeat with 'axis' argument is not yet implemented")


def reshape(
x: Array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`.
See its docstring for more information.
"""

# TODO: figure out copying semantics (currently always creates a copy)
try:
Expand Down Expand Up @@ -205,11 +207,6 @@ def roll(
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.roll <numpy.roll>`.
See its docstring for more information.
"""
axisList = []
if axis is not None:
axisList = list(axis) if isinstance(axis, tuple) else [axis]
Expand Down Expand Up @@ -240,11 +237,6 @@ def roll(


def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
"""
Array API compatible wrapper for :py:func:`np.squeeze <numpy.squeeze>`.
See its docstring for more information.
"""
nAxes = len(axis) if isinstance(axis, tuple) else 1
try:
return Array._new(
Expand All @@ -267,11 +259,6 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:


def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
"""
Array API compatible wrapper for :py:func:`np.stack <numpy.stack>`.
See its docstring for more information.
"""
# TODO: type promotion across input arrays
return Array._new(
create_pdarray(
Expand All @@ -291,8 +278,43 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) ->


def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array:
raise NotImplementedError("tile is not yet implemented")
if len(repetitions) > x.ndim:
xr = reshape(x, (1,) * (len(repetitions) - x.ndim) + x.shape)
reps = repetitions
elif len(repetitions) < x.ndim:
xr = x
reps = (1,) * (x.ndim - len(repetitions)) + repetitions
else:
xr = x
reps = repetitions

return Array._new(
create_pdarray(
cast(
str,
generic_msg(
cmd=f"tile{xr.ndim}D",
args={
"name": xr._array,
"reps": reps,
},
),
)
)
)


def unstack(x: Array, /, *, axis: int = 0) -> Tuple[Array, ...]:
raise NotImplementedError("unstack is not yet implemented")
resp = cast(
str,
generic_msg(
cmd=f"unstack{x.ndim}D",
args={
"name": x._array,
"axis": axis,
"numReturnArrays": x.shape[axis],
},
),
)

return tuple([Array._new(create_pdarray(a)) for a in resp.split("+")])
49 changes: 43 additions & 6 deletions arkouda/pdarrayclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,18 +657,41 @@ def __getitem__(self, key):
return create_pdarray(repMsg)

if isinstance(key, tuple):
allScalar = True
if len(key) > self.ndim:
raise IndexError(f"too many indices ({len(key)}) for array with {self.ndim} dimensions")

# replace '...' with the appropriate number of ':'
elipsis_axis_idx = -1
for dim, k in enumerate(key):
if isinstance(k, type(Ellipsis)):
if elipsis_axis_idx != -1:
raise IndexError("array index can only have one ellipsis")
else:
elipsis_axis_idx = dim

if elipsis_axis_idx != -1:
key = tuple(
key[:elipsis_axis_idx] +
(slice(None),) * (self.ndim - len(key) + 1) +
key[(elipsis_axis_idx+1):]
)

# parse the key tuple
num_scalar = 0
scalar_axes = []
starts = []
stops = []
strides = []
for dim, k in enumerate(key):
if isinstance(k, slice):
allScalar = False
(start, stop, stride) = k.indices(self.shape[dim])
starts.append(start)
stops.append(stop)
strides.append(stride)
elif np.isscalar(k) and (resolve_scalar_dtype(k) in ["int64", "uint64"]):
num_scalar += 1
scalar_axes.append(dim)

if k < 0:
# Interpret negative key as offset from end of array
k += int(self.shape[dim])
Expand All @@ -678,15 +701,14 @@ def __getitem__(self, key):
)
else:
# treat this as a single element slice
# TODO: implement rank-reducing slices
starts.append(k)
stops.append(k + 1)
strides.append(1)
else:
raise IndexError(f"Unhandled key type: {k} ({type(k)})")

if allScalar:
# use simpler indexing (and return a scalar) if we got a tuple of only scalars
if num_scalar == len(key):
# all scalars: use simpler indexing (and return a scalar)
repMsg = generic_msg(
cmd=f"[int]{self.ndim}D",
args={
Expand All @@ -706,7 +728,22 @@ def __getitem__(self, key):
"strides": tuple(strides),
},
)
return create_pdarray(repMsg)
maybe_degen_arr = create_pdarray(repMsg)

if num_scalar > 0:
# reduce the array rank if there are any scalar indices
# note: squeeze requires the non-default ManipulationMsg server module
repMsg = generic_msg(
cmd=f"squeeze{maybe_degen_arr.ndim}Dx{maybe_degen_arr.ndim - num_scalar}D",
args={
"name": maybe_degen_arr,
"nAxes": num_scalar,
"axes": scalar_axes,
},
)
return create_pdarray(repMsg)
else:
return maybe_degen_arr

if isinstance(key, pdarray) and self.ndim == 1:
kind, _ = translate_np_dtype(key.dtype)
Expand Down
Loading

0 comments on commit 2cd6162

Please sign in to comment.