diff --git a/arkouda/numpy/_numeric.py b/arkouda/numpy/_numeric.py index 1c61c9dcd2..8c01b3e7e3 100644 --- a/arkouda/numpy/_numeric.py +++ b/arkouda/numpy/_numeric.py @@ -3,13 +3,13 @@ from typing import TYPE_CHECKING, List, Sequence, Tuple, TypeVar, Union from typing import cast as type_cast from typing import no_type_check -from arkouda.groupbyclass import groupable + import numpy as np from typeguard import typechecked from arkouda.client import generic_msg from arkouda.dtypes import str_ as akstr_ -from arkouda.groupbyclass import GroupBy +from arkouda.groupbyclass import GroupBy, groupable from arkouda.numpy.dtypes import DTypes, bigint from arkouda.numpy.dtypes import bool_ as ak_bool from arkouda.numpy.dtypes import dtype as akdtype @@ -26,7 +26,13 @@ from arkouda.numpy.dtypes import uint64 as ak_uint64 from arkouda.pdarrayclass import all as ak_all from arkouda.pdarrayclass import any as ak_any -from arkouda.pdarrayclass import argmax, create_pdarray, pdarray, sum +from arkouda.pdarrayclass import ( + argmax, + broadcast_if_needed, + create_pdarray, + pdarray, + sum, +) from arkouda.pdarraycreation import array, linspace, scalar_array from arkouda.sorting import sort from arkouda.strings import Strings @@ -2292,10 +2298,8 @@ def array_equal(pda_a: pdarray, pda_b: pdarray, equal_nan: bool = False): def putmask( - A : pdarray , - mask : pdarray, - Values : pdarray - ) : # doesn't return anything, as A is overwritten in place + A: pdarray, mask: pdarray, Values: pdarray +): # doesn't return anything, as A is overwritten in place """ Overwrites elements of A with elements from B based upon a mask array. Similar to numpy.putmask, where mask = False, A retains its original value, @@ -2599,18 +2603,26 @@ def matmul(pdaLeft: pdarray, pdaRight: pdarray): """ if pdaLeft.ndim != pdaRight.ndim: raise ValueError("matmul requires matrices of matching rank.") + + x1, x2, tmp_x1, tmp_x2 = broadcast_if_needed(pdaLeft, pdaRight) + cmd = f"matmul<{pdaLeft.dtype},{pdaRight.dtype},{pdaLeft.ndim}>" args = { - "x1": pdaLeft, - "x2": pdaRight, + "x1": x1, + "x2": x2, } - return create_pdarray( - generic_msg( - cmd=cmd, - args=args, - ) + repMsg = generic_msg( + cmd=cmd, + args=args, ) + if tmp_x1: + del x1 + if tmp_x2: + del x2 + + return create_pdarray(repMsg) + def vecdot(x1: pdarray, x2: pdarray): """ @@ -2647,16 +2659,25 @@ def vecdot(x1: pdarray, x2: pdarray): raise ValueError("vecdot requires matrices of matching rank.") if x1.ndim < 2: raise ValueError("vector requires matrices of rank 2 or more.") + + x1b, x2b, tmp_x1, tmp_x2 = broadcast_if_needed(x1, x2) + cmd = f"vecdot<{x1.dtype},{x2.dtype},{x1.ndim}>" args = { - "x1": x1, - "x2": x2, + "x1": x1b, + "x2": x2b, "bcShape": tuple(x1.shape), "axis": 0, } - return create_pdarray( - generic_msg( - cmd=cmd, - args=args, - ) + + repMsg = generic_msg( + cmd=cmd, + args=args, ) + + if tmp_x1: + del x1 + if tmp_x2: + del x2 + + return create_pdarray(repMsg)