From d7ca0133412ae394eaed9bbc405308040dddb83c Mon Sep 17 00:00:00 2001 From: ajpotts Date: Tue, 3 Dec 2024 15:45:36 -0500 Subject: [PATCH] Closes #3864 max and min of bool to return bool like numpy (#3901) Co-authored-by: Amanda Potts --- src/ReductionMsg.chpl | 12 ++++++------ tests/pdarrayclass_test.py | 8 -------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/src/ReductionMsg.chpl b/src/ReductionMsg.chpl index 588b01e2e8..28ec901d52 100644 --- a/src/ReductionMsg.chpl +++ b/src/ReductionMsg.chpl @@ -92,7 +92,7 @@ module ReductionMsg } @arkouda.registerCommand - proc maxAll(const ref x:[?d] ?t, skipNan: bool): reductionReturnType(t) throws + proc maxAll(const ref x:[?d] ?t, skipNan: bool): t throws where t==int || t==real || t==uint(64) || t==bool { use SliceReductionOps; @@ -100,10 +100,10 @@ module ReductionMsg } @arkouda.registerCommand - proc max(const ref x:[?d] ?t, axis: list(int), skipNan: bool): [] reductionReturnType(t) throws + proc max(const ref x:[?d] ?t, axis: list(int), skipNan: bool): [] t throws where t==int || t==real || t==uint(64) || t==bool { use SliceReductionOps; - type opType = reductionReturnType(t); + type opType = t; const (valid, axes) = validateNegativeAxes(axis, x.rank); if !valid { throw new Error("Invalid axis value(s) '%?' in slicing reduction".format(axis)); @@ -117,7 +117,7 @@ module ReductionMsg } @arkouda.registerCommand - proc minAll(const ref x:[?d] ?t, skipNan: bool): reductionReturnType(t) throws + proc minAll(const ref x:[?d] ?t, skipNan: bool): t throws where t==int || t==real || t==uint(64) || t==bool { use SliceReductionOps; @@ -125,10 +125,10 @@ module ReductionMsg } @arkouda.registerCommand - proc min(const ref x:[?d] ?t, axis: list(int), skipNan: bool): [] reductionReturnType(t) throws + proc min(const ref x:[?d] ?t, axis: list(int), skipNan: bool): [] t throws where t==int || t==real || t==uint(64) || t==bool { use SliceReductionOps; - type opType = reductionReturnType(t); + type opType = t; const (valid, axes) = validateNegativeAxes(axis, x.rank); if !valid { throw new Error("Invalid axis value(s) '%?' in slicing reduction".format(axis)); diff --git a/tests/pdarrayclass_test.py b/tests/pdarrayclass_test.py index 76351b17bc..d017d27d5b 100644 --- a/tests/pdarrayclass_test.py +++ b/tests/pdarrayclass_test.py @@ -19,7 +19,6 @@ # TODO: add unint8 to DTYPES - class TestPdarrayClass: @pytest.mark.skip_if_max_rank_less_than(2) @@ -141,14 +140,7 @@ def assert_reduction_ops_match( np_op = getattr(numpy, op) nda = pda.to_ndarray() - # TODO: remove cast when #3864 is resolved. ak_result = ak_op(pda, axis=axis) - if op in ["max", "min"] and pda.dtype == ak.bool_: - if isinstance(ak_result, ak.pdarray): - ak_result = ak.cast(ak_result, dt=ak.bool_) - else: - ak_result = np.bool_(ak_result) - ak_assert_equivalent(ak_result, np_op(nda, axis=axis)) @pytest.mark.parametrize("op", INDEX_REDUCTION_OPS)