Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ollama document embedder #400

Merged
merged 9 commits into from
Feb 15, 2024
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .document_embedder import OllamaDocumentEmbedder
from .text_embedder import OllamaTextEmbedder

__all__ = ["OllamaTextEmbedder"]
__all__ = ["OllamaTextEmbedder", "OllamaDocumentEmbedder"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import Any, Dict, List, Optional

import requests
from haystack import Document, component
from tqdm import tqdm


@component
class OllamaDocumentEmbedder:
def __init__(
self,
model: str = "orca-mini",
url: str = "http://localhost:11434/api/embeddings",
generation_kwargs: Optional[Dict[str, Any]] = None,
timeout: int = 120,
prefix: str = "",
suffix: str = "",
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
"""
:param model: The name of the model to use. The model should be available in the running Ollama instance.
Default is "orca-mini".
:param url: The URL of the chat endpoint of a running Ollama instance.
Default is "http://localhost:11434/api/embeddings".
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, and others. See the available arguments in
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param timeout: The number of seconds before throwing a timeout error from the Ollama API.
Default is 120 seconds.
"""
self.timeout = timeout
self.generation_kwargs = generation_kwargs or {}
self.url = url
self.model = model
self.batch_size = 1 # API only supports a single call at the moment
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed
self.embedding_separator = embedding_separator
self.suffix = suffix
self.prefix = prefix

def _create_json_payload(self, text: str, generation_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""
Returns A dictionary of JSON arguments for a POST request to an Ollama service
:param text: Text that is to be converted to an embedding
:param generation_kwargs:
:return: A dictionary of arguments for a POST request to an Ollama service
"""
return {"model": self.model, "prompt": text, "options": {**self.generation_kwargs, **(generation_kwargs or {})}}

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
"""
texts_to_embed = []
for doc in documents:
if self.meta_fields_to_embed is not None:
meta_values_to_embed = [
str(doc.meta[key])
for key in self.meta_fields_to_embed
if key in doc.meta and doc.meta[key] is not None
]
else:
meta_values_to_embed = []

text_to_embed = (
self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix
).replace("\n", " ")

texts_to_embed.append(text_to_embed)
return texts_to_embed

def _embed_batch(
self, texts_to_embed: List[str], batch_size: int, generation_kwargs: Optional[Dict[str, Any]] = None
):
"""
Ollama Embedding only allows single uploads, not batching. Currently the batch size is set to 1.
If this changes in the future, line 86 (the first line within the for loop), can contain:
batch = texts_to_embed[i + i + batch_size]
"""

all_embeddings = []
meta: Dict[str, Any] = {"model": ""}

for i in tqdm(
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = texts_to_embed[i] # Single batch only
payload = self._create_json_payload(batch, generation_kwargs)
response = requests.post(url=self.url, json=payload, timeout=self.timeout)
response.raise_for_status()
result = response.json()
all_embeddings.append(result["embedding"])

meta["model"] = self.model

return all_embeddings, meta

@component.output_types(documents=List[Document], meta=Dict[str, Any])
def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Run an Ollama Model on a provided documents.
:param documents: Documents to be converted to an embedding.
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, etc. See the
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:return: Documents with embedding information attached and metadata in a dictionary
"""
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
msg = (
"OllamaDocumentEmbedder expects a list of Documents as input."
"In case you want to embed a list of strings, please use the OllamaTextEmbedder."
)
raise TypeError(msg)

texts_to_embed = self._prepare_texts_to_embed(documents=documents)
embeddings, meta = self._embed_batch(
texts_to_embed=texts_to_embed, batch_size=self.batch_size, generation_kwargs=generation_kwargs
)

for doc, emb in zip(documents, embeddings):
doc.embedding = emb

return {"documents": documents, "meta": meta}
51 changes: 51 additions & 0 deletions integrations/ollama/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
from haystack import Document
from haystack_integrations.components.embedders.ollama import OllamaDocumentEmbedder
from requests import HTTPError


class TestOllamaDocumentEmbedder:
def test_init_defaults(self):
embedder = OllamaDocumentEmbedder()

assert embedder.timeout == 120
assert embedder.generation_kwargs == {}
assert embedder.url == "http://localhost:11434/api/embeddings"
assert embedder.model == "orca-mini"

def test_init(self):
embedder = OllamaDocumentEmbedder(
model="orca-mini",
url="http://my-custom-endpoint:11434/api/embeddings",
generation_kwargs={"temperature": 0.5},
timeout=3000,
)

assert embedder.timeout == 3000
assert embedder.generation_kwargs == {"temperature": 0.5}
assert embedder.url == "http://my-custom-endpoint:11434/api/embeddings"
assert embedder.model == "orca-mini"

@pytest.mark.integration
def test_model_not_found(self):
embedder = OllamaDocumentEmbedder(model="cheese")

with pytest.raises(HTTPError):
embedder.run([Document("hello")])

@pytest.mark.integration
def import_text_in_embedder(self):
embedder = OllamaDocumentEmbedder(model="orca-mini")

with pytest.raises(TypeError):
embedder.run("This is a text string. This should not work.")

@pytest.mark.integration
def test_run(self):
embedder = OllamaDocumentEmbedder(model="orca-mini")
list_of_docs = [Document(content="This is a document containing some text.")]
reply = embedder.run(list_of_docs)

assert isinstance(reply, dict)
assert all(isinstance(element, float) for element in reply["documents"][0].embedding)
assert reply["meta"]["model"] == "orca-mini"