Skip to content

Commit

Permalink
add broadcast_if_needed to lingalg
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Oct 10, 2024
1 parent ace8019 commit 44a5234
Showing 1 changed file with 42 additions and 21 deletions.
63 changes: 42 additions & 21 deletions arkouda/numpy/_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import TYPE_CHECKING, List, Sequence, Tuple, TypeVar, Union
from typing import cast as type_cast
from typing import no_type_check
from arkouda.groupbyclass import groupable

import numpy as np
from typeguard import typechecked

from arkouda.client import generic_msg
from arkouda.dtypes import str_ as akstr_
from arkouda.groupbyclass import GroupBy
from arkouda.groupbyclass import GroupBy, groupable
from arkouda.numpy.dtypes import DTypes, bigint
from arkouda.numpy.dtypes import bool_ as ak_bool
from arkouda.numpy.dtypes import dtype as akdtype
Expand All @@ -26,7 +26,13 @@
from arkouda.numpy.dtypes import uint64 as ak_uint64
from arkouda.pdarrayclass import all as ak_all
from arkouda.pdarrayclass import any as ak_any
from arkouda.pdarrayclass import argmax, create_pdarray, pdarray, sum
from arkouda.pdarrayclass import (
argmax,
broadcast_if_needed,
create_pdarray,
pdarray,
sum,
)
from arkouda.pdarraycreation import array, linspace, scalar_array
from arkouda.sorting import sort
from arkouda.strings import Strings
Expand Down Expand Up @@ -2292,10 +2298,8 @@ def array_equal(pda_a: pdarray, pda_b: pdarray, equal_nan: bool = False):


def putmask(
A : pdarray ,
mask : pdarray,
Values : pdarray
) : # doesn't return anything, as A is overwritten in place
A: pdarray, mask: pdarray, Values: pdarray
): # doesn't return anything, as A is overwritten in place
"""
Overwrites elements of A with elements from B based upon a mask array.
Similar to numpy.putmask, where mask = False, A retains its original value,
Expand Down Expand Up @@ -2599,18 +2603,26 @@ def matmul(pdaLeft: pdarray, pdaRight: pdarray):
"""
if pdaLeft.ndim != pdaRight.ndim:
raise ValueError("matmul requires matrices of matching rank.")

x1, x2, tmp_x1, tmp_x2 = broadcast_if_needed(pdaLeft, pdaRight)

cmd = f"matmul<{pdaLeft.dtype},{pdaRight.dtype},{pdaLeft.ndim}>"
args = {
"x1": pdaLeft,
"x2": pdaRight,
"x1": x1,
"x2": x2,
}
return create_pdarray(
generic_msg(
cmd=cmd,
args=args,
)
repMsg = generic_msg(
cmd=cmd,
args=args,
)

if tmp_x1:
del x1
if tmp_x2:
del x2

return create_pdarray(repMsg)


def vecdot(x1: pdarray, x2: pdarray):
"""
Expand Down Expand Up @@ -2647,16 +2659,25 @@ def vecdot(x1: pdarray, x2: pdarray):
raise ValueError("vecdot requires matrices of matching rank.")
if x1.ndim < 2:
raise ValueError("vector requires matrices of rank 2 or more.")

x1b, x2b, tmp_x1, tmp_x2 = broadcast_if_needed(x1, x2)

cmd = f"vecdot<{x1.dtype},{x2.dtype},{x1.ndim}>"
args = {
"x1": x1,
"x2": x2,
"x1": x1b,
"x2": x2b,
"bcShape": tuple(x1.shape),
"axis": 0,
}
return create_pdarray(
generic_msg(
cmd=cmd,
args=args,
)

repMsg = generic_msg(
cmd=cmd,
args=args,
)

if tmp_x1:
del x1
if tmp_x2:
del x2

return create_pdarray(repMsg)

0 comments on commit 44a5234

Please sign in to comment.