Skip to content

Commit

Permalink
feat: Tokenization with attn mask (#30)
Browse files Browse the repository at this point in the history
Closes #22
  • Loading branch information
tazarov authored Sep 12, 2024
1 parent cc03df1 commit b42740f
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 28 deletions.
10 changes: 7 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ python-cidist-local:
rm -rf build/lib.*
rm -rf build/temp.*
pip install cibuildwheel==2.19.1 auditwheel
CIBW_BEFORE_BUILD="make lib" \
CIBW_BEFORE_BUILD="make lib-test" \
CIBW_SKIP="pp* *musllinux*" \
CIBW_ARCHS_MACOS="x86_64" \
CIBW_PROJECT_REQUIRES_PYTHON=">=3.8,<3.9" \
CIBW_ARCHS_MACOS="arm64" \
CIBW_ARCHS_WINDOWS="AMD64" \
CIBW_ARCHS_LINUX="x86_64 aarch64" \
CIBW_PROJECT_REQUIRES_PYTHON=">=3.9,<3.10" \
CIBW_TEST_REQUIRES="pytest>=6.0.0 huggingface_hub" \
CIBW_TEST_COMMAND="python -m pytest {project}/bindings/python/tests/test" \
CI=1 \
python -m cibuildwheel --output-dir dist

Expand Down
28 changes: 23 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,28 @@ intended for advanced users that want a custom builds e.g. for GPU support.

This project requires cmake to build.

### Shared library

To build the shared library run:

```bash
make lib
```

To run the tests:

```bash
sysctl -a
mkdir build
cd build
cmake -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL_EMBED_LIBRARY=ON -DLLAMA_CURL=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=ON ..
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
make lib-test
```

### Python bindings

To build the python bindings run:

```bash
python -m venv venv
source venv/bin/activate
make python-cidist-local
```

The above command will build the shared library, the python binding package and run the tests.
47 changes: 42 additions & 5 deletions bindings/python/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ enum class PoolingType {
LAST = 3,
};

class TokenizerData {
public:
std::vector<int32_t> tokens;
std::vector<int32_t> attention_mask;

TokenizerData(const std::vector<int32_t>& tokens, const std::vector<int32_t>& attention_mask) : tokens(tokens), attention_mask(attention_mask) {}
};


class LlamaEmbedder {
private:
Expand All @@ -39,23 +47,45 @@ class LlamaEmbedder {
}
}

std::vector<std::vector<float>> embed(const std::vector<std::string>& prompts, NormalizationType norm) {
std::vector<std::vector<float>> output;
std::vector<std::vector<float>> embed(const std::vector<std::string>& texts, NormalizationType norm) {
if (!embedder) {
throw std::runtime_error("Embedder is not initialized");
}
::embed(embedder, prompts, output, static_cast<int32_t>(norm));

if (texts.empty()) {
throw std::runtime_error("Texts are empty");
}
std::vector<std::vector<float>> output;
::embed(embedder, texts, output, static_cast<int32_t>(norm));
return output;
}

std::unordered_map<std::string, std::string> get_metadata() {
std::unordered_map<std::string, std::string> metadata;
if (!embedder) {
throw std::runtime_error("Embedder is not initialized");
}
std::unordered_map<std::string, std::string> metadata;
::get_metadata(embedder, metadata);
return metadata;
}

std::vector<TokenizerData> tokenize(std::vector<std::string>& texts, const bool add_special_tokens = true, const bool parse_special = false, const bool enable_padding = false) {
std::vector<TokenizerData> final_output;
std::vector<llama_tokenizer_data> output;
if (!embedder) {
throw std::runtime_error("Embedder is not initialized");
}
if (texts.empty()) {
throw std::runtime_error("Texts are empty");
}
::tokenize(embedder, texts, output, add_special_tokens, parse_special, enable_padding);

for (const auto& tokenizer_data : output) {
TokenizerData temp(tokenizer_data.tokens, tokenizer_data.attention_mask);
final_output.push_back(temp);
}
return final_output;
}
};

PYBIND11_MODULE(llama_embedder, m) {
Expand All @@ -75,11 +105,18 @@ py::enum_<PoolingType>(m, "PoolingType")
.value("LAST", PoolingType::LAST)
.export_values();

py::class_<TokenizerData>(m, "TokenizerData")
.def(py::init<const std::vector<int32_t>&, const std::vector<int32_t>&>(), py::arg("tokens"), py::arg("attention_mask"))
.def_readwrite("tokens", &TokenizerData::tokens) // Bind tokens attribute
.def_readwrite("attention_mask", &TokenizerData::attention_mask); // Bind attention_mask attribute


py::class_<LlamaEmbedder>(m, "LlamaEmbedder")
.def(py::init<const std::string&, PoolingType>(), py::arg("model_path"), py::arg("pooling_type") = PoolingType::MEAN) // Updated init
.def("embed", &LlamaEmbedder::embed, "Create embeddings from prompts",
py::arg("prompts"), py::arg("norm") = NormalizationType::EUCLIDEAN)
py::arg("texts"), py::arg("norm") = NormalizationType::EUCLIDEAN)
.def("get_metadata", &LlamaEmbedder::get_metadata, "Get metadata of the model")
.def("tokenize", &LlamaEmbedder::tokenize, "Tokenize the input texts",py::arg("texts"), py::arg("add_special_tokens") = true, py::arg("parse_special") = false, py::arg("enable_padding") = false)
.def("__enter__", [](LlamaEmbedder& self) { return &self; })
.def("__exit__", [](LlamaEmbedder& self, py::object exc_type, py::object exc_value, py::object traceback) {});
}
54 changes: 54 additions & 0 deletions bindings/python/tests/test/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import List

