Skip to content

Commit

Permalink
Add cudf::strings::find_re API (#16742)
Browse files Browse the repository at this point in the history
Adds the `cudf::strings::find_re` and `str.find_re` API to libcudf/pylibcudf/cudf. This function returns the character position where the pattern first matches in each row of the input column. If a match is not found, -1 is returned for that corresponding row.

Closes #16729

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - Matthew Murray (https://github.com/Matt711)
  - Bradley Dice (https://github.com/bdice)

URL: #16742
  • Loading branch information
davidwendt authored Oct 3, 2024
1 parent 2ec6cb3 commit 3faa3ee
Show file tree
Hide file tree
Showing 13 changed files with 237 additions and 7 deletions.
1 change: 1 addition & 0 deletions cpp/doxygen/regex.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This page specifies which regular expression (regex) features are currently supp
- cudf::strings::extract()
- cudf::strings::extract_all_record()
- cudf::strings::findall()
- cudf::strings::find_re()
- cudf::strings::replace_re()
- cudf::strings::replace_with_backrefs()
- cudf::strings::split_re()
Expand Down
29 changes: 29 additions & 0 deletions cpp/include/cudf/strings/findall.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,35 @@ std::unique_ptr<column> findall(
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

/**
* @brief Returns the starting character index of the first match for the given pattern
* in each row of the input column
*
* @code{.pseudo}
* Example:
* s = ["bunny", "rabbit", "hare", "dog"]
* p = regex_program::create("[be]")
* r = find_re(s, p)
* r is now [0, 2, 3, -1]
* @endcode
*
* A null output row occurs if the corresponding input row is null.
* A -1 is returned for rows that do not contain a match.
*
* See the @ref md_regex "Regex Features" page for details on patterns supported by this API.
*
* @param input Strings instance for this operation
* @param prog Regex program instance
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned column's device memory
* @return New column of integers
*/
std::unique_ptr<column> find_re(
strings_column_view const& input,
regex_program const& prog,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

/** @} */ // end of doxygen group
} // namespace strings
} // namespace CUDF_EXPORT cudf
46 changes: 46 additions & 0 deletions cpp/src/strings/search/findall.cu
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,43 @@ std::unique_ptr<column> findall(strings_column_view const& input,
mr);
}

namespace {
struct find_re_fn {
column_device_view d_strings;

__device__ size_type operator()(size_type const idx,
reprog_device const prog,
int32_t const thread_idx) const
{
if (d_strings.is_null(idx)) { return 0; }
auto const d_str = d_strings.element<string_view>(idx);

auto const result = prog.find(thread_idx, d_str, d_str.begin());
return result.has_value() ? result.value().first : -1;
}
};
} // namespace

std::unique_ptr<column> find_re(strings_column_view const& input,
regex_program const& prog,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
auto results = make_numeric_column(data_type{type_to_id<size_type>()},
input.size(),
cudf::detail::copy_bitmask(input.parent(), stream, mr),
input.null_count(),
stream,
mr);
if (input.is_empty()) { return results; }

auto d_results = results->mutable_view().data<size_type>();
auto d_prog = regex_device_builder::create_prog_device(prog, stream);
auto const d_strings = column_device_view::create(input.parent(), stream);
launch_transform_kernel(find_re_fn{*d_strings}, *d_prog, d_results, input.size(), stream);

return results;
}
} // namespace detail

// external API
Expand All @@ -139,5 +176,14 @@ std::unique_ptr<column> findall(strings_column_view const& input,
return detail::findall(input, prog, stream, mr);
}

std::unique_ptr<column> find_re(strings_column_view const& input,
regex_program const& prog,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
return detail::find_re(input, prog, stream, mr);
}

} // namespace strings
} // namespace cudf
1 change: 1 addition & 0 deletions cpp/tests/streams/strings/find_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ TEST_F(StringsFindTest, Find)
auto const pattern = std::string("[a-z]");
auto const prog = cudf::strings::regex_program::create(pattern);
cudf::strings::findall(view, *prog, cudf::test::get_default_stream());
cudf::strings::find_re(view, *prog, cudf::test::get_default_stream());
}
35 changes: 29 additions & 6 deletions cpp/tests/strings/findall_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <cudf_test/base_fixture.hpp>
#include <cudf_test/column_utilities.hpp>
#include <cudf_test/column_wrapper.hpp>
#include <cudf_test/iterator_utilities.hpp>
#include <cudf_test/table_utilities.hpp>

