Skip to content

Commit

Permalink
Closes #3300: shape function (#3900)
Browse files Browse the repository at this point in the history
Co-authored-by: Amanda Potts <[email protected]>
  • Loading branch information
ajpotts and ajpotts authored Nov 20, 2024
1 parent 0f5f0a3 commit 19d1ca9
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 0 deletions.
1 change: 1 addition & 0 deletions arkouda/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,5 @@
from arkouda.numpy.rec import *

from ._numeric import *
from ._utils import *
from ._manipulation_functions import *
51 changes: 51 additions & 0 deletions arkouda/numpy/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Iterable, Tuple, Union

from numpy import ndarray

from arkouda.numpy.dtypes import all_scalars, isSupportedDType
from arkouda.pdarrayclass import pdarray
from arkouda.strings import Strings

__all__ = ["shape"]


def shape(a: Union[pdarray, Strings, all_scalars]) -> Tuple:
"""
Return the shape of an array.
Parameters
----------
a : pdarray
Input array.
Returns
-------
shape : tuple of ints
The elements of the shape tuple give the lengths of the
corresponding array dimensions.
Examples
--------
>>> import arkouda as ak
>>> ak.shape(ak.eye(3,2))
(3, 2)
>>> ak.shape([[1, 3]])
(1, 2)
>>> ak.shape([0])
(1,)
>>> ak.shape(0)
()
"""
if isinstance(a, (pdarray, Strings, ndarray, Iterable)) and not isinstance(a, str):
try:
result = a.shape
except AttributeError:
from arkouda import array

result = array(a).shape
return result
elif isSupportedDType(a):
return ()
else:
raise TypeError("shape requires type pdarray, ndarray, Iterable, or numeric scalar.")
5 changes: 5 additions & 0 deletions arkouda/numpy/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"intTypes",
"int_scalars",
"isSupportedBool",
"isSupportedDType",
"isSupportedFloat",
"isSupportedInt",
"isSupportedNumber",
Expand Down Expand Up @@ -320,6 +321,10 @@ def isSupportedBool(num):
return isinstance(num, ARKOUDA_SUPPORTED_BOOLS)


def isSupportedDType(scalar):
return isinstance(scalar, ARKOUDA_SUPPORTED_DTYPES)


def resolve_scalar_dtype(val: object) -> str:
"""
Try to infer what dtype arkouda_server should treat val as.
Expand Down
23 changes: 23 additions & 0 deletions tests/numpy/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

import arkouda as ak
import numpy as np


class TestFromNumericFunctions:

@pytest.mark.parametrize("x", [0, [0], [1, 2, 3], np.ndarray([0, 1, 2]), [[1, 3]], np.eye(3, 2)])
def test_shape(self, x):
assert ak.shape(x) == np.shape(x)

def test_shape_pdarray(self):
a = ak.arange(5)
assert ak.shape(a) == np.shape(a.to_ndarray())

def test_shape_strings(self):
a = ak.array(["a", "b", "c"])
assert ak.shape(a) == np.shape(a.to_ndarray())

@pytest.mark.skip_if_max_rank_less_than(2)
def test_shape_multidim_pdarray(self):
assert ak.shape(ak.eye(3, 2)) == (3, 2)

0 comments on commit 19d1ca9

Please sign in to comment.