From f3ea6be51d8263327badb68c10a0acbfc471f90a Mon Sep 17 00:00:00 2001 From: Nicola Procopio Date: Tue, 20 Feb 2024 15:32:09 +0100 Subject: [PATCH] fastembed integration new parameters (#446) * added threads and cache_dir in backend * added threads and cache_dir to text embedding * added threads and cache_dit in documents embedder * fix test * formatted with back * fixed test --- .../embedding_backend/fastembed_backend.py | 14 ++++++----- .../fastembed/fastembed_document_embedder.py | 14 ++++++++++- .../fastembed/fastembed_text_embedder.py | 16 ++++++++++--- .../fastembed/tests/test_fastembed_backend.py | 8 +++---- .../tests/test_fastembed_document_embedder.py | 22 ++++++++++++++++- .../tests/test_fastembed_text_embedder.py | 24 ++++++++++++++++++- 6 files changed, 82 insertions(+), 16 deletions(-) diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py index ee51283e6..baf21c8a3 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Dict, List +from typing import ClassVar, Dict, List, Optional from fastembed import TextEmbedding @@ -13,15 +13,15 @@ class _FastembedEmbeddingBackendFactory: @staticmethod def get_embedding_backend( model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, ): - embedding_backend_id = f"{model_name}" + embedding_backend_id = f"{model_name}{cache_dir}{threads}" if embedding_backend_id in _FastembedEmbeddingBackendFactory._instances: return _FastembedEmbeddingBackendFactory._instances[embedding_backend_id] - embedding_backend = _FastembedEmbeddingBackend( - model_name=model_name, - ) + embedding_backend = _FastembedEmbeddingBackend(model_name=model_name, cache_dir=cache_dir, threads=threads) _FastembedEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend return embedding_backend @@ -34,8 +34,10 @@ class _FastembedEmbeddingBackend: def __init__( self, model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, ): - self.model = TextEmbedding(model_name=model_name) + self.model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads) def embed(self, data: List[List[str]], **kwargs) -> List[List[float]]: # the embed method returns a Iterable[np.ndarray], so we convert it to a list of lists diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py index 8f7dd8cd0..1aa2b9539 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py @@ -53,6 +53,8 @@ class FastembedDocumentEmbedder: def __init__( self, model: str = "BAAI/bge-small-en-v1.5", + cache_dir: Optional[str] = None, + threads: Optional[int] = None, prefix: str = "", suffix: str = "", batch_size: int = 256, @@ -66,6 +68,10 @@ def __init__( :param model: Local path or name of the model in Hugging Face's model hub, such as ``'BAAI/bge-small-en-v1.5'``. + :param cache_dir (str, optional): The path to the cache directory. + Can be set using the `FASTEMBED_CACHE_PATH` env variable. + Defaults to `fastembed_cache` in the system's temp directory. + :param threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None. :param prefix: A string to add to the beginning of each text. :param suffix: A string to add to the end of each text. :param batch_size: Number of strings to encode at once. @@ -79,6 +85,8 @@ def __init__( """ self.model_name = model + self.cache_dir = cache_dir + self.threads = threads self.prefix = prefix self.suffix = suffix self.batch_size = batch_size @@ -94,6 +102,8 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, model=self.model_name, + cache_dir=self.cache_dir, + threads=self.threads, prefix=self.prefix, suffix=self.suffix, batch_size=self.batch_size, @@ -108,7 +118,9 @@ def warm_up(self): Load the embedding backend. """ if not hasattr(self, "embedding_backend"): - self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name=self.model_name) + self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend( + model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads + ) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: texts_to_embed = [] diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py index 2f0b3ae62..c075875b7 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py @@ -31,6 +31,8 @@ class FastembedTextEmbedder: def __init__( self, model: str = "BAAI/bge-small-en-v1.5", + cache_dir: Optional[str] = None, + threads: Optional[int] = None, prefix: str = "", suffix: str = "", batch_size: int = 256, @@ -42,6 +44,10 @@ def __init__( :param model: Local path or name of the model in Fastembed's model hub, such as ``'BAAI/bge-small-en-v1.5'``. + :param cache_dir (str, optional): The path to the cache directory. + Can be set using the `FASTEMBED_CACHE_PATH` env variable. + Defaults to `fastembed_cache` in the system's temp directory. + :param threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None. :param batch_size: Number of strings to encode at once. :param prefix: A string to add to the beginning of each text. :param suffix: A string to add to the end of each text. @@ -52,9 +58,9 @@ def __init__( If None, don't use data-parallel processing, use default onnxruntime threading instead. """ - # TODO add parallel - self.model_name = model + self.cache_dir = cache_dir + self.threads = threads self.prefix = prefix self.suffix = suffix self.batch_size = batch_size @@ -68,6 +74,8 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, model=self.model_name, + cache_dir=self.cache_dir, + threads=self.threads, prefix=self.prefix, suffix=self.suffix, batch_size=self.batch_size, @@ -80,7 +88,9 @@ def warm_up(self): Load the embedding backend. """ if not hasattr(self, "embedding_backend"): - self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name=self.model_name) + self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend( + model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads + ) @component.output_types(embedding=List[float]) def run(self, text: str): diff --git a/integrations/fastembed/tests/test_fastembed_backend.py b/integrations/fastembed/tests/test_fastembed_backend.py index c564c72bf..4dad9525d 100644 --- a/integrations/fastembed/tests/test_fastembed_backend.py +++ b/integrations/fastembed/tests/test_fastembed_backend.py @@ -8,7 +8,9 @@ @patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding") def test_factory_behavior(mock_instructor): # noqa: ARG001 embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name="BAAI/bge-small-en-v1.5") - same_embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend("BAAI/bge-small-en-v1.5") + same_embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend( + model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None + ) another_embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend( model_name="BAAI/bge-base-en-v1.5" ) @@ -25,9 +27,7 @@ def test_model_initialization(mock_instructor): _FastembedEmbeddingBackendFactory.get_embedding_backend( model_name="BAAI/bge-small-en-v1.5", ) - mock_instructor.assert_called_once_with( - model_name="BAAI/bge-small-en-v1.5", - ) + mock_instructor.assert_called_once_with(model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None) # restore the factory state _FastembedEmbeddingBackendFactory._instances = {} diff --git a/integrations/fastembed/tests/test_fastembed_document_embedder.py b/integrations/fastembed/tests/test_fastembed_document_embedder.py index baaf250f2..797c295ba 100644 --- a/integrations/fastembed/tests/test_fastembed_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_document_embedder.py @@ -15,6 +15,8 @@ def test_init_default(self): """ embedder = FastembedDocumentEmbedder(model="BAAI/bge-small-en-v1.5") assert embedder.model_name == "BAAI/bge-small-en-v1.5" + assert embedder.cache_dir is None + assert embedder.threads is None assert embedder.prefix == "" assert embedder.suffix == "" assert embedder.batch_size == 256 @@ -29,6 +31,8 @@ def test_init_with_parameters(self): """ embedder = FastembedDocumentEmbedder( model="BAAI/bge-small-en-v1.5", + cache_dir="fake_dir", + threads=2, prefix="prefix", suffix="suffix", batch_size=64, @@ -38,6 +42,8 @@ def test_init_with_parameters(self): embedding_separator=" | ", ) assert embedder.model_name == "BAAI/bge-small-en-v1.5" + assert embedder.cache_dir == "fake_dir" + assert embedder.threads == 2 assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" assert embedder.batch_size == 64 @@ -56,6 +62,8 @@ def test_to_dict(self): "type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa "init_parameters": { "model": "BAAI/bge-small-en-v1.5", + "cache_dir": None, + "threads": None, "prefix": "", "suffix": "", "batch_size": 256, @@ -72,6 +80,8 @@ def test_to_dict_with_custom_init_parameters(self): """ embedder = FastembedDocumentEmbedder( model="BAAI/bge-small-en-v1.5", + cache_dir="fake_dir", + threads=2, prefix="prefix", suffix="suffix", batch_size=64, @@ -85,6 +95,8 @@ def test_to_dict_with_custom_init_parameters(self): "type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa "init_parameters": { "model": "BAAI/bge-small-en-v1.5", + "cache_dir": "fake_dir", + "threads": 2, "prefix": "prefix", "suffix": "suffix", "batch_size": 64, @@ -103,6 +115,8 @@ def test_from_dict(self): "type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa "init_parameters": { "model": "BAAI/bge-small-en-v1.5", + "cache_dir": None, + "threads": None, "prefix": "", "suffix": "", "batch_size": 256, @@ -114,6 +128,8 @@ def test_from_dict(self): } embedder = default_from_dict(FastembedDocumentEmbedder, embedder_dict) assert embedder.model_name == "BAAI/bge-small-en-v1.5" + assert embedder.cache_dir is None + assert embedder.threads is None assert embedder.prefix == "" assert embedder.suffix == "" assert embedder.batch_size == 256 @@ -130,6 +146,8 @@ def test_from_dict_with_custom_init_parameters(self): "type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa "init_parameters": { "model": "BAAI/bge-small-en-v1.5", + "cache_dir": "fake_dir", + "threads": 2, "prefix": "prefix", "suffix": "suffix", "batch_size": 64, @@ -141,6 +159,8 @@ def test_from_dict_with_custom_init_parameters(self): } embedder = default_from_dict(FastembedDocumentEmbedder, embedder_dict) assert embedder.model_name == "BAAI/bge-small-en-v1.5" + assert embedder.cache_dir == "fake_dir" + assert embedder.threads == 2 assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" assert embedder.batch_size == 64 @@ -160,7 +180,7 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="BAAI/bge-small-en-v1.5", + model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None ) @patch( diff --git a/integrations/fastembed/tests/test_fastembed_text_embedder.py b/integrations/fastembed/tests/test_fastembed_text_embedder.py index 42134a60e..d5982c319 100644 --- a/integrations/fastembed/tests/test_fastembed_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_text_embedder.py @@ -15,6 +15,8 @@ def test_init_default(self): """ embedder = FastembedTextEmbedder(model="BAAI/bge-small-en-v1.5") assert embedder.model_name == "BAAI/bge-small-en-v1.5" + assert embedder.cache_dir is None + assert embedder.threads is None assert embedder.prefix == "" assert embedder.suffix == "" assert embedder.batch_size == 256 @@ -27,6 +29,8 @@ def test_init_with_parameters(self): """ embedder = FastembedTextEmbedder( model="BAAI/bge-small-en-v1.5", + cache_dir="fake_dir", + threads=2, prefix="prefix", suffix="suffix", batch_size=64, @@ -34,6 +38,8 @@ def test_init_with_parameters(self): parallel=1, ) assert embedder.model_name == "BAAI/bge-small-en-v1.5" + assert embedder.cache_dir == "fake_dir" + assert embedder.threads == 2 assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" assert embedder.batch_size == 64 @@ -50,6 +56,8 @@ def test_to_dict(self): "type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa "init_parameters": { "model": "BAAI/bge-small-en-v1.5", + "cache_dir": None, + "threads": None, "prefix": "", "suffix": "", "batch_size": 256, @@ -64,6 +72,8 @@ def test_to_dict_with_custom_init_parameters(self): """ embedder = FastembedTextEmbedder( model="BAAI/bge-small-en-v1.5", + cache_dir="fake_dir", + threads=2, prefix="prefix", suffix="suffix", batch_size=64, @@ -75,6 +85,8 @@ def test_to_dict_with_custom_init_parameters(self): "type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa "init_parameters": { "model": "BAAI/bge-small-en-v1.5", + "cache_dir": "fake_dir", + "threads": 2, "prefix": "prefix", "suffix": "suffix", "batch_size": 64, @@ -91,6 +103,8 @@ def test_from_dict(self): "type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa "init_parameters": { "model": "BAAI/bge-small-en-v1.5", + "cache_dir": None, + "threads": None, "prefix": "", "suffix": "", "batch_size": 256, @@ -100,6 +114,8 @@ def test_from_dict(self): } embedder = default_from_dict(FastembedTextEmbedder, embedder_dict) assert embedder.model_name == "BAAI/bge-small-en-v1.5" + assert embedder.cache_dir is None + assert embedder.threads is None assert embedder.prefix == "" assert embedder.suffix == "" assert embedder.batch_size == 256 @@ -114,6 +130,8 @@ def test_from_dict_with_custom_init_parameters(self): "type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa "init_parameters": { "model": "BAAI/bge-small-en-v1.5", + "cache_dir": "fake_dir", + "threads": 2, "prefix": "prefix", "suffix": "suffix", "batch_size": 64, @@ -123,6 +141,8 @@ def test_from_dict_with_custom_init_parameters(self): } embedder = default_from_dict(FastembedTextEmbedder, embedder_dict) assert embedder.model_name == "BAAI/bge-small-en-v1.5" + assert embedder.cache_dir == "fake_dir" + assert embedder.threads == 2 assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" assert embedder.batch_size == 64 @@ -139,7 +159,9 @@ def test_warmup(self, mocked_factory): embedder = FastembedTextEmbedder(model="BAAI/bge-small-en-v1.5") mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() - mocked_factory.get_embedding_backend.assert_called_once_with(model_name="BAAI/bge-small-en-v1.5") + mocked_factory.get_embedding_backend.assert_called_once_with( + model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None + ) @patch( "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder._FastembedEmbeddingBackendFactory"