From f6e1a005f0139d2b2bdc27c5faf6a6dbbf85b1c0 Mon Sep 17 00:00:00 2001 From: Amanda Potts Date: Fri, 15 Nov 2024 16:32:52 -0500 Subject: [PATCH] Closes #3864 max and min of bool to return bool like numpy --- src/ReductionMsg.chpl | 12 ++++++------ tests/pdarrayclass_test.py | 8 -------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/src/ReductionMsg.chpl b/src/ReductionMsg.chpl index 7b3311eb42..c8b594ab9a 100644 --- a/src/ReductionMsg.chpl +++ b/src/ReductionMsg.chpl @@ -92,7 +92,7 @@ module ReductionMsg } @arkouda.registerCommand - proc maxAll(const ref x:[?d] ?t, skipNan: bool): reductionReturnType(t) throws + proc maxAll(const ref x:[?d] ?t, skipNan: bool): t throws where t==int || t==real || t==uint(64) || t==bool { use SliceReductionOps; @@ -100,10 +100,10 @@ module ReductionMsg } @arkouda.registerCommand - proc max(const ref x:[?d] ?t, axis: list(int), skipNan: bool): [] reductionReturnType(t) throws + proc max(const ref x:[?d] ?t, axis: list(int), skipNan: bool): [] t throws where t==int || t==real || t==uint(64) || t==bool { use SliceReductionOps; - type opType = reductionReturnType(t); + type opType = t; const (valid, axes) = validateNegativeAxes(axis, x.rank); if !valid { throw new Error("Invalid axis value(s) '%?' in slicing reduction".format(axis)); @@ -117,7 +117,7 @@ module ReductionMsg } @arkouda.registerCommand - proc minAll(const ref x:[?d] ?t, skipNan: bool): reductionReturnType(t) throws + proc minAll(const ref x:[?d] ?t, skipNan: bool): t throws where t==int || t==real || t==uint(64) || t==bool { use SliceReductionOps; @@ -125,10 +125,10 @@ module ReductionMsg } @arkouda.registerCommand - proc min(const ref x:[?d] ?t, axis: list(int), skipNan: bool): [] reductionReturnType(t) throws + proc min(const ref x:[?d] ?t, axis: list(int), skipNan: bool): [] t throws where t==int || t==real || t==uint(64) || t==bool { use SliceReductionOps; - type opType = reductionReturnType(t); + type opType = t; const (valid, axes) = validateNegativeAxes(axis, x.rank); if !valid { throw new Error("Invalid axis value(s) '%?' in slicing reduction".format(axis)); diff --git a/tests/pdarrayclass_test.py b/tests/pdarrayclass_test.py index bef07a396a..e33776a501 100644 --- a/tests/pdarrayclass_test.py +++ b/tests/pdarrayclass_test.py @@ -16,7 +16,6 @@ # TODO: add unint8 to DTYPES - class TestPdarrayClass: @pytest.mark.skip_if_max_rank_less_than(2) @@ -139,14 +138,7 @@ def assert_reduction_ops_match( np_op = getattr(numpy, op) nda = pda.to_ndarray() - # TODO: remove cast when #3864 is resolved. ak_result = ak_op(pda, axis=axis) - if op in ["max", "min"] and pda.dtype == ak.bool_: - if isinstance(ak_result, ak.pdarray): - ak_result = ak.cast(ak_result, dt=ak.bool_) - else: - ak_result = np.bool_(ak_result) - ak_assert_equivalent(ak_result, np_op(nda, axis=axis)) @pytest.mark.parametrize("op", REDUCTION_OPS)