#include <cudf/strings/findall.hpp>
Expand Down Expand Up @@ -149,6 +150,22 @@ TEST_F(StringsFindallTests, LargeRegex)
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected);
}

TEST_F(StringsFindallTests, FindTest)
{
auto const valids = cudf::test::iterators::null_at(5);
cudf::test::strings_column_wrapper input(
{"3A", "May4", "Jan2021", "March", "A9BC", "", "", "abcdef ghijklm 12345"}, valids);
auto sv = cudf::strings_column_view(input);

auto pattern = std::string("\\d+");

auto prog = cudf::strings::regex_program::create(pattern);
auto results = cudf::strings::find_re(sv, *prog);
auto expected =
cudf::test::fixed_width_column_wrapper<cudf::size_type>({0, 3, 3, -1, 1, 0, -1, 15}, valids);
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected);
}

TEST_F(StringsFindallTests, NoMatches)
{
cudf::test::strings_column_wrapper input({"abc\nfff\nabc", "fff\nabc\nlll", "abc", "", "abc\n"});
Expand All @@ -169,10 +186,16 @@ TEST_F(StringsFindallTests, EmptyTest)
auto prog = cudf::strings::regex_program::create(pattern);

cudf::test::strings_column_wrapper input;
auto sv = cudf::strings_column_view(input);
auto results = cudf::strings::findall(sv, *prog);

using LCW = cudf::test::lists_column_wrapper<cudf::string_view>;
LCW expected;
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected);
auto sv = cudf::strings_column_view(input);
{
auto results = cudf::strings::findall(sv, *prog);
using LCW = cudf::test::lists_column_wrapper<cudf::string_view>;
LCW expected;
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected);
}
{
auto results = cudf::strings::find_re(sv, *prog);
auto expected = cudf::test::fixed_width_column_wrapper<cudf::size_type>{};
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected);
}
}
2 changes: 1 addition & 1 deletion python/cudf/cudf/_lib/strings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
startswith_multiple,
)
from cudf._lib.strings.find_multiple import find_multiple
from cudf._lib.strings.findall import findall
from cudf._lib.strings.findall import find_re, findall
from cudf._lib.strings.json import GetJsonObjectOptions, get_json_object
from cudf._lib.strings.padding import center, ljust, pad, rjust, zfill
from cudf._lib.strings.repeat import repeat_scalar, repeat_sequence
Expand Down
16 changes: 16 additions & 0 deletions python/cudf/cudf/_lib/strings/findall.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,19 @@ def findall(Column source_strings, object pattern, uint32_t flags):
prog,
)
return Column.from_pylibcudf(plc_result)


@acquire_spill_lock()
def find_re(Column source_strings, object pattern, uint32_t flags):
"""
Returns character positions where the pattern first matches
the elements in source_strings.
"""
prog = plc.strings.regex_program.RegexProgram.create(
str(pattern), flags
)
plc_result = plc.strings.findall.find_re(
source_strings.to_pylibcudf(mode="read"),
prog,
)
return Column.from_pylibcudf(plc_result)
40 changes: 40 additions & 0 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -3626,6 +3626,46 @@ def findall(self, pat: str, flags: int = 0) -> SeriesOrIndex:
data = libstrings.findall(self._column, pat, flags)
return self._return_or_inplace(data)

