Skip to content

Commit

Permalink
expose encodeSpecialTokens functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
Daulet Zhanguzin committed Apr 10, 2024
1 parent b14c306 commit 3a615e6
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 15 deletions.
11 changes: 9 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ use std::path::PathBuf;
use std::ptr;
use tokenizers::tokenizer::Tokenizer;

#[repr(C)]
pub struct TokenizerOptions {
encode_special_tokens: bool,
}

#[repr(C)]
pub struct Buffer {
ids: *mut u32,
Expand All @@ -14,12 +19,14 @@ pub struct Buffer {
}

#[no_mangle]
pub extern "C" fn from_bytes(bytes: *const u8, len: u32) -> *mut Tokenizer {
pub extern "C" fn from_bytes(bytes: *const u8, len: u32, opts: &TokenizerOptions) -> *mut Tokenizer {
let bytes_slice = unsafe { std::slice::from_raw_parts(bytes, len as usize) };
let tokenizer = Tokenizer::from_bytes(bytes_slice).expect("failed to create tokenizer");
let mut tokenizer = Tokenizer::from_bytes(bytes_slice).expect("failed to create tokenizer");
tokenizer.set_encode_special_tokens(opts.encode_special_tokens);
Box::into_raw(Box::new(tokenizer))
}

// TODO merge with from_bytes and pass truncation params as an argument to TokenizerOptions
#[no_mangle]
pub extern "C" fn from_bytes_with_truncation(bytes: *const u8, len: u32, max_len: usize, dir: u8) -> *mut Tokenizer {
let bytes_slice = unsafe { std::slice::from_raw_parts(bytes, len as usize) };
Expand Down
8 changes: 7 additions & 1 deletion test/data/sentence-transformers-labse.json
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,19 @@
"max_input_chars_per_word": 100,
"vocab": {
"[PAD]": 0,
"[CLS]":101,
"[SEP]":102,
"brown": 51775,
"fox": 193284,
"jumps": 333915,
"over": 15444,
"the": 14985,
"lazy": 221123,
"dog": 22452
"dog": 22452,
"[":164,
"CLS":304910,
"]":166,
"SEP":211703
}
}
}
41 changes: 30 additions & 11 deletions tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ type Tokenizer struct {
tokenizer unsafe.Pointer
}

type tokenizerOpts struct {
encodeSpecialTokens C.bool
}

type TokenizerOption func(to *tokenizerOpts)

func WithEncodeSpecialTokens() TokenizerOption {
return func(to *tokenizerOpts) {
to.encodeSpecialTokens = C.bool(true)
}
}

type TruncationDirection int

const (
Expand All @@ -28,8 +40,15 @@ const (

var _ io.Closer = (*Tokenizer)(nil)

func FromBytes(data []byte) (*Tokenizer, error) {
tokenizer := C.from_bytes((*C.uchar)(unsafe.Pointer(&data[0])), C.uint(len(data)))
func FromBytes(data []byte, opts ...TokenizerOption) (*Tokenizer, error) {
allOpts := &tokenizerOpts{
// by default, we do not encode special tokens
encodeSpecialTokens: C.bool(false),
}
for _, opt := range opts {
opt(allOpts)
}
tokenizer := C.from_bytes((*C.uchar)(unsafe.Pointer(&data[0])), C.uint(len(data)), (*C.struct_TokenizerOptions)(unsafe.Pointer(allOpts)))
return &Tokenizer{tokenizer: tokenizer}, nil
}

Expand Down Expand Up @@ -62,7 +81,7 @@ type Encoding struct {
Tokens []string
}

type EncodeOptions struct {
type encodeOpts struct {
AddSpecialTokens C.bool

ReturnTypeIDs C.bool
Expand All @@ -71,7 +90,7 @@ type EncodeOptions struct {
ReturnAttentionMask C.bool
}

type EncodeOption func(eo *EncodeOptions)
type EncodeOption func(eo *encodeOpts)

func uintVecToSlice(arrPtr *C.uint, len int) []uint32 {
arr := unsafe.Slice(arrPtr, len)
Expand All @@ -85,7 +104,7 @@ func uintVecToSlice(arrPtr *C.uint, len int) []uint32 {
func (t *Tokenizer) Encode(str string, addSpecialTokens bool) ([]uint32, []string) {
cStr := C.CString(str)
defer C.free(unsafe.Pointer(cStr))
options := EncodeOptions{
options := encodeOpts{
AddSpecialTokens: C.bool(addSpecialTokens),
ReturnTokens: C.bool(true),
}
Expand All @@ -109,7 +128,7 @@ func (t *Tokenizer) Encode(str string, addSpecialTokens bool) ([]uint32, []strin
}

func WithReturnAllAttributes() EncodeOption {
return func(eo *EncodeOptions) {
return func(eo *encodeOpts) {
eo.ReturnTypeIDs = C.bool(true)
eo.ReturnSpecialTokensMask = C.bool(true)
eo.ReturnAttentionMask = C.bool(true)
Expand All @@ -118,25 +137,25 @@ func WithReturnAllAttributes() EncodeOption {
}

func WithReturnTypeIDs() EncodeOption {
return func(eo *EncodeOptions) {
return func(eo *encodeOpts) {
eo.ReturnTypeIDs = C.bool(true)
}
}

func WithReturnSpecialTokensMask() EncodeOption {
return func(eo *EncodeOptions) {
return func(eo *encodeOpts) {
eo.ReturnSpecialTokensMask = C.bool(true)
}
}

func WithReturnTokens() EncodeOption {
return func(eo *EncodeOptions) {
return func(eo *encodeOpts) {
eo.ReturnTokens = C.bool(true)
}
}

func WithReturnAttentionMask() EncodeOption {
return func(eo *EncodeOptions) {
return func(eo *encodeOpts) {
eo.ReturnAttentionMask = C.bool(true)
}
}
Expand All @@ -145,7 +164,7 @@ func (t *Tokenizer) EncodeWithOptions(str string, addSpecialTokens bool, opts ..
cStr := C.CString(str)
defer C.free(unsafe.Pointer(cStr))

encOptions := EncodeOptions{
encOptions := encodeOpts{
AddSpecialTokens: C.bool(addSpecialTokens),
}
for _, opt := range opts {
Expand Down
17 changes: 17 additions & 0 deletions tokenizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,23 @@ func TestEncodeWithAndWithoutOptions(t *testing.T) {
}
}

func TestEncodeSpecialTokens(t *testing.T) {
tk, err := tokenizers.FromBytes(embeddedBytes)
require.NoError(t, err)
// special tokens are not encoded by default,
// meaning if input matches a special token, encoding will include the special token
ids, _ := tk.Encode("[CLS]fox[SEP]", false)
assert.Equal(t, []uint32{101, 193284, 102}, ids)
tk.Close()

tk, err = tokenizers.FromBytes(embeddedBytes, tokenizers.WithEncodeSpecialTokens())
require.NoError(t, err)
ids, _ = tk.Encode("[CLS]fox[SEP]", false)
// assert that special tokens 101 and 102 are not present
assert.Equal(t, []uint32{164, 304910, 166, 193284, 164, 211703, 166}, ids)
tk.Close()
}

func TestEncodeOptions(t *testing.T) {
tk, err := tokenizers.FromFile("./test/data/bert-base-uncased.json")
require.NoError(t, err)
Expand Down
6 changes: 5 additions & 1 deletion tokenizers.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ struct EncodeOptions {
bool return_attention_mask;
};

struct TokenizerOptions {
bool encode_special_tokens;
};

struct Buffer {
uint32_t *ids;
uint32_t *type_ids;
Expand All @@ -18,7 +22,7 @@ struct Buffer {
uint32_t len;
};

void *from_bytes(const uint8_t *config, uint32_t len);
void *from_bytes(const uint8_t *config, uint32_t len, const struct TokenizerOptions *options);

void *from_bytes_with_truncation(const uint8_t *config, uint32_t len, uint32_t max_len, uint8_t direction);

Expand Down

0 comments on commit 3a615e6

Please sign in to comment.