-
Notifications
You must be signed in to change notification settings - Fork 916
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add BytePairEncoder class to cuDF (#13891)
Adds a new BytePairEncoding class to cuDF ``` >>> import cudf >>> from cudf.core.byte_pair_encoding import BytePairEncoder >>> mps = cudf.read_text('merges.txt', delimiter='\n', strip_delimiters=True) >>> bpe = BytePairEncoder(mps) >>> str_series = cudf.Series(['This is a sentence', 'thisisit']) >>> bpe(str_series) 0 This is a sent ence 1 this is it dtype: object ``` This class wraps the existing `nvtext::byte_pair_encoding` APIs to load the merge-pairs data and encode a column of strings. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Bradley Dice (https://github.com/bdice) URL: #13891
- Loading branch information
1 parent
7f3fba1
commit b0c1b7b
Showing
5 changed files
with
176 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
|
||
from libcpp.memory cimport unique_ptr | ||
from libcpp.string cimport string | ||
|
||
from cudf._lib.cpp.column.column cimport column | ||
from cudf._lib.cpp.column.column_view cimport column_view | ||
from cudf._lib.cpp.scalar.scalar cimport string_scalar | ||
|
||
|
||
cdef extern from "nvtext/byte_pair_encoding.hpp" namespace "nvtext" nogil: | ||
|
||
cdef struct bpe_merge_pairs "nvtext::bpe_merge_pairs": | ||
pass | ||
|
||
cdef unique_ptr[bpe_merge_pairs] load_merge_pairs( | ||
const column_view &merge_pairs | ||
) except + | ||
|
||
cdef unique_ptr[column] byte_pair_encoding( | ||
const column_view &strings, | ||
const bpe_merge_pairs &merge_pairs, | ||
const string_scalar &separator | ||
) except + |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
|
||
|
||
from cudf.core.buffer import acquire_spill_lock | ||
|
||
from libcpp.memory cimport unique_ptr | ||
from libcpp.utility cimport move | ||
|
||
from cudf._lib.column cimport Column | ||
from cudf._lib.cpp.column.column cimport column | ||
from cudf._lib.cpp.column.column_view cimport column_view | ||
from cudf._lib.cpp.nvtext.byte_pair_encode cimport ( | ||
bpe_merge_pairs as cpp_bpe_merge_pairs, | ||
byte_pair_encoding as cpp_byte_pair_encoding, | ||
load_merge_pairs as cpp_load_merge_pairs, | ||
) | ||
from cudf._lib.cpp.scalar.scalar cimport string_scalar | ||
from cudf._lib.scalar cimport DeviceScalar | ||
|
||
|
||
cdef class BPEMergePairs: | ||
cdef unique_ptr[cpp_bpe_merge_pairs] c_obj | ||
|
||
def __cinit__(self, Column merge_pairs): | ||
cdef column_view c_pairs = merge_pairs.view() | ||
with nogil: | ||
self.c_obj = move(cpp_load_merge_pairs(c_pairs)) | ||
|
||
|
||
@acquire_spill_lock() | ||
def byte_pair_encoding( | ||
Column strings, | ||
BPEMergePairs merge_pairs, | ||
object separator | ||
): | ||
cdef column_view c_strings = strings.view() | ||
cdef DeviceScalar d_separator = separator.device_value | ||
cdef const string_scalar* c_separator = <const string_scalar*>d_separator\ | ||
.get_raw_ptr() | ||
cdef unique_ptr[column] c_result | ||
with nogil: | ||
c_result = move( | ||
cpp_byte_pair_encoding( | ||
c_strings, | ||
merge_pairs.c_obj.get()[0], | ||
c_separator[0] | ||
) | ||
) | ||
|
||
return Column.from_unique_ptr(move(c_result)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
|
||
from __future__ import annotations | ||
|
||
import cudf | ||
from cudf._lib.nvtext.byte_pair_encode import ( | ||
BPEMergePairs as cpp_merge_pairs, | ||
byte_pair_encoding as cpp_byte_pair_encoding, | ||
) | ||
|
||
|
||
class BytePairEncoder: | ||
""" | ||
Given a merge pairs strings series, performs byte pair encoding on | ||
a strings series using the provided separator. | ||
Parameters | ||
---------- | ||
merges_pairs : str | ||
Strings column of merge pairs | ||
Returns | ||
------- | ||
BytePairEncoder | ||
""" | ||
|
||
def __init__(self, merges_pair: "cudf.Series"): | ||
self.merge_pairs = cpp_merge_pairs(merges_pair._column) | ||
|
||
def __call__(self, text, separator: str = " "): | ||
""" | ||
Parameters | ||
---------- | ||
text : cudf string series | ||
The strings to be encoded. | ||
Returns | ||
------- | ||
Encoded strings | ||
Examples | ||
-------- | ||
>>> import cudf | ||
>>> from cudf.core.byte_pair_encoding import BytePairEncoder | ||
>>> mps = cudf.Series(["e n", "i t", "i s", "e s", "en t", | ||
... "c e", "es t", "en ce", "T h", "Th is", | ||
... "t est", "s ent", "t h", "th is"]) | ||
>>> bpe = BytePairEncoder(mps) | ||
>>> str_series = cudf.Series(['This is the sentence', 'thisisit']) | ||
>>> bpe(str_series) | ||
0 This is a sent ence | ||
1 this is it | ||
dtype: object | ||
""" | ||
sep = cudf.Scalar(separator, dtype="str") | ||
result = cpp_byte_pair_encoding(text._column, self.merge_pairs, sep) | ||
|
||
return cudf.Series(result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters