From d28db0a0562ce91c06768b260974f8a9f00dff79 Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 15 Nov 2023 10:12:39 +0000 Subject: [PATCH] Fix mean intermediate dtypes --- cubed/array_api/statistical_functions.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/cubed/array_api/statistical_functions.py b/cubed/array_api/statistical_functions.py index d21175103..b4ff76946 100644 --- a/cubed/array_api/statistical_functions.py +++ b/cubed/array_api/statistical_functions.py @@ -35,6 +35,7 @@ def mean(x, /, *, axis=None, keepdims=False): # outputs. dtype = x.dtype intermediate_dtype = [("n", np.int64), ("total", np.float64)] + extra_func_kwargs = dict(dtype=intermediate_dtype) return reduction( x, _mean_func, @@ -44,18 +45,21 @@ def mean(x, /, *, axis=None, keepdims=False): intermediate_dtype=intermediate_dtype, dtype=dtype, keepdims=keepdims, + extra_func_kwargs=extra_func_kwargs, ) def _mean_func(a, **kwargs): - n = _numel(a, **kwargs) - total = np.sum(a, **kwargs) + dtype = dict(kwargs.pop("dtype")) + n = _numel(a, dtype=dtype["n"], **kwargs) + total = np.sum(a, dtype=dtype["total"], **kwargs) return {"n": n, "total": total} def _mean_combine(a, **kwargs): - n = np.sum(a["n"], **kwargs) - total = np.sum(a["total"], **kwargs) + dtype = dict(kwargs.pop("dtype")) + n = np.sum(a["n"], dtype=dtype["n"], **kwargs) + total = np.sum(a["total"], dtype=dtype["total"], **kwargs) return {"n": n, "total": total}