Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fastembed integration new parameters #446

Merged
merged 7 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, Dict, List
from typing import ClassVar, Dict, List, Optional

from fastembed import TextEmbedding

Expand All @@ -13,15 +13,15 @@ class _FastembedEmbeddingBackendFactory:
@staticmethod
def get_embedding_backend(
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
):
embedding_backend_id = f"{model_name}"
embedding_backend_id = f"{model_name}{cache_dir}{threads}"

if embedding_backend_id in _FastembedEmbeddingBackendFactory._instances:
return _FastembedEmbeddingBackendFactory._instances[embedding_backend_id]

embedding_backend = _FastembedEmbeddingBackend(
model_name=model_name,
)
embedding_backend = _FastembedEmbeddingBackend(model_name=model_name, cache_dir=cache_dir, threads=threads)
_FastembedEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend

Expand All @@ -34,8 +34,10 @@ class _FastembedEmbeddingBackend:
def __init__(
self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
):
self.model = TextEmbedding(model_name=model_name)
self.model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads)

def embed(self, data: List[List[str]], **kwargs) -> List[List[float]]:
# the embed method returns a Iterable[np.ndarray], so we convert it to a list of lists
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class FastembedDocumentEmbedder:
def __init__(
self,
model: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
prefix: str = "",
suffix: str = "",
batch_size: int = 256,
Expand All @@ -66,6 +68,10 @@ def __init__(

:param model: Local path or name of the model in Hugging Face's model hub,
such as ``'BAAI/bge-small-en-v1.5'``.
:param cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
:param threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param batch_size: Number of strings to encode at once.
Expand All @@ -79,6 +85,8 @@ def __init__(
"""

self.model_name = model
self.cache_dir = cache_dir
self.threads = threads
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
Expand All @@ -94,6 +102,8 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
model=self.model_name,
cache_dir=self.cache_dir,
threads=self.threads,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
Expand All @@ -108,7 +118,9 @@ def warm_up(self):
Load the embedding backend.
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name=self.model_name)
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads
)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
texts_to_embed = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class FastembedTextEmbedder:
def __init__(
self,
model: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
prefix: str = "",
suffix: str = "",
batch_size: int = 256,
Expand All @@ -42,6 +44,10 @@ def __init__(

:param model: Local path or name of the model in Fastembed's model hub,
such as ``'BAAI/bge-small-en-v1.5'``.
:param cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
:param threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
:param batch_size: Number of strings to encode at once.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
Expand All @@ -52,9 +58,9 @@ def __init__(
If None, don't use data-parallel processing, use default onnxruntime threading instead.
"""

# TODO add parallel

self.model_name = model
self.cache_dir = cache_dir
self.threads = threads
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
Expand All @@ -68,6 +74,8 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
model=self.model_name,
cache_dir=self.cache_dir,
threads=self.threads,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
Expand All @@ -80,7 +88,9 @@ def warm_up(self):
Load the embedding backend.
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name=self.model_name)
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads
)

@component.output_types(embedding=List[float])
def run(self, text: str):
Expand Down
8 changes: 4 additions & 4 deletions integrations/fastembed/tests/test_fastembed_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
@patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding")
def test_factory_behavior(mock_instructor): # noqa: ARG001
embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name="BAAI/bge-small-en-v1.5")
same_embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend("BAAI/bge-small-en-v1.5")
same_embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None
)
another_embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name="BAAI/bge-base-en-v1.5"
)
Expand All @@ -25,9 +27,7 @@ def test_model_initialization(mock_instructor):
_FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name="BAAI/bge-small-en-v1.5",
)
mock_instructor.assert_called_once_with(
model_name="BAAI/bge-small-en-v1.5",
)
mock_instructor.assert_called_once_with(model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None)
# restore the factory state
_FastembedEmbeddingBackendFactory._instances = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def test_init_default(self):
"""
embedder = FastembedDocumentEmbedder(model="BAAI/bge-small-en-v1.5")
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.cache_dir is None
assert embedder.threads is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
Expand All @@ -29,6 +31,8 @@ def test_init_with_parameters(self):
"""
embedder = FastembedDocumentEmbedder(
model="BAAI/bge-small-en-v1.5",
cache_dir="fake_dir",
threads=2,
prefix="prefix",
suffix="suffix",
batch_size=64,
Expand All @@ -38,6 +42,8 @@ def test_init_with_parameters(self):
embedding_separator=" | ",
)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.cache_dir == "fake_dir"
assert embedder.threads == 2
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
Expand All @@ -56,6 +62,8 @@ def test_to_dict(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"cache_dir": None,
"threads": None,
"prefix": "",
"suffix": "",
"batch_size": 256,
Expand All @@ -72,6 +80,8 @@ def test_to_dict_with_custom_init_parameters(self):
"""
embedder = FastembedDocumentEmbedder(
model="BAAI/bge-small-en-v1.5",
cache_dir="fake_dir",
threads=2,
prefix="prefix",
suffix="suffix",
batch_size=64,
Expand All @@ -85,6 +95,8 @@ def test_to_dict_with_custom_init_parameters(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"cache_dir": "fake_dir",
"threads": 2,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
Expand All @@ -103,6 +115,8 @@ def test_from_dict(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"cache_dir": None,
"threads": None,
"prefix": "",
"suffix": "",
"batch_size": 256,
Expand All @@ -114,6 +128,8 @@ def test_from_dict(self):
}
embedder = default_from_dict(FastembedDocumentEmbedder, embedder_dict)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.cache_dir is None
assert embedder.threads is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
Expand All @@ -130,6 +146,8 @@ def test_from_dict_with_custom_init_parameters(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"cache_dir": "fake_dir",
"threads": 2,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
Expand All @@ -141,6 +159,8 @@ def test_from_dict_with_custom_init_parameters(self):
}
embedder = default_from_dict(FastembedDocumentEmbedder, embedder_dict)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.cache_dir == "fake_dir"
assert embedder.threads == 2
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
Expand All @@ -160,7 +180,7 @@ def test_warmup(self, mocked_factory):
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model_name="BAAI/bge-small-en-v1.5",
model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None
)

@patch(
Expand Down
24 changes: 23 additions & 1 deletion integrations/fastembed/tests/test_fastembed_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def test_init_default(self):
"""
embedder = FastembedTextEmbedder(model="BAAI/bge-small-en-v1.5")
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.cache_dir is None
assert embedder.threads is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
Expand All @@ -27,13 +29,17 @@ def test_init_with_parameters(self):
"""
embedder = FastembedTextEmbedder(
model="BAAI/bge-small-en-v1.5",
cache_dir="fake_dir",
threads=2,
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
parallel=1,
)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.cache_dir == "fake_dir"
assert embedder.threads == 2
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
Expand All @@ -50,6 +56,8 @@ def test_to_dict(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"cache_dir": None,
"threads": None,
"prefix": "",
"suffix": "",
"batch_size": 256,
Expand All @@ -64,6 +72,8 @@ def test_to_dict_with_custom_init_parameters(self):
"""
embedder = FastembedTextEmbedder(
model="BAAI/bge-small-en-v1.5",
cache_dir="fake_dir",
threads=2,
prefix="prefix",
suffix="suffix",
batch_size=64,
Expand All @@ -75,6 +85,8 @@ def test_to_dict_with_custom_init_parameters(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"cache_dir": "fake_dir",
"threads": 2,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
Expand All @@ -91,6 +103,8 @@ def test_from_dict(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"cache_dir": None,
"threads": None,
"prefix": "",
"suffix": "",
"batch_size": 256,
Expand All @@ -100,6 +114,8 @@ def test_from_dict(self):
}
embedder = default_from_dict(FastembedTextEmbedder, embedder_dict)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.cache_dir is None
assert embedder.threads is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
Expand All @@ -114,6 +130,8 @@ def test_from_dict_with_custom_init_parameters(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"cache_dir": "fake_dir",
"threads": 2,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
Expand All @@ -123,6 +141,8 @@ def test_from_dict_with_custom_init_parameters(self):
}
embedder = default_from_dict(FastembedTextEmbedder, embedder_dict)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.cache_dir == "fake_dir"
assert embedder.threads == 2
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
Expand All @@ -139,7 +159,9 @@ def test_warmup(self, mocked_factory):
embedder = FastembedTextEmbedder(model="BAAI/bge-small-en-v1.5")
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(model_name="BAAI/bge-small-en-v1.5")
mocked_factory.get_embedding_backend.assert_called_once_with(
model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None
)

@patch(
"haystack_integrations.components.embedders.fastembed.fastembed_text_embedder._FastembedEmbeddingBackendFactory"
Expand Down