import pytest
from llama_embedder import LlamaEmbedder, PoolingType, NormalizationType
Expand All @@ -23,6 +24,14 @@ def test_basic(get_model):
assert len(embeddings[0]) == 384


def test_embed_err_no_texts(get_model):
# Using without a context manager
embedder = LlamaEmbedder(get_model)
with pytest.raises(Exception) as e:
embedder.embed([])
assert str(e.value) == "Texts are empty"


def test_metadata(get_model):
# Using without a context manager
embedder = LlamaEmbedder(get_model)
Expand All @@ -35,3 +44,48 @@ def test_metadata(get_model):
assert metadata_dict["general.name"] == "all-MiniLM-L6-v2"
assert metadata_dict["general.architecture"] == "bert"
assert metadata_dict["bert.context_length"] == "512"


def get_attn_mask_len(attn_mask: List[int]) -> int:
mask_size = 0
for mask_t in attn_mask:
if mask_t != 0:
mask_size += 1

return mask_size


def test_tokenize(get_model):
# Using without a context manager
embedder = LlamaEmbedder(get_model)
tokenizer_data = embedder.tokenize(["Hello, world!", "How are you?"], enable_padding=True)
assert tokenizer_data is not None
assert len(tokenizer_data) == 2
assert len(tokenizer_data[0].tokens) == 512
assert len(tokenizer_data[0].attention_mask) == 512
assert get_attn_mask_len(tokenizer_data[0].attention_mask) == 6
assert len(tokenizer_data[1].tokens) == 512
assert len(tokenizer_data[1].attention_mask) == 512
assert get_attn_mask_len(tokenizer_data[1].attention_mask) == 6


def test_tokenize_without_special_tokens(get_model):
# Using without a context manager
embedder = LlamaEmbedder(get_model)
tokenizer_data = embedder.tokenize(["Hello, world!", "How are you?"], add_special_tokens=False, enable_padding=True)
assert tokenizer_data is not None
assert len(tokenizer_data) == 2
assert len(tokenizer_data[0].tokens) == 512
assert len(tokenizer_data[0].attention_mask) == 512
assert get_attn_mask_len(tokenizer_data[0].attention_mask) == 4
assert len(tokenizer_data[1].tokens) == 512
assert len(tokenizer_data[1].attention_mask) == 512
assert get_attn_mask_len(tokenizer_data[1].attention_mask) == 4


def test_tokenize_err_no_texts(get_model):
# Using without a context manager
embedder = LlamaEmbedder(get_model)
with pytest.raises(Exception) as e:
embedder.tokenize([])
assert str(e.value) == "Texts are empty"
77 changes: 71 additions & 6 deletions src/embedder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,35 @@ void my_log_callback(enum ggml_log_level level, const char *text, void *user_dat
// Do nothing, effectively silencing the log
}

// Function to generate attention mask
std::vector<int32_t> generate_attention_mask(const std::vector<int>& token_ids, unsigned long max_length) {
std::vector<int32_t> attention_mask(max_length, 0); // Initialize mask with 0s

for (size_t i = 0; i < token_ids.size() && i < max_length; ++i) {
if (token_ids[i] != 0) {
attention_mask[i] = 1; // Set 1 for non-padding tokens (non-zero)
}
}

return attention_mask;
}

/// Function to pad token IDs and add CLS and SEP tokens
std::vector<int> pad_tokens(const std::vector<int>& token_ids, unsigned long max_length,
int pad_token_id = 0) {
std::vector<int> padded_token_ids;

// Add the actual tokens
padded_token_ids.insert(padded_token_ids.end(), token_ids.begin(), token_ids.end());

// Add padding if token size is still less than max_length
if (padded_token_ids.size() < max_length) {
padded_token_ids.resize(max_length, pad_token_id);
}

return padded_token_ids;
}

