Skip to content

Commit

Permalink
Deprecate misspelled aggegrate_func in favour of aggregate_func (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Apr 12, 2024
1 parent c2283f3 commit 30385c1
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
2 changes: 1 addition & 1 deletion cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def mean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None
x,
_mean_func,
combine_func=_mean_combine,
aggegrate_func=_mean_aggregate,
aggregate_func=_mean_aggregate,
axis=axis,
intermediate_dtype=intermediate_dtype,
dtype=dtype,
Expand Down
25 changes: 17 additions & 8 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from numbers import Integral, Number
from operator import add
from typing import TYPE_CHECKING, Any, Sequence, Union
from warnings import warn

import numpy as np
import zarr
Expand Down Expand Up @@ -886,7 +887,8 @@ def reduction(
x: "Array",
func,
combine_func=None,
aggegrate_func=None,
aggegrate_func=None, # typo, will removed in next release
aggregate_func=None,
axis=None,
intermediate_dtype=None,
dtype=None,
Expand All @@ -896,12 +898,19 @@ def reduction(
extra_func_kwargs=None,
) -> "Array":
"""Apply a function to reduce an array along one or more axes."""
if aggegrate_func is not None and aggregate_func is None:
warn(
"`aggegrate_func` is deprecated, please use `aggregate_func` instead",
DeprecationWarning,
stacklevel=2,
)
aggregate_func = aggegrate_func
if use_new_impl:
return reduction_new(
x,
func,
combine_func,
aggegrate_func,
aggregate_func,
axis,
intermediate_dtype,
dtype,
Expand Down Expand Up @@ -983,8 +992,8 @@ def reduction(
extra_func_kwargs=extra_func_kwargs,
)

if aggegrate_func is not None:
result = map_blocks(aggegrate_func, result, dtype=dtype)
if aggregate_func is not None:
result = map_blocks(aggregate_func, result, dtype=dtype)

if not keepdims:
axis_to_squeeze = tuple(i for i in axis if result.shape[i] == 1)
Expand All @@ -1002,7 +1011,7 @@ def reduction_new(
x: "Array",
func,
combine_func=None,
aggegrate_func=None,
aggregate_func=None,
axis=None,
intermediate_dtype=None,
dtype=None,
Expand Down Expand Up @@ -1043,8 +1052,8 @@ def reduction_new(
)

# aggregate final chunks
if aggegrate_func is not None:
result = map_blocks(aggegrate_func, result, dtype=dtype)
if aggregate_func is not None:
result = map_blocks(aggregate_func, result, dtype=dtype)

if not keepdims:
axis_to_squeeze = tuple(i for i in axis if result.shape[i] == 1)
Expand Down Expand Up @@ -1213,7 +1222,7 @@ def arg_reduction(
out,
_arg_func,
combine_func=partial(_arg_combine, arg_func=arg_func),
aggegrate_func=_arg_aggregate,
aggregate_func=_arg_aggregate,
axis=axis,
intermediate_dtype=intermediate_dtype,
dtype=dtype,
Expand Down
2 changes: 1 addition & 1 deletion cubed/nan_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def nanmean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=N
x,
_nanmean_func,
combine_func=_nanmean_combine,
aggegrate_func=_nanmean_aggregate,
aggregate_func=_nanmean_aggregate,
axis=axis,
intermediate_dtype=intermediate_dtype,
dtype=dtype,
Expand Down

0 comments on commit 30385c1

Please sign in to comment.