Skip to content

Commit

Permalink
Fix mean intermediate dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Nov 15, 2023
1 parent 41dcd05 commit d28db0a
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def mean(x, /, *, axis=None, keepdims=False):
# outputs.
dtype = x.dtype
intermediate_dtype = [("n", np.int64), ("total", np.float64)]
extra_func_kwargs = dict(dtype=intermediate_dtype)
return reduction(
x,
_mean_func,
Expand All @@ -44,18 +45,21 @@ def mean(x, /, *, axis=None, keepdims=False):
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
extra_func_kwargs=extra_func_kwargs,
)


def _mean_func(a, **kwargs):
n = _numel(a, **kwargs)
total = np.sum(a, **kwargs)
dtype = dict(kwargs.pop("dtype"))
n = _numel(a, dtype=dtype["n"], **kwargs)
total = np.sum(a, dtype=dtype["total"], **kwargs)
return {"n": n, "total": total}


def _mean_combine(a, **kwargs):
n = np.sum(a["n"], **kwargs)
total = np.sum(a["total"], **kwargs)
dtype = dict(kwargs.pop("dtype"))
n = np.sum(a["n"], dtype=dtype["n"], **kwargs)
total = np.sum(a["total"], dtype=dtype["total"], **kwargs)
return {"n": n, "total": total}


Expand Down

0 comments on commit d28db0a

Please sign in to comment.