Skip to content

Commit

Permalink
Closes 3878 - refactors rounding functions to new interface, pulls ha…
Browse files Browse the repository at this point in the history
…sh function into their own procs (Bears-R-Us#3898)

* Rebased

* Rebasing

* Minor neatening

* Minor spacing edits

* Fixed flake8 issues

* Removed accidental duplicates

* Addresses comments

* Refining multi-dim test of isinf_isfinite

* Closes Bears-R-Us#3896 PytestUnknownMarkWarning for pytest.mark.skip_if_nl_greater_than and pytest.mark.skip_if_nl_less_than (Bears-R-Us#3897)

Co-authored-by: Amanda Potts <[email protected]>

* cleanup

---------

Co-authored-by: drculhane <[email protected]>
Co-authored-by: ajpotts <[email protected]>
Co-authored-by: Amanda Potts <[email protected]>
  • Loading branch information
4 people authored Nov 20, 2024
1 parent cd596ff commit 15bd176
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 123 deletions.
51 changes: 22 additions & 29 deletions arkouda/numpy/_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, List, Sequence, Tuple, TypeVar, Union
from typing import cast as type_cast
from typing import no_type_check

import numpy as np
from typeguard import typechecked

Expand All @@ -26,13 +25,7 @@
from arkouda.numpy.dtypes import _datatype_check
from arkouda.pdarrayclass import all as ak_all
from arkouda.pdarrayclass import any as ak_any
from arkouda.pdarrayclass import (
argmax,
broadcast_if_needed,
create_pdarray,
pdarray,
sum,
)
from arkouda.pdarrayclass import argmax, broadcast_if_needed, create_pdarray, pdarray, sum
from arkouda.pdarraycreation import array, linspace, scalar_array
from arkouda.sorting import sort
from arkouda.strings import Strings
Expand Down Expand Up @@ -281,10 +274,10 @@ def ceil(pda: pdarray) -> pdarray:
>>> ak.ceil(ak.linspace(1.1,5.5,5))
array([2, 3, 4, 5, 6])
"""
_datatype_check(pda.dtype, [float], 'ceil')
repMsg = generic_msg(
cmd=f"efunc{pda.ndim}D",
cmd=f"ceil<{pda.dtype},{pda.ndim}>",
args={
"func": "ceil",
"array": pda,
},
)
Expand Down Expand Up @@ -315,11 +308,11 @@ def floor(pda: pdarray) -> pdarray:
>>> ak.floor(ak.linspace(1.1,5.5,5))
array([1, 2, 3, 4, 5])
"""
_datatype_check(pda.dtype, [float], 'floor')
repMsg = generic_msg(
cmd=f"efunc{pda.ndim}D",
cmd=f"floor<{pda.dtype},{pda.ndim}>",
args={
"func": "floor",
"array": pda,
"pda": pda,
},
)
return create_pdarray(type_cast(str, repMsg))
Expand Down Expand Up @@ -349,11 +342,11 @@ def round(pda: pdarray) -> pdarray:
>>> ak.round(ak.array([1.1, 2.5, 3.14159]))
array([1, 3, 3])
"""
_datatype_check(pda.dtype, [float], 'round')
repMsg = generic_msg(
cmd=f"efunc{pda.ndim}D",
cmd=f"round<{pda.dtype},{pda.ndim}>",
args={
"func": "round",
"array": pda,
"pda": pda,
},
)
return create_pdarray(type_cast(str, repMsg))
Expand Down Expand Up @@ -383,10 +376,10 @@ def trunc(pda: pdarray) -> pdarray:
>>> ak.trunc(ak.array([1.1, 2.5, 3.14159]))
array([1, 2, 3])
"""
_datatype_check(pda.dtype, [float], 'trunc')
repMsg = generic_msg(
cmd=f"efunc{pda.ndim}D",
cmd=f"trunc<{pda.dtype},{pda.ndim}>",
args={
"func": "trunc",
"array": pda,
},
)
Expand Down Expand Up @@ -1362,7 +1355,7 @@ def rad2deg(pda: pdarray, where: Union[bool, pdarray] = True) -> pdarray:
elif where is False:
return pda
else:
return _merge_where(pda[:], where, 180 * (pda[where] / np.pi))
return _merge_where(pda[:], where, 180*(pda[where]/np.pi))


@typechecked
Expand Down Expand Up @@ -1394,7 +1387,7 @@ def deg2rad(pda: pdarray, where: Union[bool, pdarray] = True) -> pdarray:
elif where is False:
return pda
else:
return _merge_where(pda[:], where, (np.pi * pda[where] / 180))
return _merge_where(pda[:], where, (np.pi*pda[where]/180))


def _hash_helper(a):
Expand Down Expand Up @@ -1521,13 +1514,14 @@ def hash(
def _hash_single(pda: pdarray, full: bool = True):
if pda.dtype == bigint:
return hash(pda.bigint_to_uint_arrays())
_datatype_check(pda.dtype, [float, int, ak_uint64], 'hash')
hname = "hash128" if full else "hash64"
repMsg = type_cast(
str,
generic_msg(
cmd=f"efunc{pda.ndim}D",
cmd=f"{hname}<{pda.dtype},{pda.ndim}>",
args={
"func": "hash128" if full else "hash64",
"array": pda,
"x": pda,
},
),
)
Expand Down Expand Up @@ -2588,19 +2582,18 @@ def matmul(pdaLeft: pdarray, pdaRight: pdarray):
"""
if pdaLeft.ndim != pdaRight.ndim:
raise ValueError("matmul requires matrices of matching rank.")

cmd = f"matmul<{pdaLeft.dtype},{pdaRight.dtype},{pdaLeft.ndim}>"
args = {
"x1": pdaLeft,
"x2": pdaRight,
}
repMsg = generic_msg(
cmd=cmd,
args=args,
return create_pdarray(
generic_msg(
cmd=cmd,
args=args,
)
)

return create_pdarray(repMsg)


def vecdot(x1: pdarray, x2: pdarray):
"""
Expand Down
128 changes: 53 additions & 75 deletions src/EfuncMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,59 @@ module EfuncMsg
@arkouda.registerCommand(name="isfinite")
proc isfinite_ (pda : [?d] real) : [d] bool { return (isFinite(pda)) ; }

@arkouda.registerCommand (name="floor")
proc floor_ (pda : [?d] ?t) : [d] real throws
where (t==real) { return floor(pda); }

@arkouda.registerCommand (name="ceil")
proc ceil_ (pda : [?d] ?t) : [d] real throws
where (t==real) { return ceil(pda); }

@arkouda.registerCommand (name="round")
proc round_ (pda : [?d] ?t) : [d] real throws
where (t==real) { return round(pda); }

@arkouda.registerCommand (name="trunc")
proc trunc_ (pda : [?d] ?t) : [d] real throws
where (t==real) { return trunc(pda); }

// Hashes are more of a challenge to unhook from the old interface, but they
// have been pulled out into their own functions.

@arkouda.instantiateAndRegister
proc hash64 (cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int) : MsgTuple throws
where ((array_dtype==real || array_dtype==int || array_dtype==uint) && array_nd==1) {
const efunc = msgArgs.getValueOf("x"),
e = st[msgArgs["x"]]: SymEntry(array_dtype,array_nd);
const rname = st.nextName();
overMemLimit(numBytes(array_dtype)*e.size);
var a = st.addEntry(rname, e.tupShape, uint);
forall (ai, x) in zip (a.a, e.a) {
ai = sipHash64(x) : uint ;
}
var repMsg = "created " + st.attrib(rname);
eLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

@arkouda.instantiateAndRegister
proc hash128 (cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int) : MsgTuple throws
where ((array_dtype==real || array_dtype==int || array_dtype==uint) && array_nd==1) {
const efunc = msgArgs.getValueOf("x"),
e = st[msgArgs["x"]]: SymEntry(array_dtype,array_nd);
const rname = st.nextName();
var rname2 = st.nextName();
overMemLimit(numBytes(array_dtype) * e.size * 2);
var a1 = st.addEntry(rname2, e.tupShape, uint);
var a2 = st.addEntry(rname, e.tupShape, uint);
forall (a1i, a2i, x) in zip(a1.a, a2.a, e.a) {
(a1i, a2i) = sipHash128(x): (uint, uint);
}
var repMsg = "created " + st.attrib(rname2) + "+";
repMsg += "created " + st.attrib(rname);
eLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

// End of rewrite section -- delete this comment after all of EfuncMsg is rewritten.

Expand All @@ -162,9 +215,6 @@ module EfuncMsg
ref ea = e.a;
select efunc
{
when "round" {
st.addEntry(rname, new shared SymEntry(ea));
}
when "sgn" {
st.addEntry(rname, new shared SymEntry(sgn(ea)));
}
Expand All @@ -190,25 +240,6 @@ module EfuncMsg
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
when "hash64" {
overMemLimit(numBytes(int) * e.size);
var a = st.addEntry(rname, e.tupShape, uint);
forall (ai, x) in zip(a.a, e.a) {
ai = sipHash64(x): uint;
}
}
when "hash128" {
overMemLimit(numBytes(int) * e.size * 2);
var rname2 = st.nextName();
var a1 = st.addEntry(rname2, e.tupShape, uint);
var a2 = st.addEntry(rname, e.tupShape, uint);
forall (a1i, a2i, x) in zip(a1.a, a2.a, e.a) {
(a1i, a2i) = sipHash128(x): (uint, uint);
}
// Put first array's attrib in repMsg and let common
// code append second array's attrib
repMsg += "created " + st.attrib(rname2) + "+";
}
when "popcount" {
st.addEntry(rname, new shared SymEntry(popCount(ea)));
}
Expand Down Expand Up @@ -236,18 +267,6 @@ module EfuncMsg
ref ea = e.a;
select efunc
{
when "ceil" {
st.addEntry(rname, new shared SymEntry(ceil(ea)));
}
when "floor" {
st.addEntry(rname, new shared SymEntry(floor(ea)));
}
when "round" {
st.addEntry(rname, new shared SymEntry(round(ea)));
}
when "trunc" {
st.addEntry(rname, new shared SymEntry(trunc(ea)));
}
when "sgn" {
st.addEntry(rname, new shared SymEntry(sgn(ea)));
}
Expand All @@ -273,25 +292,6 @@ module EfuncMsg
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
when "hash64" {
overMemLimit(numBytes(real) * e.size);
var a = st.addEntry(rname, e.tupShape, uint);
forall (ai, x) in zip(a.a, e.a) {
ai = sipHash64(x): uint;
}
}
when "hash128" {
overMemLimit(numBytes(real) * e.size * 2);
var rname2 = st.nextName();
var a1 = st.addEntry(rname2, e.tupShape, uint);
var a2 = st.addEntry(rname, e.tupShape, uint);
forall (a1i, a2i, x) in zip(a1.a, a2.a, e.a) {
(a1i, a2i) = sipHash128(x): (uint, uint);
}
// Put first array's attrib in repMsg and let common
// code append second array's attrib
repMsg += "created " + st.attrib(rname2) + "+";
}
otherwise {
var errorMsg = notImplementedError(pn,efunc,gEnt.dtype);
eLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
Expand Down Expand Up @@ -353,9 +353,6 @@ module EfuncMsg
when "ctz" {
st.addEntry(rname, new shared SymEntry(ctz(ea)));
}
when "round" {
st.addEntry(rname, new shared SymEntry(ea));
}
when "sgn" {
st.addEntry(rname, new shared SymEntry(sgn(ea)));
}
Expand Down Expand Up @@ -384,25 +381,6 @@ module EfuncMsg
when "parity" {
st.addEntry(rname, new shared SymEntry(parity(ea)));
}
when "hash64" {
overMemLimit(numBytes(uint) * e.size);
var a = st.addEntry(rname, e.tupShape, uint);
forall (ai, x) in zip(a.a, e.a) {
ai = sipHash64(x): uint;
}
}
when "hash128" {
overMemLimit(numBytes(uint) * e.size * 2);
var rname2 = st.nextName();
var a1 = st.addEntry(rname2, e.tupShape, uint);
var a2 = st.addEntry(rname, e.tupShape, uint);
forall (a1i, a2i, x) in zip(a1.a, a2.a, e.a) {
(a1i, a2i) = sipHash128(x): (uint, uint);
}
// Put first array's attrib in repMsg and let common
// code append second array's attrib
repMsg += "created " + st.attrib(rname2) + "+";
}
when "not" {
st.addEntry(rname, new shared SymEntry(!e.a));
}
Expand Down
Loading

0 comments on commit 15bd176

Please sign in to comment.