Skip to content

Commit

Permalink
Plumb pylibcudf strings contains_re through cudf_polars (#15918)
Browse files Browse the repository at this point in the history
This PR adds cudf-polars code for evaluating the `StringFunction.Contains` expression node.

Depends on #15880

Authors:
  - https://github.com/brandon-b-miller
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: #15918
  • Loading branch information
brandon-b-miller authored Jun 13, 2024
1 parent af09d3e commit 246d017
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 0 deletions.
51 changes: 51 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,13 +644,28 @@ def __init__(
self.options = options
self.name = name
self.children = children
self._validate_input()

def _validate_input(self):
if self.name not in (
pl_expr.StringFunction.Lowercase,
pl_expr.StringFunction.Uppercase,
pl_expr.StringFunction.EndsWith,
pl_expr.StringFunction.StartsWith,
pl_expr.StringFunction.Contains,
):
raise NotImplementedError(f"String function {self.name}")
if self.name == pl_expr.StringFunction.Contains:
literal, strict = self.options
if not literal:
if not strict:
raise NotImplementedError(
"f{strict=} is not supported for regex contains"
)
if not isinstance(self.children[1], Literal):
raise NotImplementedError(
"Regex contains only supports a scalar pattern"
)

def do_evaluate(
self,
Expand All @@ -660,6 +675,26 @@ def do_evaluate(
mapping: Mapping[Expr, Column] | None = None,
) -> Column:
"""Evaluate this expression given a dataframe for context."""
if self.name == pl_expr.StringFunction.Contains:
child, arg = self.children
column = child.evaluate(df, context=context, mapping=mapping)

literal, _ = self.options
if literal:
pat = arg.evaluate(df, context=context, mapping=mapping)
pattern = (
pat.obj_scalar
if pat.is_scalar and pat.obj.size() != column.obj.size()
else pat.obj
)
return Column(plc.strings.find.contains(column.obj, pattern))
else:
assert isinstance(arg, Literal)
prog = plc.strings.regex_program.RegexProgram.create(
arg.value.as_py(),
flags=plc.strings.regex_flags.RegexFlags.DEFAULT,
)
return Column(plc.strings.contains.contains_re(column.obj, prog))
columns = [
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
Expand Down Expand Up @@ -691,6 +726,22 @@ def do_evaluate(
)
)
else:
columns = [
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
]
if self.name == pl_expr.StringFunction.Lowercase:
(column,) = columns
return Column(plc.strings.case.to_lower(column.obj))
elif self.name == pl_expr.StringFunction.Uppercase:
(column,) = columns
return Column(plc.strings.case.to_upper(column.obj))
elif self.name == pl_expr.StringFunction.EndsWith:
column, suffix = columns
return Column(plc.strings.find.ends_with(column.obj, suffix.obj))
elif self.name == pl_expr.StringFunction.StartsWith:
column, suffix = columns
return Column(plc.strings.find.starts_with(column.obj, suffix.obj))
raise NotImplementedError(
f"StringFunction {self.name}"
) # pragma: no cover; handled by init raising
Expand Down
61 changes: 61 additions & 0 deletions python/cudf_polars/tests/test_string.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from functools import partial

import pytest

import polars as pl

from cudf_polars.callback import execute_with_cudf
from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.fixture
def ldf():
return pl.DataFrame(
{"a": ["AbC", "de", "FGHI", "j", "kLm", "nOPq", None, "RsT", None, "uVw"]}
).lazy()


@pytest.mark.parametrize(
"substr",
[
"A",
"de",
".*",
"^a",
"^A",
"[^a-z]",
"[a-z]{3,}",
"^[A-Z]{2,}",
"j|u",
],
)
def test_contains_regex(ldf, substr):
query = ldf.select(pl.col("a").str.contains(substr))
assert_gpu_result_equal(query)


@pytest.mark.parametrize(
"literal", ["A", "de", "FGHI", "j", "kLm", "nOPq", "RsT", "uVw"]
)
def test_contains_literal(ldf, literal):
query = ldf.select(pl.col("a").str.contains(pl.lit(literal), literal=True))
assert_gpu_result_equal(query)


def test_contains_column(ldf):
query = ldf.select(pl.col("a").str.contains(pl.col("a"), literal=True))
assert_gpu_result_equal(query)


@pytest.mark.parametrize("pat", ["["])
def test_contains_invalid(ldf, pat):
query = ldf.select(pl.col("a").str.contains(pat))

with pytest.raises(pl.exceptions.ComputeError):
query.collect()
with pytest.raises(pl.exceptions.ComputeError):
query.collect(post_opt_callback=partial(execute_with_cudf, raise_on_fail=True))

0 comments on commit 246d017

Please sign in to comment.