enum llama_pooling_type from_uint(const uint32_t pooling_type){
switch (pooling_type) {
case 0:
Expand Down Expand Up @@ -157,12 +186,46 @@ llama_embedder *init_embedder(const char *embedding_model, const uint32_t poolin
return embedder;
}

void tokenize(llama_embedder *embedder, const std::vector<std::string>& texts, std::vector<llama_tokenizer_data> &output,const bool add_special_tokens, const bool parse_special, const bool enable_padding) {
if (!embedder) {
throw std::runtime_error("Error: Null pointer passed to tokenize function");
}
if (texts.empty()){
fprintf(stderr, "Warn: empty texts.\n");
return;
}

char model_arch[1024];
size_t vmodel_arch_size = sizeof(model_arch);
llama_model_meta_val_str(embedder->model, "general.architecture",model_arch, vmodel_arch_size);
if (strcmp(model_arch, "bert") != 0) {
throw std::runtime_error("error: tokenize function is only supported for BERT-like models");
}

for (const auto &text: texts) {
auto tokens = ::llama_tokenize(embedder->context, text, add_special_tokens, parse_special);
char value[1024];
size_t value_size = sizeof(value);
llama_model_meta_val_str(embedder->model, "bert.context_length",value, value_size);
unsigned long max_length = tokens.size();
if (enable_padding) {
max_length = std::stoi(value);
memset(value, 0, value_size);
llama_model_meta_val_str(embedder->model, "tokenizer.ggml.padding_token_id",value, value_size);
int padding_token_id = std::stoi(value);
tokens = pad_tokens(tokens,max_length , padding_token_id);
}
auto attention_mask = generate_attention_mask(tokens, max_length);
output.push_back({tokens, attention_mask});
}
}

void get_metadata(llama_embedder *embedder, std::unordered_map<std::string, std::string> &output) {
output = embedder->model_metadata;
}


void free_embedder(llama_embedder *embedder) {
void free_embedder(llama_embedder *embedder) noexcept {
if (embedder->model) {
llama_free_model(embedder->model);
}
Expand All @@ -174,12 +237,12 @@ void free_embedder(llama_embedder *embedder) {
}

// Creates embeddings from list of strings
void embed(llama_embedder *embedder, const std::vector<std::string> prompts, std::vector<std::vector<float>> &output,
void embed(llama_embedder *embedder, const std::vector<std::string> & texts, std::vector<std::vector<float>> & output,
int32_t embd_norm) {
if (!embedder) {
throw std::runtime_error("Error: Null pointer passed to embed function");
}
if (prompts.empty()){
if (texts.empty()){
fprintf(stderr, "Warn: empty prompts.\n");
return;
}
Expand All @@ -198,8 +261,10 @@ void embed(llama_embedder *embedder, const std::vector<std::string> prompts, std

// tokenize the prompts and trim
std::vector<std::vector<int32_t>> inputs;
for (const auto &prompt: prompts) {
auto inp = ::llama_tokenize(ctx, prompt, true, false);
for (const auto &prompt: texts) {
std::vector<llama_tokenizer_data> output_token_data;
::tokenize(embedder, {prompt}, output_token_data);
auto inp = output_token_data[0].tokens;
if (inp.size() > n_batch) {
fprintf(stderr,
"%s: error: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
Expand All @@ -220,7 +285,7 @@ void embed(llama_embedder *embedder, const std::vector<std::string> prompts, std
}

// initialize batch
const int n_prompts = prompts.size();
const int n_prompts = texts.size();
struct llama_batch batch = llama_batch_init((long long int) n_batch, 0, 1);

// count number of embeddings
Expand Down
13 changes: 9 additions & 4 deletions src/embedder.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@ struct llama_embedder {
std::unordered_map<std::string, std::string> model_metadata;
};

struct llama_tokenizer_data {
std::vector<int32_t> tokens;
std::vector<int32_t> attention_mask;
};

extern "C" {
EXPORT_SYMBOL llama_embedder * init_embedder(const char * embedding_model, const uint32_t pooling_type);
EXPORT_SYMBOL void free_embedder(llama_embedder *embedder);
EXPORT_SYMBOL void embed(llama_embedder * embedder, const std::vector<std::string> prompts, std::vector<std::vector<float>> &output, int32_t embd_norm);
EXPORT_SYMBOL void get_metadata(llama_embedder * embedder, std::unordered_map<std::string, std::string> &output);
EXPORT_SYMBOL llama_embedder * init_embedder(const char * embedding_model, uint32_t pooling_type) noexcept(false);
EXPORT_SYMBOL void free_embedder(llama_embedder *embedder) noexcept;
EXPORT_SYMBOL void embed(llama_embedder * embedder, const std::vector<std::string> & texts, std::vector<std::vector<float>> & output, int32_t embd_norm) noexcept(false);
EXPORT_SYMBOL void get_metadata(llama_embedder * embedder, std::unordered_map<std::string, std::string> &output) noexcept(false);
EXPORT_SYMBOL void tokenize(llama_embedder * embedder, const std::vector<std::string>& texts, std::vector<llama_tokenizer_data> &output, bool add_special_tokens = true, bool parse_special = false, bool enable_padding = false) noexcept(false);
}
Loading

0 comments on commit b42740f

Please sign in to comment.