Skip to content

Commit

Permalink
Migrate NVText Byte Pair Encoding APIs to pylibcudf
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 committed Oct 16, 2024
1 parent 319ec3b commit dad5328
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
================
byte_pair_encode
================

.. automodule:: pylibcudf.nvtext.byte_pair_encode
:members:
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ nvtext
generate_ngrams
jaccard
minhash
byte_pair_encode
45 changes: 9 additions & 36 deletions python/cudf/cudf/_lib/nvtext/byte_pair_encode.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,22 @@

from cudf.core.buffer import acquire_spill_lock

from libcpp.memory cimport unique_ptr
from libcpp.utility cimport move

from pylibcudf.libcudf.column.column cimport column
from pylibcudf.libcudf.column.column_view cimport column_view
from pylibcudf.libcudf.nvtext.byte_pair_encode cimport (
bpe_merge_pairs as cpp_bpe_merge_pairs,
byte_pair_encoding as cpp_byte_pair_encoding,
load_merge_pairs as cpp_load_merge_pairs,
)
from pylibcudf.libcudf.scalar.scalar cimport string_scalar

from cudf._lib.column cimport Column
from cudf._lib.scalar cimport DeviceScalar


cdef class BPEMergePairs:
cdef unique_ptr[cpp_bpe_merge_pairs] c_obj

def __cinit__(self, Column merge_pairs):
cdef column_view c_pairs = merge_pairs.view()
with nogil:
self.c_obj = move(cpp_load_merge_pairs(c_pairs))
from pylibcudf import nvtext
from pylibcudf.nvtext.byte_pair_encode import BPEMergePairs # no-cython-lint


@acquire_spill_lock()
def byte_pair_encoding(
Column strings,
BPEMergePairs merge_pairs,
object merge_pairs,
object separator
):
cdef column_view c_strings = strings.view()
cdef DeviceScalar d_separator = separator.device_value
cdef const string_scalar* c_separator = <const string_scalar*>d_separator\
.get_raw_ptr()
cdef unique_ptr[column] c_result
with nogil:
c_result = move(
cpp_byte_pair_encoding(
c_strings,
merge_pairs.c_obj.get()[0],
c_separator[0]
)
return Column.from_pylibcudf(
nvtext.byte_pair_encode.byte_pair_encoding(
strings.to_pylibcudf(mode="read"),
merge_pairs,
separator.device_value.c_value
)

return Column.from_unique_ptr(move(c_result))
)
4 changes: 3 additions & 1 deletion python/cudf/cudf/core/byte_pair_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ class BytePairEncoder:
"""

def __init__(self, merges_pair: "cudf.Series"):
self.merge_pairs = cpp_merge_pairs(merges_pair._column)
self.merge_pairs = cpp_merge_pairs(
merges_pair._column.to_pylibcudf(mode="read")
)

def __call__(self, text, separator: str = " ") -> cudf.Series:
"""
Expand Down
4 changes: 3 additions & 1 deletion python/pylibcudf/pylibcudf/nvtext/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# the License.
# =============================================================================

set(cython_sources edit_distance.pyx generate_ngrams.pyx jaccard.pyx minhash.pyx)
set(cython_sources edit_distance.pyx generate_ngrams.pyx jaccard.pyx minhash.pyx
byte_pair_encode.pyx
)

set(linked_libraries cudf::cudf)
rapids_cython_create_modules(
Expand Down
11 changes: 9 additions & 2 deletions python/pylibcudf/pylibcudf/nvtext/__init__.pxd
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from . cimport edit_distance, generate_ngrams, jaccard, minhash
from . cimport (
byte_pair_encode,
edit_distance,
generate_ngrams,
jaccard,
minhash,
)

__all__ = [
"edit_distance",
"generate_ngrams",
"jaccard",
"minhash"
"minhash",
"byte_pair_encode"
]
9 changes: 8 additions & 1 deletion python/pylibcudf/pylibcudf/nvtext/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from . import edit_distance, generate_ngrams, jaccard, minhash
from . import (
byte_pair_encode,
edit_distance,
generate_ngrams,
jaccard,
minhash,
)

__all__ = [
"edit_distance",
"generate_ngrams",
"jaccard",
"minhash",
"byte_pair_encode",
]
18 changes: 18 additions & 0 deletions python/pylibcudf/pylibcudf/nvtext/byte_pair_encode.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr
from pylibcudf.column cimport Column
from pylibcudf.libcudf.nvtext.byte_pair_encode cimport bpe_merge_pairs
from pylibcudf.scalar cimport Scalar


cdef class BPEMergePairs:
cdef unique_ptr[bpe_merge_pairs] c_obj

cpdef Column byte_pair_encoding(
Column input,
BPEMergePairs merge_pairs,
Scalar separator=*
)

cpdef BPEMergePairs load_merge_pairs(Column input)
88 changes: 88 additions & 0 deletions python/pylibcudf/pylibcudf/nvtext/byte_pair_encode.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from cython.operator cimport dereference
from libcpp.memory cimport unique_ptr
from libcpp.utility cimport move
from pylibcudf.column cimport Column
from pylibcudf.libcudf.column.column cimport column
from pylibcudf.libcudf.column.column_view cimport column_view
from pylibcudf.libcudf.nvtext.byte_pair_encode cimport (
byte_pair_encoding as cpp_byte_pair_encoding,
load_merge_pairs as cpp_load_merge_pairs,
)
from pylibcudf.libcudf.scalar.scalar cimport string_scalar
from pylibcudf.libcudf.scalar.scalar_factories cimport (
make_string_scalar as cpp_make_string_scalar,
)
from pylibcudf.scalar cimport Scalar


cdef class BPEMergePairs:
"""The table of merge pairs for the BPE encoder.
For details, see :cpp:class:`cudf::nvtext::bpe_merge_pairs`.
"""
def __cinit__(self, Column merge_pairs):
cdef column_view c_pairs = merge_pairs.view()
with nogil:
self.c_obj = move(cpp_load_merge_pairs(c_pairs))

cpdef Column byte_pair_encoding(
Column input,
BPEMergePairs merge_pairs,
Scalar separator=None
):
"""
Byte pair encode the input strings.
For details, see cpp:func:`cudf::nvtext::byte_pair_encoding`
Parameters
----------
input : Column
Strings to encode.
merge_pairs : BPEMergePairs
Substrings to rebuild each string on.
separator : Scalar
String used to build the output after encoding. Default is a space.
Returns
-------
Column
An encoded column of strings.
"""
cdef unique_ptr[column] c_result

if separator is None:
separator = Scalar.from_libcudf(
cpp_make_string_scalar(" ".encode())
)

with nogil:
c_result = move(
cpp_byte_pair_encoding(
input.view(),
dereference(merge_pairs.c_obj.get()),
dereference(<const string_scalar*>separator.c_obj.get()),
)
)

return Column.from_libcudf(move(c_result))

cpdef BPEMergePairs load_merge_pairs(Column input):
"""
Create a nvtext::bpe_merge_pairs from a strings column.
For details, see cpp:func:`cudf::nvtext::load_merge_pairs`
Parameters
----------
input : Column
Column containing the unique merge pairs.
Returns
-------
Column
An ``BPEMergePairs`` object
"""
return BPEMergePairs(input)
52 changes: 52 additions & 0 deletions python/pylibcudf/pylibcudf/tests/test_nvtext_byte_pair_encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import pyarrow as pa
import pylibcudf as plc
import pytest
from utils import assert_column_eq


@pytest.fixture(scope="module")
def input_col():
return pa.array(
[
"e n",
"i t",
"i s",
"e s",
"en t",
"c e",
"es t",
"en ce",
"t est",
"s ent",
]
)


@pytest.mark.parametrize(
"separator", [None, plc.interop.from_arrow(pa.scalar("e"))]
)
def test_byte_pair_encoding(input_col, separator):
plc_col = plc.interop.from_arrow(
pa.array(["test sentence", "thisis test"])
)
result = plc.nvtext.byte_pair_encode.byte_pair_encoding(
plc_col,
plc.nvtext.byte_pair_encode.load_merge_pairs(
plc.interop.from_arrow(input_col)
),
separator,
)
if separator is None:
expected = pa.array(["test sent ence", "t h is is test"])
else:
expected = pa.array(["teste esenteence", "teheiseise etest"])
assert_column_eq(result, expected)


def test_load_merge_pairs(input_col):
result = plc.nvtext.byte_pair_encode.load_merge_pairs(
plc.interop.from_arrow(input_col)
)
assert isinstance(result, plc.nvtext.byte_pair_encode.BPEMergePairs)

0 comments on commit dad5328

Please sign in to comment.