From eb2cfb1a6a3f272023be776be9eb0cad83ba2420 Mon Sep 17 00:00:00 2001 From: latifboubyan Date: Thu, 28 Nov 2024 19:37:48 +0300 Subject: [PATCH] feat: `OllamaDocumentEmbedder` - allow batching embeddings (#1224) * use batch embeddings * extend embeddings with batch result * add batch_size parameter to OllamaDocumentEmbedder * use correct embed parameter * add unit test for ollama batch embed * refinements --------- Co-authored-by: David S. Batista Co-authored-by: anakin87 --- .../embedders/ollama/document_embedder.py | 36 ++++++++++++------- .../ollama/tests/test_document_embedder.py | 18 ++++++---- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py index 2fab6c72f..8d2f5f505 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py @@ -36,6 +36,7 @@ def __init__( progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", + batch_size: int = 32, ): """ :param model: @@ -48,12 +49,24 @@ def __init__( [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). :param timeout: The number of seconds before throwing a timeout error from the Ollama API. + :param prefix: + A string to add at the beginning of each text. + :param suffix: + A string to add at the end of each text. + :param progress_bar: + If `True`, shows a progress bar when running. + :param meta_fields_to_embed: + List of metadata fields to embed along with the document text. + :param embedding_separator: + Separator used to concatenate the metadata fields to the document text. + :param batch_size: + Number of documents to process at once. """ self.timeout = timeout self.generation_kwargs = generation_kwargs or {} self.url = url self.model = model - self.batch_size = 1 # API only supports a single call at the moment + self.batch_size = batch_size self.progress_bar = progress_bar self.meta_fields_to_embed = meta_fields_to_embed self.embedding_separator = embedding_separator @@ -88,24 +101,19 @@ def _embed_batch( self, texts_to_embed: List[str], batch_size: int, generation_kwargs: Optional[Dict[str, Any]] = None ): """ - Ollama Embedding only allows single uploads, not batching. Currently the batch size is set to 1. - If this changes in the future, line 86 (the first line within the for loop), can contain: - batch = texts_to_embed[i + i + batch_size] + Internal method to embed a batch of texts. """ all_embeddings = [] - meta: Dict[str, Any] = {"model": ""} for i in tqdm( range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" ): - batch = texts_to_embed[i] # Single batch only - result = self._client.embeddings(model=self.model, prompt=batch, options=generation_kwargs).model_dump() - all_embeddings.append(result["embedding"]) + batch = texts_to_embed[i : i + batch_size] + result = self._client.embed(model=self.model, input=batch, options=generation_kwargs) + all_embeddings.extend(result["embeddings"]) - meta["model"] = self.model - - return all_embeddings, meta + return all_embeddings @component.output_types(documents=List[Document], meta=Dict[str, Any]) def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None): @@ -129,12 +137,14 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A ) raise TypeError(msg) + generation_kwargs = generation_kwargs or self.generation_kwargs + texts_to_embed = self._prepare_texts_to_embed(documents=documents) - embeddings, meta = self._embed_batch( + embeddings = self._embed_batch( texts_to_embed=texts_to_embed, batch_size=self.batch_size, generation_kwargs=generation_kwargs ) for doc, emb in zip(documents, embeddings): doc.embedding = emb - return {"documents": documents, "meta": meta} + return {"documents": documents, "meta": {"model": self.model}} diff --git a/integrations/ollama/tests/test_document_embedder.py b/integrations/ollama/tests/test_document_embedder.py index 4fe3cfbb3..7d972e898 100644 --- a/integrations/ollama/tests/test_document_embedder.py +++ b/integrations/ollama/tests/test_document_embedder.py @@ -43,10 +43,14 @@ def import_text_in_embedder(self): @pytest.mark.integration def test_run(self): - embedder = OllamaDocumentEmbedder(model="nomic-embed-text") - list_of_docs = [Document(content="This is a document containing some text.")] - reply = embedder.run(list_of_docs) - - assert isinstance(reply, dict) - assert all(isinstance(element, float) for element in reply["documents"][0].embedding) - assert reply["meta"]["model"] == "nomic-embed-text" + embedder = OllamaDocumentEmbedder(model="nomic-embed-text", batch_size=2) + list_of_docs = [ + Document(content="Llamas are amazing animals known for their soft wool and gentle demeanor."), + Document(content="The Andes mountains are the natural habitat of many llamas."), + Document(content="Llamas have been used as pack animals for centuries, especially in South America."), + ] + result = embedder.run(list_of_docs) + assert result["meta"]["model"] == "nomic-embed-text" + documents = result["documents"] + assert len(documents) == 3 + assert all(isinstance(element, float) for document in documents for element in document.embedding)