From d30f1250d5fa323e90e66829a10b96652b5847e5 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Thu, 7 Mar 2024 15:56:44 +0100 Subject: [PATCH] improvements to FastEmbed integration --- .../fastembed/embedding_backend/fastembed_backend.py | 11 +++++++++-- .../fastembed/fastembed_document_embedder.py | 12 +++++------- .../embedders/fastembed/fastembed_text_embedder.py | 7 +------ .../tests/test_fastembed_document_embedder.py | 2 +- .../fastembed/tests/test_fastembed_text_embedder.py | 10 ---------- 5 files changed, 16 insertions(+), 26 deletions(-) diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py index baf21c8a3..e44e50a61 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py @@ -1,5 +1,7 @@ from typing import ClassVar, Dict, List, Optional +from tqdm import tqdm + from fastembed import TextEmbedding @@ -39,7 +41,12 @@ def __init__( ): self.model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads) - def embed(self, data: List[List[str]], **kwargs) -> List[List[float]]: + def embed(self, data: List[str], progress_bar=True, **kwargs) -> List[List[float]]: # the embed method returns a Iterable[np.ndarray], so we convert it to a list of lists - embeddings = [np_array.tolist() for np_array in self.model.embed(data, **kwargs)] + embeddings = [] + embeddings_iterable = self.model.embed(data, **kwargs) + for np_array in tqdm( + embeddings_iterable, disable=not progress_bar, desc="Calculating embeddings", total=len(data) + ): + embeddings.append(np_array.tolist()) return embeddings diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py index b5dd71231..ec0b918d9 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py @@ -131,11 +131,11 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: meta_values_to_embed = [ str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None ] - text_to_embed = [ - self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix, - ] + text_to_embed = ( + self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix + ) - texts_to_embed.append(text_to_embed[0]) + texts_to_embed.append(text_to_embed) return texts_to_embed @component.output_types(documents=List[Document]) @@ -157,13 +157,11 @@ def run(self, documents: List[Document]): msg = "The embedding model has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) - # TODO: once non textual Documents are properly supported, we should also prepare them for embedding here - texts_to_embed = self._prepare_texts_to_embed(documents=documents) embeddings = self.embedding_backend.embed( texts_to_embed, batch_size=self.batch_size, - show_progress_bar=self.progress_bar, + progress_bar=self.progress_bar, parallel=self.parallel, ) diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py index 743884ec1..9bc4475a5 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py @@ -35,7 +35,6 @@ def __init__( threads: Optional[int] = None, prefix: str = "", suffix: str = "", - batch_size: int = 256, progress_bar: bool = True, parallel: Optional[int] = None, ): @@ -47,7 +46,6 @@ def __init__( Can be set using the `FASTEMBED_CACHE_PATH` env variable. Defaults to `fastembed_cache` in the system's temp directory. :param threads: The number of threads single onnxruntime session can use. Defaults to None. - :param batch_size: Number of strings to encode at once. :param prefix: A string to add to the beginning of each text. :param suffix: A string to add to the end of each text. :param progress_bar: If true, displays progress bar during embedding. @@ -62,7 +60,6 @@ def __init__( self.threads = threads self.prefix = prefix self.suffix = suffix - self.batch_size = batch_size self.progress_bar = progress_bar self.parallel = parallel @@ -80,7 +77,6 @@ def to_dict(self) -> Dict[str, Any]: threads=self.threads, prefix=self.prefix, suffix=self.suffix, - batch_size=self.batch_size, progress_bar=self.progress_bar, parallel=self.parallel, ) @@ -119,8 +115,7 @@ def run(self, text: str): embedding = list( self.embedding_backend.embed( text_to_embed, - batch_size=self.batch_size, - show_progress_bar=self.progress_bar, + progress_bar=self.progress_bar, parallel=self.parallel, )[0] ) diff --git a/integrations/fastembed/tests/test_fastembed_document_embedder.py b/integrations/fastembed/tests/test_fastembed_document_embedder.py index 797c295ba..75fdcc9c9 100644 --- a/integrations/fastembed/tests/test_fastembed_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_document_embedder.py @@ -261,7 +261,7 @@ def test_embed_metadata(self): "meta_value 4\ndocument-number 4", ], batch_size=256, - show_progress_bar=True, + progress_bar=True, parallel=None, ) diff --git a/integrations/fastembed/tests/test_fastembed_text_embedder.py b/integrations/fastembed/tests/test_fastembed_text_embedder.py index d5982c319..402980485 100644 --- a/integrations/fastembed/tests/test_fastembed_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_text_embedder.py @@ -19,7 +19,6 @@ def test_init_default(self): assert embedder.threads is None assert embedder.prefix == "" assert embedder.suffix == "" - assert embedder.batch_size == 256 assert embedder.progress_bar is True assert embedder.parallel is None @@ -33,7 +32,6 @@ def test_init_with_parameters(self): threads=2, prefix="prefix", suffix="suffix", - batch_size=64, progress_bar=False, parallel=1, ) @@ -42,7 +40,6 @@ def test_init_with_parameters(self): assert embedder.threads == 2 assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" - assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.parallel == 1 @@ -60,7 +57,6 @@ def test_to_dict(self): "threads": None, "prefix": "", "suffix": "", - "batch_size": 256, "progress_bar": True, "parallel": None, }, @@ -76,7 +72,6 @@ def test_to_dict_with_custom_init_parameters(self): threads=2, prefix="prefix", suffix="suffix", - batch_size=64, progress_bar=False, parallel=1, ) @@ -89,7 +84,6 @@ def test_to_dict_with_custom_init_parameters(self): "threads": 2, "prefix": "prefix", "suffix": "suffix", - "batch_size": 64, "progress_bar": False, "parallel": 1, }, @@ -107,7 +101,6 @@ def test_from_dict(self): "threads": None, "prefix": "", "suffix": "", - "batch_size": 256, "progress_bar": True, "parallel": None, }, @@ -118,7 +111,6 @@ def test_from_dict(self): assert embedder.threads is None assert embedder.prefix == "" assert embedder.suffix == "" - assert embedder.batch_size == 256 assert embedder.progress_bar is True assert embedder.parallel is None @@ -134,7 +126,6 @@ def test_from_dict_with_custom_init_parameters(self): "threads": 2, "prefix": "prefix", "suffix": "suffix", - "batch_size": 64, "progress_bar": False, "parallel": 1, }, @@ -145,7 +136,6 @@ def test_from_dict_with_custom_init_parameters(self): assert embedder.threads == 2 assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" - assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.parallel == 1