Skip to content

Commit f0a3dfe

Browse files
authored
Fix rolling(min_periods=) with int and null data with mode.pandas_compat (rapidsai#17822)
closes rapidsai#17786 Authors: - Matthew Roeschke (https://github.com/mroeschke) - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) - GALI PREM SAGAR (https://github.com/galipremsagar) - Lawrence Mitchell (https://github.com/wence-) URL: rapidsai#17822
1 parent fa20521 commit f0a3dfe

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

python/cudf/cudf/core/window/rolling.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
if TYPE_CHECKING:
2525
from cudf.core.column.column import ColumnBase
26+
from cudf.core.indexed_frame import IndexedFrame
2627

2728

2829
class _RollingBase:
@@ -205,7 +206,7 @@ class Rolling(GetAttrGetItemMixin, _RollingBase, Reducible):
205206

206207
def __init__(
207208
self,
208-
obj,
209+
obj: IndexedFrame,
209210
window,
210211
min_periods=None,
211212
center: bool = False,
@@ -216,7 +217,9 @@ def __init__(
216217
step: int | None = None,
217218
method: str = "single",
218219
):
219-
self.obj = obj
220+
if cudf.get_option("mode.pandas_compatible"):
221+
obj = obj.nans_to_nulls()
222+
self.obj = obj # type: ignore[assignment]
220223
self.window = window
221224
self.min_periods = min_periods
222225
self.center = center

python/cudf/cudf/tests/test_rolling.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
1+
# Copyright (c) 2021-2025, NVIDIA CORPORATION.
22

33
import math
44

@@ -517,3 +517,16 @@ def test_rolling_series():
517517
actual = df.groupby("b")["a"].rolling(5).mean()
518518

519519
assert_eq(expected, actual)
520+
521+
522+
@pytest.mark.parametrize("klass", ["DataFrame", "Series"])
523+
def test_pandas_compat_int_nan_min_periods(klass):
524+
data = [None, 1, 2, None, 4, 6, 11]
525+
with cudf.option_context("mode.pandas_compatible", True):
526+
result = getattr(cudf, klass)(data).rolling(2, min_periods=1).sum()
527+
expected = getattr(pd, klass)(data).rolling(2, min_periods=1).sum()
528+
assert_eq(result, expected)
529+
530+
result = getattr(cudf, klass)(data).rolling(2, min_periods=1).sum()
531+
expected = getattr(cudf, klass)([None, 1, 3, 2, 4, 10, 17])
532+
assert_eq(result, expected)

0 commit comments

Comments
 (0)