diff --git a/src/embedder.cpp b/src/embedder.cpp index 8753e18..9d5ef00 100644 --- a/src/embedder.cpp +++ b/src/embedder.cpp @@ -186,6 +186,43 @@ llama_embedder *init_embedder(const char *embedding_model, const uint32_t poolin return embedder; } +std::vector> weighted_mean_pooling( + const std::vector>>& last_hidden_state, + const std::vector>& attention_mask, const int max_length) { + + if (last_hidden_state.empty() || attention_mask.empty() || + last_hidden_state.size() != attention_mask.size()) { + throw std::invalid_argument("Invalid input sizes"); + } + + size_t batch_size = last_hidden_state.size(); + size_t seq_length = last_hidden_state[0].size(); + size_t hidden_dim = last_hidden_state[0][0].size(); + + std::vector> result(batch_size, std::vector(hidden_dim, 0.0f)); + + for (size_t i = 0; i < batch_size; ++i) { + std::vector sum(hidden_dim, 0.0f); + float mask_sum = 0.0f; + + for (size_t j = 0; j < seq_length; ++j) { + auto mask_value = static_cast(attention_mask[i][j]); + for (size_t k = 0; k < hidden_dim; ++k) { + sum[k] += last_hidden_state[i][j][k] * mask_value; + } + mask_sum += mask_value; + } + + // Normalize, with a minimum value to prevent division by zero + double norm_factor = std::max(mask_sum, 1e-9f); + for (size_t k = 0; k < hidden_dim; ++k) { + result[i][k] = sum[k] / norm_factor; + } + } + + return result; +} + void tokenize(llama_embedder *embedder, const std::vector& texts, std::vector &output,const bool add_special_tokens, const bool parse_special, const bool enable_padding) { if (!embedder) { throw std::runtime_error("Error: Null pointer passed to tokenize function");