Skip to content

Commit

Permalink
fastembed integration new parameters (#446)
Browse files Browse the repository at this point in the history
* added threads and cache_dir in backend

* added threads and cache_dir to text embedding

* added threads and cache_dit in documents embedder

* fix test

* formatted with back

* fixed test
  • Loading branch information
nickprock authored Feb 20, 2024
1 parent f762d76 commit f3ea6be
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, Dict, List
from typing import ClassVar, Dict, List, Optional

from fastembed import TextEmbedding

Expand All @@ -13,15 +13,15 @@ class _FastembedEmbeddingBackendFactory:
@staticmethod
def get_embedding_backend(
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
):
embedding_backend_id = f"{model_name}"
embedding_backend_id = f"{model_name}{cache_dir}{threads}"

if embedding_backend_id in _FastembedEmbeddingBackendFactory._instances:
return _FastembedEmbeddingBackendFactory._instances[embedding_backend_id]

embedding_backend = _FastembedEmbeddingBackend(
model_name=model_name,
)
embedding_backend = _FastembedEmbeddingBackend(model_name=model_name, cache_dir=cache_dir, threads=threads)
_FastembedEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend

Expand All @@ -34,8 +34,10 @@ class _FastembedEmbeddingBackend:
def __init__(
self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
):
self.model = TextEmbedding(model_name=model_name)
self.model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads)

def embed(self, data: List[List[str]], **kwargs) -> List[List[float]]:
# the embed method returns a Iterable[np.ndarray], so we convert it to a list of lists
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class FastembedDocumentEmbedder:
def __init__(
self,
model: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
prefix: str = "",
suffix: str = "",
batch_size: int = 256,
Expand All @@ -66,6 +68,10 @@ def __init__(
:param model: Local path or name of the model in Hugging Face's model hub,
such as ``'BAAI/bge-small-en-v1.5'``.
:param cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
:param threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param batch_size: Number of strings to encode at once.
Expand All @@ -79,6 +85,8 @@ def __init__(
"""

self.model_name = model
self.cache_dir = cache_dir
self.threads = threads
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
Expand All @@ -94,6 +102,8 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
model=self.model_name,
cache_dir=self.cache_dir,
threads=self.threads,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
Expand All @@ -108,7 +118,9 @@ def warm_up(self):
Load the embedding backend.
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name=self.model_name)
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads
)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
texts_to_embed = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class FastembedTextEmbedder:
def __init__(
self,
model: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
prefix: str = "",
suffix: str = "",
batch_size: int = 256,
Expand All @@ -42,6 +44,10 @@ def __init__(
:param model: Local path or name of the model in Fastembed's model hub,
such as ``'BAAI/bge-small-en-v1.5'``.
:param cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
:param threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
:param batch_size: Number of strings to encode at once.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
Expand All @@ -52,9 +58,9 @@ def __init__(
If None, don't use data-parallel processing, use default onnxruntime threading instead.
"""

# TODO add parallel

self.model_name = model
self.cache_dir = cache_dir
self.threads = threads
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
Expand All @@ -68,6 +74,8 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
model=self.model_name,
cache_dir=self.cache_dir,
threads=self.threads,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
Expand All @@ -80,7 +88,9 @@ def warm_up(self):
Load the embedding backend.
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name=self.model_name)
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads
)

@component.output_types(embedding=List[float])
def run(self, text: str):
Expand Down
8 changes: 4 additions & 4 deletions integrations/fastembed/tests/test_fastembed_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
@patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding")
def test_factory_behavior(mock_instructor): # noqa: ARG001
embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name="BAAI/bge-small-en-v1.5")
same_embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend("BAAI/bge-small-en-v1.5")
same_embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None
)
another_embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name="BAAI/bge-base-en-v1.5"
)
Expand All @@ -25,9 +27,7 @@ def test_model_initialization(mock_instructor):
_FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name="BAAI/bge-small-en-v1.5",
)
mock_instructor.assert_called_once_with(
model_name="BAAI/bge-small-en-v1.5",
)
mock_instructor.assert_called_once_with(model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None)
# restore the factory state
_FastembedEmbeddingBackendFactory._instances = {}

