Skip to content

Commit

Permalink
Fix count API issue about ignoring nan values (#17779)
Browse files Browse the repository at this point in the history
Fixes: #17768 

This PR fixes the `count` api to match default dtype behavior where `nan`'s are not counted as valid elements. But retains counting `nan`'s as valid in non-pandas compatibility modes because that is what pandas nullable dtypes do.

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Matthew Roeschke (https://github.com/mroeschke)

URL: #17779
  • Loading branch information
galipremsagar authored Jan 24, 2025
1 parent c57cb6e commit db7f1e3
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 4 deletions.
15 changes: 14 additions & 1 deletion python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6409,7 +6409,20 @@ def count(self, axis=0, numeric_only=False):
raise NotImplementedError("Only axis=0 is currently supported.")
length = len(self)
return Series._from_column(
as_column([length - col.null_count for col in self._columns]),
as_column(
[
length
- (
col.null_count
+ (
col.nan_count
if cudf.get_option("mode.pandas_compatible")
else 0
)
)
for col in self._columns
]
),
index=cudf.Index(self._column_names),
)

Expand Down
5 changes: 4 additions & 1 deletion python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2712,7 +2712,10 @@ def count(self):
Parameters currently not supported is `level`.
"""
return self.valid_count
valid_count = self.valid_count
if cudf.get_option("mode.pandas_compatible"):
return valid_count - self._column.nan_count
return valid_count

@_performance_tracking
def mode(self, dropna=True):
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/tests/input_output/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from cudf._fuzz_testing.utils import compare_dataframe


def test_parquet_long_list():
def test_parquet_long_list(tmpdir):
# This test generates int and string list columns, where each has a row that is very large.
# When generated by the cudf writer these long rows are contained on a single page,
# but when generated by pyarrow they span several pages.
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_parquet_long_list():
)

# Write the table to a parquet file using pyarrow
file_name = "long_row_list_test.pq"
file_name = tmpdir.join("long_row_list_test.pq")
# https://arrow.apache.org/docs/python/generated/pyarrow.parquet.write_table.html
pq.write_table(
generated_table,
Expand Down
15 changes: 15 additions & 0 deletions python/cudf/cudf/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3019,3 +3019,18 @@ def test_roundtrip_series_plc_column(ps):
expect = cudf.Series(ps)
actual = cudf.Series.from_pylibcudf(*expect.to_pylibcudf())
assert_eq(expect, actual)


def test_series_dataframe_count_float():
gs = cudf.Series([1, 2, 3, None, np.nan, 10], nan_as_null=False)
ps = cudf.Series([1, 2, 3, None, np.nan, 10])

with cudf.option_context("mode.pandas_compatible", True):
assert_eq(ps.count(), gs.count())
assert_eq(ps.to_frame().count(), gs.to_frame().count())
with cudf.option_context("mode.pandas_compatible", False):
assert_eq(gs.count(), gs.to_pandas(nullable=True).count())
assert_eq(
gs.to_frame().count(),
gs.to_frame().to_pandas(nullable=True).count(),
)

0 comments on commit db7f1e3

Please sign in to comment.