Skip to content

Commit

Permalink
feat: Weighted mean pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Sep 13, 2024
1 parent b42740f commit 397074f
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions src/embedder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,43 @@ llama_embedder *init_embedder(const char *embedding_model, const uint32_t poolin
return embedder;
}

std::vector<std::vector<float>> weighted_mean_pooling(
const std::vector<std::vector<std::vector<float>>>& last_hidden_state,
const std::vector<std::vector<int>>& attention_mask, const int max_length) {

if (last_hidden_state.empty() || attention_mask.empty() ||
last_hidden_state.size() != attention_mask.size()) {
throw std::invalid_argument("Invalid input sizes");
}

size_t batch_size = last_hidden_state.size();
size_t seq_length = last_hidden_state[0].size();
size_t hidden_dim = last_hidden_state[0][0].size();

std::vector<std::vector<float>> result(batch_size, std::vector<float>(hidden_dim, 0.0f));

for (size_t i = 0; i < batch_size; ++i) {
std::vector<float> sum(hidden_dim, 0.0f);
float mask_sum = 0.0f;

for (size_t j = 0; j < seq_length; ++j) {
auto mask_value = static_cast<float>(attention_mask[i][j]);
for (size_t k = 0; k < hidden_dim; ++k) {
sum[k] += last_hidden_state[i][j][k] * mask_value;
}
mask_sum += mask_value;
}

// Normalize, with a minimum value to prevent division by zero
double norm_factor = std::max(mask_sum, 1e-9f);
for (size_t k = 0; k < hidden_dim; ++k) {
result[i][k] = sum[k] / norm_factor;
}
}

return result;
}

void tokenize(llama_embedder *embedder, const std::vector<std::string>& texts, std::vector<llama_tokenizer_data> &output,const bool add_special_tokens, const bool parse_special, const bool enable_padding) {
if (!embedder) {
throw std::runtime_error("Error: Null pointer passed to tokenize function");
Expand Down

0 comments on commit 397074f

Please sign in to comment.