Expand Down
22 changes: 21 additions & 1 deletion integrations/fastembed/tests/test_fastembed_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def test_init_default(self):
"""
embedder = FastembedDocumentEmbedder(model="BAAI/bge-small-en-v1.5")
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.cache_dir is None
assert embedder.threads is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
Expand All @@ -29,6 +31,8 @@ def test_init_with_parameters(self):
"""
embedder = FastembedDocumentEmbedder(
model="BAAI/bge-small-en-v1.5",
cache_dir="fake_dir",
threads=2,
prefix="prefix",
suffix="suffix",
batch_size=64,
Expand All @@ -38,6 +42,8 @@ def test_init_with_parameters(self):
embedding_separator=" | ",
)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.cache_dir == "fake_dir"
assert embedder.threads == 2
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
Expand All @@ -56,6 +62,8 @@ def test_to_dict(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"cache_dir": None,
"threads": None,
"prefix": "",
"suffix": "",
"batch_size": 256,
Expand All @@ -72,6 +80,8 @@ def test_to_dict_with_custom_init_parameters(self):
"""
embedder = FastembedDocumentEmbedder(
model="BAAI/bge-small-en-v1.5",
cache_dir="fake_dir",
threads=2,
prefix="prefix",
suffix="suffix",
batch_size=64,
Expand All @@ -85,6 +95,8 @@ def test_to_dict_with_custom_init_parameters(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"cache_dir": "fake_dir",
"threads": 2,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
Expand All @@ -103,6 +115,8 @@ def test_from_dict(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"cache_dir": None,
"threads": None,
"prefix": "",
"suffix": "",
"batch_size": 256,
Expand All @@ -114,6 +128,8 @@ def test_from_dict(self):
}
embedder = default_from_dict(FastembedDocumentEmbedder, embedder_dict)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.cache_dir is None
assert embedder.threads is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
Expand All @@ -130,6 +146,8 @@ def test_from_dict_with_custom_init_parameters(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"cache_dir": "fake_dir",
"threads": 2,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
Expand All @@ -141,6 +159,8 @@ def test_from_dict_with_custom_init_parameters(self):
}
embedder = default_from_dict(FastembedDocumentEmbedder, embedder_dict)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.cache_dir == "fake_dir"
assert embedder.threads == 2
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
Expand All @@ -160,7 +180,7 @@ def test_warmup(self, mocked_factory):
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model_name="BAAI/bge-small-en-v1.5",
model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None
)

@patch(
Expand Down
24 changes: 23 additions & 1 deletion integrations/fastembed/tests/test_fastembed_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def test_init_default(self):
"""
embedder = FastembedTextEmbedder(model="BAAI/bge-small-en-v1.5")
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.cache_dir is None
assert embedder.threads is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
Expand All @@ -27,13 +29,17 @@ def test_init_with_parameters(self):
"""
embedder = FastembedTextEmbedder(
model="BAAI/bge-small-en-v1.5",
cache_dir="fake_dir",
threads=2,
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
parallel=1,
)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.cache_dir == "fake_dir"
assert embedder.threads == 2
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
Expand All @@ -50,6 +56,8 @@ def test_to_dict(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"cache_dir": None,
"threads": None,
"prefix": "",
"suffix": "",
"batch_size": 256,
Expand All @@ -64,6 +72,8 @@ def test_to_dict_with_custom_init_parameters(self):
"""
embedder = FastembedTextEmbedder(
model="BAAI/bge-small-en-v1.5",
cache_dir="fake_dir",
threads=2,
prefix="prefix",
suffix="suffix",
batch_size=64,
Expand All @@ -75,6 +85,8 @@ def test_to_dict_with_custom_init_parameters(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"cache_dir": "fake_dir",
"threads": 2,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
Expand All @@ -91,6 +103,8 @@ def test_from_dict(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"cache_dir": None,
"threads": None,
"prefix": "",
"suffix": "",
"batch_size": 256,
Expand All @@ -100,6 +114,8 @@ def test_from_dict(self):
}
embedder = default_from_dict(FastembedTextEmbedder, embedder_dict)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.cache_dir is None
assert embedder.threads is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
Expand All @@ -114,6 +130,8 @@ def test_from_dict_with_custom_init_parameters(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"cache_dir": "fake_dir",
"threads": 2,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
Expand All @@ -123,6 +141,8 @@ def test_from_dict_with_custom_init_parameters(self):
}
embedder = default_from_dict(FastembedTextEmbedder, embedder_dict)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.cache_dir == "fake_dir"
assert embedder.threads == 2
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
Expand All @@ -139,7 +159,9 @@ def test_warmup(self, mocked_factory):
embedder = FastembedTextEmbedder(model="BAAI/bge-small-en-v1.5")
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(model_name="BAAI/bge-small-en-v1.5")
mocked_factory.get_embedding_backend.assert_called_once_with(
model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None
)

@patch(
"haystack_integrations.components.embedders.fastembed.fastembed_text_embedder._FastembedEmbeddingBackendFactory"
Expand Down

0 comments on commit f3ea6be

Please sign in to comment.