-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added ollama document embedder and tests * Cleaning of non-used variables and batch restrictions * Fixed issue with test_document_embedder.py import_text_in_embedder test, test was incorrect * Fixed lint issues and tests * chore: Exculde evaluator private classes in API docs (#392) * rename astraretriever (#399) * rename retriever (#407) * test patch- documents embedding wasn't working as expected --------- Co-authored-by: Madeesh Kannan <[email protected]> Co-authored-by: Stefano Fiorucci <[email protected]>
- Loading branch information
1 parent
d0909b7
commit e486ff0
Showing
3 changed files
with
179 additions
and
1 deletion.
There are no files selected for viewing
3 changes: 2 additions & 1 deletion
3
integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .document_embedder import OllamaDocumentEmbedder | ||
from .text_embedder import OllamaTextEmbedder | ||
|
||
__all__ = ["OllamaTextEmbedder"] | ||
__all__ = ["OllamaTextEmbedder", "OllamaDocumentEmbedder"] |
126 changes: 126 additions & 0 deletions
126
...rations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from typing import Any, Dict, List, Optional | ||
|
||
import requests | ||
from haystack import Document, component | ||
from tqdm import tqdm | ||
|
||
|
||
@component | ||
class OllamaDocumentEmbedder: | ||
def __init__( | ||
self, | ||
model: str = "orca-mini", | ||
url: str = "http://localhost:11434/api/embeddings", | ||
generation_kwargs: Optional[Dict[str, Any]] = None, | ||
timeout: int = 120, | ||
prefix: str = "", | ||
suffix: str = "", | ||
progress_bar: bool = True, | ||
meta_fields_to_embed: Optional[List[str]] = None, | ||
embedding_separator: str = "\n", | ||
): | ||
""" | ||
:param model: The name of the model to use. The model should be available in the running Ollama instance. | ||
Default is "orca-mini". | ||
:param url: The URL of the chat endpoint of a running Ollama instance. | ||
Default is "http://localhost:11434/api/embeddings". | ||
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, | ||
top_p, and others. See the available arguments in | ||
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). | ||
:param timeout: The number of seconds before throwing a timeout error from the Ollama API. | ||
Default is 120 seconds. | ||
""" | ||
self.timeout = timeout | ||
self.generation_kwargs = generation_kwargs or {} | ||
self.url = url | ||
self.model = model | ||
self.batch_size = 1 # API only supports a single call at the moment | ||
self.progress_bar = progress_bar | ||
self.meta_fields_to_embed = meta_fields_to_embed | ||
self.embedding_separator = embedding_separator | ||
self.suffix = suffix | ||
self.prefix = prefix | ||
|
||
def _create_json_payload(self, text: str, generation_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]: | ||
""" | ||
Returns A dictionary of JSON arguments for a POST request to an Ollama service | ||
:param text: Text that is to be converted to an embedding | ||
:param generation_kwargs: | ||
:return: A dictionary of arguments for a POST request to an Ollama service | ||
""" | ||
return {"model": self.model, "prompt": text, "options": {**self.generation_kwargs, **(generation_kwargs or {})}} | ||
|
||
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: | ||
""" | ||
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. | ||
""" | ||
texts_to_embed = [] | ||
for doc in documents: | ||
if self.meta_fields_to_embed is not None: | ||
meta_values_to_embed = [ | ||
str(doc.meta[key]) | ||
for key in self.meta_fields_to_embed | ||
if key in doc.meta and doc.meta[key] is not None | ||
] | ||
else: | ||
meta_values_to_embed = [] | ||
|
||
text_to_embed = ( | ||
self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix | ||
).replace("\n", " ") | ||
|
||
texts_to_embed.append(text_to_embed) | ||
return texts_to_embed | ||
|
||
def _embed_batch( | ||
self, texts_to_embed: List[str], batch_size: int, generation_kwargs: Optional[Dict[str, Any]] = None | ||
): | ||
""" | ||
Ollama Embedding only allows single uploads, not batching. Currently the batch size is set to 1. | ||
If this changes in the future, line 86 (the first line within the for loop), can contain: | ||
batch = texts_to_embed[i + i + batch_size] | ||
""" | ||
|
||
all_embeddings = [] | ||
meta: Dict[str, Any] = {"model": ""} | ||
|
||
for i in tqdm( | ||
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" | ||
): | ||
batch = texts_to_embed[i] # Single batch only | ||
payload = self._create_json_payload(batch, generation_kwargs) | ||
response = requests.post(url=self.url, json=payload, timeout=self.timeout) | ||
response.raise_for_status() | ||
result = response.json() | ||
all_embeddings.append(result["embedding"]) | ||
|
||
meta["model"] = self.model | ||
|
||
return all_embeddings, meta | ||
|
||
@component.output_types(documents=List[Document], meta=Dict[str, Any]) | ||
def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None): | ||
""" | ||
Run an Ollama Model on a provided documents. | ||
:param documents: Documents to be converted to an embedding. | ||
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, | ||
top_p, etc. See the | ||
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). | ||
:return: Documents with embedding information attached and metadata in a dictionary | ||
""" | ||
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): | ||
msg = ( | ||
"OllamaDocumentEmbedder expects a list of Documents as input." | ||
"In case you want to embed a list of strings, please use the OllamaTextEmbedder." | ||
) | ||
raise TypeError(msg) | ||
|
||
texts_to_embed = self._prepare_texts_to_embed(documents=documents) | ||
embeddings, meta = self._embed_batch( | ||
texts_to_embed=texts_to_embed, batch_size=self.batch_size, generation_kwargs=generation_kwargs | ||
) | ||
|
||
for doc, emb in zip(documents, embeddings): | ||
doc.embedding = emb | ||
|
||
return {"documents": documents, "meta": meta} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import pytest | ||
from haystack import Document | ||
from haystack_integrations.components.embedders.ollama import OllamaDocumentEmbedder | ||
from requests import HTTPError | ||
|
||
|
||
class TestOllamaDocumentEmbedder: | ||
def test_init_defaults(self): | ||
embedder = OllamaDocumentEmbedder() | ||
|
||
assert embedder.timeout == 120 | ||
assert embedder.generation_kwargs == {} | ||
assert embedder.url == "http://localhost:11434/api/embeddings" | ||
assert embedder.model == "orca-mini" | ||
|
||
def test_init(self): | ||
embedder = OllamaDocumentEmbedder( | ||
model="orca-mini", | ||
url="http://my-custom-endpoint:11434/api/embeddings", | ||
generation_kwargs={"temperature": 0.5}, | ||
timeout=3000, | ||
) | ||
|
||
assert embedder.timeout == 3000 | ||
assert embedder.generation_kwargs == {"temperature": 0.5} | ||
assert embedder.url == "http://my-custom-endpoint:11434/api/embeddings" | ||
assert embedder.model == "orca-mini" | ||
|
||
@pytest.mark.integration | ||
def test_model_not_found(self): | ||
embedder = OllamaDocumentEmbedder(model="cheese") | ||
|
||
with pytest.raises(HTTPError): | ||
embedder.run([Document("hello")]) | ||
|
||
@pytest.mark.integration | ||
def import_text_in_embedder(self): | ||
embedder = OllamaDocumentEmbedder(model="orca-mini") | ||
|
||
with pytest.raises(TypeError): | ||
embedder.run("This is a text string. This should not work.") | ||
|
||
@pytest.mark.integration | ||
def test_run(self): | ||
embedder = OllamaDocumentEmbedder(model="orca-mini") | ||
list_of_docs = [Document(content="This is a document containing some text.")] | ||
reply = embedder.run(list_of_docs) | ||
|
||
assert isinstance(reply, dict) | ||
assert all(isinstance(element, float) for element in reply["documents"][0].embedding) | ||
assert reply["meta"]["model"] == "orca-mini" |