Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add string.extract APIs to pylibcudf #16823

Merged
merged 6 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
=======
extract
=======

.. automodule:: pylibcudf.strings.extract
:members:
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ strings
capitalize
char_types
contains
extract
find
regex_flags
regex_program
Expand Down
34 changes: 6 additions & 28 deletions python/cudf/cudf/_lib/strings/extract.pyx
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

from cython.operator cimport dereference
from libc.stdint cimport uint32_t
from libcpp.memory cimport unique_ptr
from libcpp.string cimport string
from libcpp.utility cimport move

from cudf.core.buffer import acquire_spill_lock

from pylibcudf.libcudf.column.column_view cimport column_view
from pylibcudf.libcudf.strings.extract cimport extract as cpp_extract
from pylibcudf.libcudf.strings.regex_flags cimport regex_flags
from pylibcudf.libcudf.strings.regex_program cimport regex_program
from pylibcudf.libcudf.table.table cimport table

from cudf._lib.column cimport Column
from cudf._lib.utils cimport data_from_unique_ptr

import pylibcudf as plc


@acquire_spill_lock()
Expand All @@ -26,21 +17,8 @@ def extract(Column source_strings, object pattern, uint32_t flags):
The returning data contains one row for each subject string,
and one column for each group.
"""
cdef unique_ptr[table] c_result
cdef column_view source_view = source_strings.view()

cdef string pattern_string = <string>str(pattern).encode()
cdef regex_flags c_flags = <regex_flags>flags
cdef unique_ptr[regex_program] c_prog

with nogil:
c_prog = move(regex_program.create(pattern_string, c_flags))
c_result = move(cpp_extract(
source_view,
dereference(c_prog)
))

return data_from_unique_ptr(
move(c_result),
column_names=range(0, c_result.get()[0].num_columns())
prog = plc.strings.regex_program.RegexProgram.create(str(pattern), flags)
plc_result = plc.strings.extract.extract(
source_strings.to_pylibcudf(mode="read"), prog
)
return dict(enumerate(Column.from_pylibcudf(col) for col in plc_result.columns()))
6 changes: 2 additions & 4 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,11 +623,9 @@ def extract(
"unsupported value for `flags` parameter"
)

data, _ = libstrings.extract(self._column, pat, flags)
data = libstrings.extract(self._column, pat, flags)
if len(data) == 1 and expand is False:
data = next(iter(data.values()))
else:
data = data
_, data = data.popitem()
return self._return_or_inplace(data, expand=expand)

def contains(
Expand Down
8 changes: 6 additions & 2 deletions python/pylibcudf/pylibcudf/libcudf/strings/extract.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,9 @@ from pylibcudf.libcudf.table.table cimport table
cdef extern from "cudf/strings/extract.hpp" namespace "cudf::strings" nogil:

cdef unique_ptr[table] extract(
column_view source_strings,
regex_program) except +
column_view input,
regex_program prog) except +

cdef unique_ptr[column] extract_all_record(
column_view input,
regex_program prog) except +
4 changes: 2 additions & 2 deletions python/pylibcudf/pylibcudf/strings/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# the License.
# =============================================================================

set(cython_sources capitalize.pyx case.pyx char_types.pyx contains.pyx find.pyx regex_flags.pyx
regex_program.pyx repeat.pyx replace.pyx slice.pyx
set(cython_sources capitalize.pyx case.pyx char_types.pyx contains.pyx extract.pyx find.pyx
regex_flags.pyx regex_program.pyx repeat.pyx replace.pyx slice.pyx
)

set(linked_libraries cudf::cudf)
Expand Down
1 change: 1 addition & 0 deletions python/pylibcudf/pylibcudf/strings/__init__.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ from . cimport (
case,
char_types,
contains,
extract,
find,
regex_flags,
regex_program,
Expand Down
1 change: 1 addition & 0 deletions python/pylibcudf/pylibcudf/strings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
case,
char_types,
contains,
extract,
find,
regex_flags,
regex_program,
Expand Down
10 changes: 10 additions & 0 deletions python/pylibcudf/pylibcudf/strings/extract.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from pylibcudf.column cimport Column
from pylibcudf.strings.regex_program cimport RegexProgram
from pylibcudf.table cimport Table


cpdef Table extract(Column input, RegexProgram prog)

cpdef Column extract_all_record(Column input, RegexProgram prog)
76 changes: 76 additions & 0 deletions python/pylibcudf/pylibcudf/strings/extract.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr
from libcpp.utility cimport move
from pylibcudf.column cimport Column
from pylibcudf.libcudf.column.column cimport column
from pylibcudf.libcudf.strings cimport extract as cpp_extract
from pylibcudf.libcudf.table.table cimport table
from pylibcudf.strings.regex_program cimport RegexProgram
from pylibcudf.table cimport Table


cpdef Table extract(Column input, RegexProgram prog):
"""
Returns a table of strings columns where each column
corresponds to the matching group specified in the given
egex_program object.

For details, see :cpp:func:`cudf::strings::extract`.

Parameters
----------
input : Column
Strings instance for this operation
prog : RegexProgram
Regex program instance

Returns
-------
Table
Columns of strings extracted from the input column.
"""
cdef unique_ptr[table] c_result

with nogil:
c_result = move(
cpp_extract.extract(
input.view(),
prog.c_obj.get()[0]
)
)

return Table.from_libcudf(move(c_result))


cpdef Column extract_all_record(Column input, RegexProgram prog):
"""
Returns a lists column of strings where each string column
row corresponds to the matching group specified in the given
regex_program object.

For details, see :cpp:func:`cudf::strings::extract_all_record`.

Parameters
----------
input : Column
Strings instance for this operation
prog : RegexProgram
Regex program instance

Returns
-------
Column
Lists column containing strings extracted from the input column
"""
cdef unique_ptr[column] c_result

with nogil:
c_result = move(
cpp_extract.extract_all_record(
input.view(),
prog.c_obj.get()[0]
)
)

return Column.from_libcudf(move(c_result))
38 changes: 38 additions & 0 deletions python/pylibcudf/pylibcudf/tests/test_string_extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import pyarrow as pa
import pyarrow.compute as pc
import pylibcudf as plc


def test_extract():
pattern = "([ab])(\\d)"
pa_pattern = "(?P<letter>[ab])(?P<digit>\\d)"
arr = pa.array(["a1", "b2", "c3"])
plc_result = plc.strings.extract.extract(
plc.interop.from_arrow(arr),
plc.strings.regex_program.RegexProgram.create(
pattern, plc.strings.regex_flags.RegexFlags.DEFAULT
),
)
result = plc.interop.to_arrow(plc_result)
expected = pc.extract_regex(arr, pa_pattern)
for i, result_col in enumerate(result.itercolumns()):
expected_col = pa.chunked_array(expected.field(i))
assert result_col.fill_null("").equals(expected_col)


def test_extract_all_record():
pattern = "([ab])(\\d)"
arr = pa.array(["a1", "b2", "c3"])
plc_result = plc.strings.extract.extract_all_record(
plc.interop.from_arrow(arr),
plc.strings.regex_program.RegexProgram.create(
pattern, plc.strings.regex_flags.RegexFlags.DEFAULT
),
)
result = plc.interop.to_arrow(plc_result)
expected = pa.chunked_array(
[pa.array([["a", "1"], ["b", "2"], None], type=result.type)]
)
assert result.equals(expected)
Loading