Skip to content

Commit

Permalink
Fix various .str methods for pandas compatability (rapidsai#17782)
Browse files Browse the repository at this point in the history
closes rapidsai#17745
closes rapidsai#17748
closes rapidsai#17749
closes rapidsai#17751

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: rapidsai#17782
  • Loading branch information
mroeschke authored Jan 24, 2025
1 parent db7f1e3 commit e9bfab5
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 22 deletions.
39 changes: 26 additions & 13 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,10 @@ def replace(
if regex and isinstance(pat, re.Pattern):
pat = pat.pattern

pa_repl = pa.scalar(repl)
if not pa.types.is_string(pa_repl.type):
raise TypeError(f"repl must be a str, not {type(repl).__name__}.")

# Pandas forces non-regex replace when pat is a single-character
with acquire_spill_lock():
if regex is True and len(pat) > 1:
Expand All @@ -1076,14 +1080,14 @@ def replace(
plc.strings.regex_program.RegexProgram.create(
pat, plc.strings.regex_flags.RegexFlags.DEFAULT
),
pa_scalar_to_plc_scalar(pa.scalar(repl)),
pa_scalar_to_plc_scalar(pa_repl),
n,
)
else:
plc_result = plc.strings.replace.replace(
self._column.to_pylibcudf(mode="read"),
pa_scalar_to_plc_scalar(pa.scalar(pat)),
pa_scalar_to_plc_scalar(pa.scalar(repl)),
pa_scalar_to_plc_scalar(pa_repl),
n,
)
result = Column.from_pylibcudf(plc_result)
Expand Down Expand Up @@ -2416,13 +2420,19 @@ def get(self, i: int = 0) -> SeriesOrIndex:
2 f
dtype: object
"""
str_lens = self.len()
if i < 0:
next_index = i - 1
step = -1
to_mask = str_lens < abs(i) # type: ignore[operator]
else:
next_index = i + 1
step = 1
return self.slice(i, next_index, step)
to_mask = str_lens <= i # type: ignore[operator]
result = self.slice(i, next_index, step)
if to_mask.any(): # type: ignore[union-attr]
result[to_mask] = cudf.NA # type: ignore[index]
return result

def get_json_object(
self,
Expand Down Expand Up @@ -3933,27 +3943,26 @@ def isspace(self) -> SeriesOrIndex:
def _starts_ends_with(
self,
method: Callable[[plc.Column, plc.Column | plc.Scalar], plc.Column],
pat: str | Sequence,
pat: str | tuple[str, ...],
) -> SeriesOrIndex:
if pat is None:
raise TypeError(
f"expected a string or a sequence-like object, not "
f"{type(pat).__name__}"
)
elif is_scalar(pat):
if isinstance(pat, str):
plc_pat = pa_scalar_to_plc_scalar(pa.scalar(pat, type=pa.string()))
else:
elif isinstance(pat, tuple) and all(isinstance(p, str) for p in pat):
plc_pat = column.as_column(pat, dtype="str").to_pylibcudf(
mode="read"
)
else:
raise TypeError(
f"expected a string or tuple, not {type(pat).__name__}"
)
with acquire_spill_lock():
plc_result = method(
self._column.to_pylibcudf(mode="read"), plc_pat
)
result = Column.from_pylibcudf(plc_result)
return self._return_or_inplace(result)

def endswith(self, pat: str | Sequence) -> SeriesOrIndex:
def endswith(self, pat: str | tuple[str, ...]) -> SeriesOrIndex:
"""
Test if the end of each string element matches a pattern.
Expand Down Expand Up @@ -3997,7 +4006,7 @@ def endswith(self, pat: str | Sequence) -> SeriesOrIndex:
"""
return self._starts_ends_with(plc.strings.find.ends_with, pat)

def startswith(self, pat: str | Sequence) -> SeriesOrIndex:
def startswith(self, pat: str | tuple[str, ...]) -> SeriesOrIndex:
"""
Test if the start of each string element matches a pattern.
Expand Down Expand Up @@ -4299,6 +4308,8 @@ def index(

if (result == -1).any():
raise ValueError("substring not found")
elif cudf.get_option("mode.pandas_compatible"):
return result.astype(np.dtype(np.int64))
else:
return result

Expand Down Expand Up @@ -4359,6 +4370,8 @@ def rindex(

if (result == -1).any():
raise ValueError("substring not found")
elif cudf.get_option("mode.pandas_compatible"):
return result.astype(np.dtype(np.int64))
else:
return result

Expand Down
54 changes: 45 additions & 9 deletions python/cudf/cudf/tests/test_string.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018-2024, NVIDIA CORPORATION.
# Copyright (c) 2018-2025, NVIDIA CORPORATION.

import json
import re
Expand Down Expand Up @@ -2048,26 +2048,26 @@ def test_string_starts_ends(data, pat):
[
(
["abc", "xyz", "a", "ab", "123", "097"],
["abc", "x", "a", "b", "3", "7"],
("abc", "x", "a", "b", "3", "7"),
),
(["A B", "1.5", "3,000"], ["A ", ".", ","]),
(["23", "³", "⅕", ""], ["23", "³", "⅕", ""]),
([" ", "\t\r\n ", ""], ["d", "\n ", ""]),
(["A B", "1.5", "3,000"], ("A ", ".", ",")),
(["23", "³", "⅕", ""], ("23", "³", "⅕", "")),
([" ", "\t\r\n ", ""], ("d", "\n ", "")),
(
["$", "B", "Aab$", "$$ca", "C$B$", "cat"],
["$", "$", "a", "<", "(", "#"],
("$", "$", "a", "<", "(", "#"),
),
(
["line to be wrapped", "another line to be wrapped"],
["another", "wrapped"],
("another", "wrapped"),
),
(
["hello", "there", "world", "+1234", "-1234", None, "accént", ""],
["hsdjfk", None, "ll", "+", "-", "w", "-", "én"],
("hsdjfk", "", "ll", "+", "-", "w", "-", "én"),
),
(
["1. Ant. ", "2. Bee!\n", "3. Cat?\t", None],
["1. Ant. ", "2. Bee!\n", "3. Cat?\t", None],
("1. Ant. ", "2. Bee!\n", "3. Cat?\t", ""),
),
],
)
Expand Down Expand Up @@ -3539,3 +3539,39 @@ def test_string_reduction_error():
lfunc_args_and_kwargs=([], {"skipna": False}),
rfunc_args_and_kwargs=([], {"skipna": False}),
)


def test_getitem_out_of_bounds():
data = ["123", "12", "1"]
pd_ser = pd.Series(data)
cudf_ser = cudf.Series(data)
expected = pd_ser.str[2]
result = cudf_ser.str[2]
assert_eq(result, expected)

expected = pd_ser.str[-2]
result = cudf_ser.str[-2]
assert_eq(result, expected)


@pytest.mark.parametrize("method", ["startswith", "endswith"])
@pytest.mark.parametrize("pat", [None, (1, 2), pd.Series([1])])
def test_startsendwith_invalid_pat(method, pat):
ser = cudf.Series(["1"])
with pytest.raises(TypeError):
getattr(ser.str, method)(pat)


@pytest.mark.parametrize("method", ["rindex", "index"])
def test_index_int64_pandas_compat(method):
data = ["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF"]
with cudf.option_context("mode.pandas_compatible", True):
result = getattr(cudf.Series(data).str, method)("E", 4, 8)
expected = getattr(pd.Series(data).str, method)("E", 4, 8)
assert_eq(result, expected)


def test_replace_invalid_scalar_repl():
ser = cudf.Series(["1"])
with pytest.raises(TypeError):
ser.str.replace("1", 2)

0 comments on commit e9bfab5

Please sign in to comment.