Skip to content

Commit

Permalink
feat: OllamaDocumentEmbedder - allow batching embeddings (#1224)
Browse files Browse the repository at this point in the history
* use batch embeddings

* extend embeddings with batch result

* add batch_size parameter to OllamaDocumentEmbedder

* use correct embed parameter

* add unit test for ollama batch embed

* refinements

---------

Co-authored-by: David S. Batista <[email protected]>
Co-authored-by: anakin87 <[email protected]>
  • Loading branch information
3 people authored Nov 28, 2024
1 parent 5751605 commit eb2cfb1
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
batch_size: int = 32,
):
"""
:param model:
Expand All @@ -48,12 +49,24 @@ def __init__(
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param timeout:
The number of seconds before throwing a timeout error from the Ollama API.
:param prefix:
A string to add at the beginning of each text.
:param suffix:
A string to add at the end of each text.
:param progress_bar:
If `True`, shows a progress bar when running.
:param meta_fields_to_embed:
List of metadata fields to embed along with the document text.
:param embedding_separator:
Separator used to concatenate the metadata fields to the document text.
:param batch_size:
Number of documents to process at once.
"""
self.timeout = timeout
self.generation_kwargs = generation_kwargs or {}
self.url = url
self.model = model
self.batch_size = 1 # API only supports a single call at the moment
self.batch_size = batch_size
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed
self.embedding_separator = embedding_separator
Expand Down Expand Up @@ -88,24 +101,19 @@ def _embed_batch(
self, texts_to_embed: List[str], batch_size: int, generation_kwargs: Optional[Dict[str, Any]] = None
):
"""
Ollama Embedding only allows single uploads, not batching. Currently the batch size is set to 1.
If this changes in the future, line 86 (the first line within the for loop), can contain:
batch = texts_to_embed[i + i + batch_size]
Internal method to embed a batch of texts.
"""

all_embeddings = []
meta: Dict[str, Any] = {"model": ""}

for i in tqdm(
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = texts_to_embed[i] # Single batch only
result = self._client.embeddings(model=self.model, prompt=batch, options=generation_kwargs).model_dump()
all_embeddings.append(result["embedding"])
batch = texts_to_embed[i : i + batch_size]
result = self._client.embed(model=self.model, input=batch, options=generation_kwargs)
all_embeddings.extend(result["embeddings"])

meta["model"] = self.model

return all_embeddings, meta
return all_embeddings

@component.output_types(documents=List[Document], meta=Dict[str, Any])
def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None):
Expand All @@ -129,12 +137,14 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A
)
raise TypeError(msg)

generation_kwargs = generation_kwargs or self.generation_kwargs

texts_to_embed = self._prepare_texts_to_embed(documents=documents)
embeddings, meta = self._embed_batch(
embeddings = self._embed_batch(
texts_to_embed=texts_to_embed, batch_size=self.batch_size, generation_kwargs=generation_kwargs
)

for doc, emb in zip(documents, embeddings):
doc.embedding = emb

return {"documents": documents, "meta": meta}
return {"documents": documents, "meta": {"model": self.model}}
18 changes: 11 additions & 7 deletions integrations/ollama/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,14 @@ def import_text_in_embedder(self):

@pytest.mark.integration
def test_run(self):
embedder = OllamaDocumentEmbedder(model="nomic-embed-text")
list_of_docs = [Document(content="This is a document containing some text.")]
reply = embedder.run(list_of_docs)

assert isinstance(reply, dict)
assert all(isinstance(element, float) for element in reply["documents"][0].embedding)
assert reply["meta"]["model"] == "nomic-embed-text"
embedder = OllamaDocumentEmbedder(model="nomic-embed-text", batch_size=2)
list_of_docs = [
Document(content="Llamas are amazing animals known for their soft wool and gentle demeanor."),
Document(content="The Andes mountains are the natural habitat of many llamas."),
Document(content="Llamas have been used as pack animals for centuries, especially in South America."),
]
result = embedder.run(list_of_docs)
assert result["meta"]["model"] == "nomic-embed-text"
documents = result["documents"]
assert len(documents) == 3
assert all(isinstance(element, float) for document in documents for element in document.embedding)

0 comments on commit eb2cfb1

Please sign in to comment.