def find_re(self, pat: str, flags: int = 0) -> SeriesOrIndex:
"""
Find first occurrence of pattern or regular expression in the
Series/Index.
Parameters
----------
pat : str
Pattern or regular expression.
flags : int, default 0 (no flags)
Flags to pass through to the regex engine (e.g. re.MULTILINE)
Returns
-------
Series
A Series of position values where the pattern first matches
each string.
Examples
--------
>>> import cudf
>>> s = cudf.Series(['Lion', 'Monkey', 'Rabbit', 'Cat'])
>>> s.str.find_re('[ti]')
0 1
1 -1
2 4
3 2
dtype: int32
"""
if isinstance(pat, re.Pattern):
flags = pat.flags & ~re.U
pat = pat.pattern
if not _is_supported_regex_flags(flags):
raise NotImplementedError(
"Unsupported value for `flags` parameter"
)

data = libstrings.find_re(self._column, pat, flags)
return self._return_or_inplace(data)

def find_multiple(self, patterns: SeriesOrIndex) -> cudf.Series:
"""
Find all first occurrences of patterns in the Series/Index.
Expand Down
20 changes: 20 additions & 0 deletions python/cudf/cudf/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,26 @@ def test_string_findall(pat, flags):
assert_eq(expected, actual)


@pytest.mark.parametrize(
"pat, flags, pos",
[
("Monkey", 0, [-1, 0, -1, -1]),
("on", 0, [2, 1, -1, 1]),
("bit", 0, [-1, -1, 3, -1]),
("on$", 0, [2, -1, -1, -1]),
("on$", re.MULTILINE, [2, -1, -1, 1]),
("o.*k", re.DOTALL, [-1, 1, -1, 1]),
],
)
def test_string_find_re(pat, flags, pos):
test_data = ["Lion", "Monkey", "Rabbit", "Don\nkey"]
gs = cudf.Series(test_data)

expected = pd.Series(pos, dtype=np.int32)
actual = gs.str.find_re(pat, flags)
assert_eq(expected, actual)


def test_string_replace_multi():
ps = pd.Series(["hello", "goodbye"])
gs = cudf.Series(["hello", "goodbye"])
Expand Down
4 changes: 4 additions & 0 deletions python/pylibcudf/pylibcudf/libcudf/strings/findall.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ cdef extern from "cudf/strings/findall.hpp" namespace "cudf::strings" nogil:
cdef unique_ptr[column] findall(
column_view input,
regex_program prog) except +

cdef unique_ptr[column] find_re(
column_view input,
regex_program prog) except +
1 change: 1 addition & 0 deletions python/pylibcudf/pylibcudf/strings/findall.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ from pylibcudf.column cimport Column
from pylibcudf.strings.regex_program cimport RegexProgram


cpdef Column find_re(Column input, RegexProgram pattern)
cpdef Column findall(Column input, RegexProgram pattern)
32 changes: 32 additions & 0 deletions python/pylibcudf/pylibcudf/strings/findall.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,35 @@ cpdef Column findall(Column input, RegexProgram pattern):
)

return Column.from_libcudf(move(c_result))


cpdef Column find_re(Column input, RegexProgram pattern):
"""
Returns character positions where the pattern first matches
the elements in input strings.
For details, see :cpp:func:`cudf::strings::find_re`
Parameters
----------
input : Column
Strings instance for this operation
pattern : RegexProgram
Regex pattern
Returns
-------
Column
New column of integers
"""
cdef unique_ptr[column] c_result

with nogil:
c_result = move(
cpp_findall.find_re(
input.view(),
pattern.c_obj.get()[0]
)
)

return Column.from_libcudf(move(c_result))
17 changes: 17 additions & 0 deletions python/pylibcudf/pylibcudf/tests/test_string_findall.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,20 @@ def test_findall():
type=pa_result.type,
)
assert_column_eq(result, expected)


def test_find_re():
arr = pa.array(["bunny", "rabbit", "hare", "dog"])
pattern = "[eb]"
result = plc.strings.findall.find_re(
plc.interop.from_arrow(arr),
plc.strings.regex_program.RegexProgram.create(
pattern, plc.strings.regex_flags.RegexFlags.DEFAULT
),
)
pa_result = plc.interop.to_arrow(result)
expected = pa.array(
[0, 2, 3, -1],
type=pa_result.type,
)
assert_column_eq(result, expected)

0 comments on commit 3faa3ee

Please sign in to comment.