Skip to content

Commit

Permalink
Closes #3864 max and min of bool to return bool like numpy (#3901)
Browse files Browse the repository at this point in the history
Co-authored-by: Amanda Potts <[email protected]>
  • Loading branch information
ajpotts and ajpotts authored Dec 3, 2024
1 parent 40d1fd8 commit d7ca013
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 14 deletions.
12 changes: 6 additions & 6 deletions src/ReductionMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,18 @@ module ReductionMsg
}

@arkouda.registerCommand
proc maxAll(const ref x:[?d] ?t, skipNan: bool): reductionReturnType(t) throws
proc maxAll(const ref x:[?d] ?t, skipNan: bool): t throws
where t==int || t==real || t==uint(64) || t==bool
{
use SliceReductionOps;
return getMaxSlice(x, x.domain, skipNan);
}

@arkouda.registerCommand
proc max(const ref x:[?d] ?t, axis: list(int), skipNan: bool): [] reductionReturnType(t) throws
proc max(const ref x:[?d] ?t, axis: list(int), skipNan: bool): [] t throws
where t==int || t==real || t==uint(64) || t==bool {
use SliceReductionOps;
type opType = reductionReturnType(t);
type opType = t;
const (valid, axes) = validateNegativeAxes(axis, x.rank);
if !valid {
throw new Error("Invalid axis value(s) '%?' in slicing reduction".format(axis));
Expand All @@ -117,18 +117,18 @@ module ReductionMsg
}

@arkouda.registerCommand
proc minAll(const ref x:[?d] ?t, skipNan: bool): reductionReturnType(t) throws
proc minAll(const ref x:[?d] ?t, skipNan: bool): t throws
where t==int || t==real || t==uint(64) || t==bool
{
use SliceReductionOps;
return getMinSlice(x, x.domain, skipNan);
}

@arkouda.registerCommand
proc min(const ref x:[?d] ?t, axis: list(int), skipNan: bool): [] reductionReturnType(t) throws
proc min(const ref x:[?d] ?t, axis: list(int), skipNan: bool): [] t throws
where t==int || t==real || t==uint(64) || t==bool {
use SliceReductionOps;
type opType = reductionReturnType(t);
type opType = t;
const (valid, axes) = validateNegativeAxes(axis, x.rank);
if !valid {
throw new Error("Invalid axis value(s) '%?' in slicing reduction".format(axis));
Expand Down
8 changes: 0 additions & 8 deletions tests/pdarrayclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

# TODO: add unint8 to DTYPES


class TestPdarrayClass:

@pytest.mark.skip_if_max_rank_less_than(2)
Expand Down Expand Up @@ -141,14 +140,7 @@ def assert_reduction_ops_match(
np_op = getattr(numpy, op)
nda = pda.to_ndarray()

# TODO: remove cast when #3864 is resolved.
ak_result = ak_op(pda, axis=axis)
if op in ["max", "min"] and pda.dtype == ak.bool_:
if isinstance(ak_result, ak.pdarray):
ak_result = ak.cast(ak_result, dt=ak.bool_)
else:
ak_result = np.bool_(ak_result)

ak_assert_equivalent(ak_result, np_op(nda, axis=axis))

@pytest.mark.parametrize("op", INDEX_REDUCTION_OPS)
Expand Down

0 comments on commit d7ca013

Please sign in to comment.