diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py index 6b257a856..46042a1c9 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py @@ -1,3 +1,4 @@ +from .document_embedder import OllamaDocumentEmbedder from .text_embedder import OllamaTextEmbedder -__all__ = ["OllamaTextEmbedder"] +__all__ = ["OllamaTextEmbedder", "OllamaDocumentEmbedder"] diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py new file mode 100644 index 000000000..17b32f065 --- /dev/null +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py @@ -0,0 +1,126 @@ +from typing import Any, Dict, List, Optional + +import requests +from haystack import Document, component +from tqdm import tqdm + + +@component +class OllamaDocumentEmbedder: + def __init__( + self, + model: str = "orca-mini", + url: str = "http://localhost:11434/api/embeddings", + generation_kwargs: Optional[Dict[str, Any]] = None, + timeout: int = 120, + prefix: str = "", + suffix: str = "", + progress_bar: bool = True, + meta_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", + ): + """ + :param model: The name of the model to use. The model should be available in the running Ollama instance. + Default is "orca-mini". + :param url: The URL of the chat endpoint of a running Ollama instance. + Default is "http://localhost:11434/api/embeddings". + :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, + top_p, and others. See the available arguments in + [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). + :param timeout: The number of seconds before throwing a timeout error from the Ollama API. + Default is 120 seconds. + """ + self.timeout = timeout + self.generation_kwargs = generation_kwargs or {} + self.url = url + self.model = model + self.batch_size = 1 # API only supports a single call at the moment + self.progress_bar = progress_bar + self.meta_fields_to_embed = meta_fields_to_embed + self.embedding_separator = embedding_separator + self.suffix = suffix + self.prefix = prefix + + def _create_json_payload(self, text: str, generation_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]: + """ + Returns A dictionary of JSON arguments for a POST request to an Ollama service + :param text: Text that is to be converted to an embedding + :param generation_kwargs: + :return: A dictionary of arguments for a POST request to an Ollama service + """ + return {"model": self.model, "prompt": text, "options": {**self.generation_kwargs, **(generation_kwargs or {})}} + + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: + """ + Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. + """ + texts_to_embed = [] + for doc in documents: + if self.meta_fields_to_embed is not None: + meta_values_to_embed = [ + str(doc.meta[key]) + for key in self.meta_fields_to_embed + if key in doc.meta and doc.meta[key] is not None + ] + else: + meta_values_to_embed = [] + + text_to_embed = ( + self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix + ).replace("\n", " ") + + texts_to_embed.append(text_to_embed) + return texts_to_embed + + def _embed_batch( + self, texts_to_embed: List[str], batch_size: int, generation_kwargs: Optional[Dict[str, Any]] = None + ): + """ + Ollama Embedding only allows single uploads, not batching. Currently the batch size is set to 1. + If this changes in the future, line 86 (the first line within the for loop), can contain: + batch = texts_to_embed[i + i + batch_size] + """ + + all_embeddings = [] + meta: Dict[str, Any] = {"model": ""} + + for i in tqdm( + range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" + ): + batch = texts_to_embed[i] # Single batch only + payload = self._create_json_payload(batch, generation_kwargs) + response = requests.post(url=self.url, json=payload, timeout=self.timeout) + response.raise_for_status() + result = response.json() + all_embeddings.append(result["embedding"]) + + meta["model"] = self.model + + return all_embeddings, meta + + @component.output_types(documents=List[Document], meta=Dict[str, Any]) + def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None): + """ + Run an Ollama Model on a provided documents. + :param documents: Documents to be converted to an embedding. + :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, + top_p, etc. See the + [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). + :return: Documents with embedding information attached and metadata in a dictionary + """ + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + msg = ( + "OllamaDocumentEmbedder expects a list of Documents as input." + "In case you want to embed a list of strings, please use the OllamaTextEmbedder." + ) + raise TypeError(msg) + + texts_to_embed = self._prepare_texts_to_embed(documents=documents) + embeddings, meta = self._embed_batch( + texts_to_embed=texts_to_embed, batch_size=self.batch_size, generation_kwargs=generation_kwargs + ) + + for doc, emb in zip(documents, embeddings): + doc.embedding = emb + + return {"documents": documents, "meta": meta} diff --git a/integrations/ollama/tests/test_document_embedder.py b/integrations/ollama/tests/test_document_embedder.py new file mode 100644 index 000000000..a5694db33 --- /dev/null +++ b/integrations/ollama/tests/test_document_embedder.py @@ -0,0 +1,51 @@ +import pytest +from haystack import Document +from haystack_integrations.components.embedders.ollama import OllamaDocumentEmbedder +from requests import HTTPError + + +class TestOllamaDocumentEmbedder: + def test_init_defaults(self): + embedder = OllamaDocumentEmbedder() + + assert embedder.timeout == 120 + assert embedder.generation_kwargs == {} + assert embedder.url == "http://localhost:11434/api/embeddings" + assert embedder.model == "orca-mini" + + def test_init(self): + embedder = OllamaDocumentEmbedder( + model="orca-mini", + url="http://my-custom-endpoint:11434/api/embeddings", + generation_kwargs={"temperature": 0.5}, + timeout=3000, + ) + + assert embedder.timeout == 3000 + assert embedder.generation_kwargs == {"temperature": 0.5} + assert embedder.url == "http://my-custom-endpoint:11434/api/embeddings" + assert embedder.model == "orca-mini" + + @pytest.mark.integration + def test_model_not_found(self): + embedder = OllamaDocumentEmbedder(model="cheese") + + with pytest.raises(HTTPError): + embedder.run([Document("hello")]) + + @pytest.mark.integration + def import_text_in_embedder(self): + embedder = OllamaDocumentEmbedder(model="orca-mini") + + with pytest.raises(TypeError): + embedder.run("This is a text string. This should not work.") + + @pytest.mark.integration + def test_run(self): + embedder = OllamaDocumentEmbedder(model="orca-mini") + list_of_docs = [Document(content="This is a document containing some text.")] + reply = embedder.run(list_of_docs) + + assert isinstance(reply, dict) + assert all(isinstance(element, float) for element in reply["documents"][0].embedding) + assert reply["meta"]["model"] == "orca-mini"