Skip to content

Commit

Permalink
Add min_weight param to rolling_exp functions (#8285)
Browse files Browse the repository at this point in the history
* Add `min_weight` param to `rolling_exp` functions

* whatsnew
  • Loading branch information
max-sixty authored Oct 14, 2023
1 parent 8f7e8b5 commit dafd726
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 24 deletions.
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ New Features
the ``other`` parameter, passing the object as the only argument. Previously,
this was only valid for the ``cond`` parameter. (:issue:`8255`)
By `Maximilian Roos <https://github.com/max-sixty>`_.
- ``.rolling_exp`` functions can now take a ``min_weight`` parameter, to only
output values when there are sufficient recent non-nan values.
``numbagg>=0.3.1`` is required. (:pull:`8285`)
By `Maximilian Roos <https://github.com/max-sixty>`_.
- :py:meth:`DataArray.sortby` & :py:meth:`Dataset.sortby` accept a callable for
the ``variables`` parameter, passing the object as the only argument.
By `Maximilian Roos <https://github.com/max-sixty>`_.
Expand Down
55 changes: 31 additions & 24 deletions xarray/core/rolling_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
from xarray.core.computation import apply_ufunc
from xarray.core.options import _get_keep_attrs
from xarray.core.pdcompat import count_not_none
from xarray.core.pycompat import is_duck_dask_array
from xarray.core.types import T_DataWithCoords, T_DuckArray
from xarray.core.types import T_DataWithCoords

try:
import numbagg
from numbagg import move_exp_nanmean, move_exp_nansum

has_numbagg = numbagg.__version__
except ImportError:
has_numbagg = False


def _get_alpha(
Expand All @@ -25,26 +32,6 @@ def _get_alpha(
return 1 / (1 + com)


def move_exp_nanmean(array: T_DuckArray, *, axis: int, alpha: float) -> np.ndarray:
if is_duck_dask_array(array):
raise TypeError("rolling_exp is not currently support for dask-like arrays")
import numbagg

# No longer needed in numbag > 0.2.0; remove in time
if axis == ():
return array.astype(np.float64)
else:
return numbagg.move_exp_nanmean(array, axis=axis, alpha=alpha)


def move_exp_nansum(array: T_DuckArray, *, axis: int, alpha: float) -> np.ndarray:
if is_duck_dask_array(array):
raise TypeError("rolling_exp is not currently supported for dask-like arrays")
import numbagg

return numbagg.move_exp_nansum(array, axis=axis, alpha=alpha)


def _get_center_of_mass(
comass: float | None,
span: float | None,
Expand Down Expand Up @@ -110,11 +97,31 @@ def __init__(
obj: T_DataWithCoords,
windows: Mapping[Any, int | float],
window_type: str = "span",
min_weight: float = 0.0,
):
if has_numbagg is False:
raise ImportError(
"numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed"
)
elif has_numbagg < "0.2.1":
raise ImportError(
f"numbagg >= 0.2.1 is required for rolling_exp but currently version {has_numbagg} is installed"
)
elif has_numbagg < "0.3.1" and min_weight > 0:
raise ImportError(
f"numbagg >= 0.3.1 is required for `min_weight > 0` but currently version {has_numbagg} is installed"
)

self.obj: T_DataWithCoords = obj
dim, window = next(iter(windows.items()))
self.dim = dim
self.alpha = _get_alpha(**{window_type: window})
self.min_weight = min_weight
# Don't pass min_weight=0 so we can support older versions of numbagg
kwargs = dict(alpha=self.alpha, axis=-1)
if min_weight > 0:
kwargs["min_weight"] = min_weight
self.kwargs = kwargs

def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
"""
Expand Down Expand Up @@ -145,7 +152,7 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
move_exp_nanmean,
self.obj,
input_core_dims=[[self.dim]],
kwargs=dict(alpha=self.alpha, axis=-1),
kwargs=self.kwargs,
output_core_dims=[[self.dim]],
keep_attrs=keep_attrs,
on_missing_core_dim="copy",
Expand Down Expand Up @@ -181,7 +188,7 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
move_exp_nansum,
self.obj,
input_core_dims=[[self.dim]],
kwargs=dict(alpha=self.alpha, axis=-1),
kwargs=self.kwargs,
output_core_dims=[[self.dim]],
keep_attrs=keep_attrs,
on_missing_core_dim="copy",
Expand Down

0 comments on commit dafd726

Please sign in to comment.