Skip to content

Commit

Permalink
Use the local_files_only option available as of fastembed==0.2.7. It … (
Browse files Browse the repository at this point in the history
#736)

* Use the local_files_only option available as of fastembed==0.2.7. It allows to not look for the models online, but only use the local, cached, files.
This way, we can download the model once then use this without internet access

* Fix lint issues

* add same param to doc embedder

---------

Co-authored-by: anakin87 <[email protected]>
  • Loading branch information
paulmartrencharpro and anakin87 authored May 15, 2024
1 parent 0e02fd6 commit 6f298ce
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@ def get_embedding_backend(
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
local_files_only: bool = False,
):
embedding_backend_id = f"{model_name}{cache_dir}{threads}"

if embedding_backend_id in _FastembedEmbeddingBackendFactory._instances:
return _FastembedEmbeddingBackendFactory._instances[embedding_backend_id]

embedding_backend = _FastembedEmbeddingBackend(model_name=model_name, cache_dir=cache_dir, threads=threads)
embedding_backend = _FastembedEmbeddingBackend(
model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only
)
_FastembedEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend

Expand All @@ -40,8 +43,11 @@ def __init__(
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
local_files_only: bool = False,
):
self.model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads)
self.model = TextEmbedding(
model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only
)

def embed(self, data: List[str], progress_bar=True, **kwargs) -> List[List[float]]:
# the embed method returns a Iterable[np.ndarray], so we convert it to a list of lists
Expand All @@ -66,14 +72,15 @@ def get_embedding_backend(
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
local_files_only: bool = False,
):
embedding_backend_id = f"{model_name}{cache_dir}{threads}"

if embedding_backend_id in _FastembedSparseEmbeddingBackendFactory._instances:
return _FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id]

embedding_backend = _FastembedSparseEmbeddingBackend(
model_name=model_name, cache_dir=cache_dir, threads=threads
model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only
)
_FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend
Expand All @@ -89,8 +96,11 @@ def __init__(
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
local_files_only: bool = False,
):
self.model = SparseTextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads)
self.model = SparseTextEmbedding(
model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only
)

def embed(self, data: List[List[str]], progress_bar=True, **kwargs) -> List[SparseEmbedding]:
# The embed method returns a Iterable[SparseEmbedding], so we convert to Haystack SparseEmbedding type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
batch_size: int = 256,
progress_bar: bool = True,
parallel: Optional[int] = None,
local_files_only: bool = False,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
Expand All @@ -80,11 +81,12 @@ def __init__(
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param batch_size: Number of strings to encode at once.
:param progress_bar: If true, displays progress bar during embedding.
:param progress_bar: If `True`, displays progress bar during embedding.
:param parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
:param local_files_only: If `True`, only use the model files in the `cache_dir`.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""
Expand All @@ -97,6 +99,7 @@ def __init__(
self.batch_size = batch_size
self.progress_bar = progress_bar
self.parallel = parallel
self.local_files_only = local_files_only
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator

Expand All @@ -116,6 +119,7 @@ def to_dict(self) -> Dict[str, Any]:
batch_size=self.batch_size,
progress_bar=self.progress_bar,
parallel=self.parallel,
local_files_only=self.local_files_only,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
)
Expand All @@ -126,7 +130,10 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads
model_name=self.model_name,
cache_dir=self.cache_dir,
threads=self.threads,
local_files_only=self.local_files_only,
)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
batch_size: int = 32,
progress_bar: bool = True,
parallel: Optional[int] = None,
local_files_only: bool = False,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
Expand All @@ -77,6 +78,7 @@ def __init__(
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
:param local_files_only: If `True`, only use the model files in the `cache_dir`.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""
Expand All @@ -87,6 +89,7 @@ def __init__(
self.batch_size = batch_size
self.progress_bar = progress_bar
self.parallel = parallel
self.local_files_only = local_files_only
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator

Expand All @@ -104,6 +107,7 @@ def to_dict(self) -> Dict[str, Any]:
batch_size=self.batch_size,
progress_bar=self.progress_bar,
parallel=self.parallel,
local_files_only=self.local_files_only,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
)
Expand All @@ -114,7 +118,10 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _FastembedSparseEmbeddingBackendFactory.get_embedding_backend(
model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads
model_name=self.model_name,
cache_dir=self.cache_dir,
threads=self.threads,
local_files_only=self.local_files_only,
)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
threads: Optional[int] = None,
progress_bar: bool = True,
parallel: Optional[int] = None,
local_files_only: bool = False,
):
"""
Create a FastembedSparseTextEmbedder component.
Expand All @@ -43,18 +44,20 @@ def __init__(
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
:param threads: The number of threads single onnxruntime session can use. Defaults to None.
:param progress_bar: If true, displays progress bar during embedding.
:param progress_bar: If `True`, displays progress bar during embedding.
:param parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
:param local_files_only: If `True`, only use the model files in the `cache_dir`.
"""

self.model_name = model
self.cache_dir = cache_dir
self.threads = threads
self.progress_bar = progress_bar
self.parallel = parallel
self.local_files_only = local_files_only

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -70,6 +73,7 @@ def to_dict(self) -> Dict[str, Any]:
threads=self.threads,
progress_bar=self.progress_bar,
parallel=self.parallel,
local_files_only=self.local_files_only,
)

def warm_up(self):
Expand All @@ -78,7 +82,10 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _FastembedSparseEmbeddingBackendFactory.get_embedding_backend(
model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads
model_name=self.model_name,
cache_dir=self.cache_dir,
threads=self.threads,
local_files_only=self.local_files_only,
)

@component.output_types(sparse_embedding=SparseEmbedding)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
suffix: str = "",
progress_bar: bool = True,
parallel: Optional[int] = None,
local_files_only: bool = False,
):
"""
Create a FastembedTextEmbedder component.
Expand All @@ -46,11 +47,12 @@ def __init__(
:param threads: The number of threads single onnxruntime session can use. Defaults to None.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param progress_bar: If true, displays progress bar during embedding.
:param progress_bar: If `True`, displays progress bar during embedding.
:param parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
:param local_files_only: If `True`, only use the model files in the `cache_dir`.
"""

self.model_name = model
Expand All @@ -60,6 +62,7 @@ def __init__(
self.suffix = suffix
self.progress_bar = progress_bar
self.parallel = parallel
self.local_files_only = local_files_only

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -77,6 +80,7 @@ def to_dict(self) -> Dict[str, Any]:
suffix=self.suffix,
progress_bar=self.progress_bar,
parallel=self.parallel,
local_files_only=self.local_files_only,
)

def warm_up(self):
Expand All @@ -85,7 +89,10 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads
model_name=self.model_name,
cache_dir=self.cache_dir,
threads=self.threads,
local_files_only=self.local_files_only,
)

@component.output_types(embedding=List[float])
Expand Down
4 changes: 3 additions & 1 deletion integrations/fastembed/tests/test_fastembed_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def test_model_initialization(mock_instructor):
_FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name="BAAI/bge-small-en-v1.5",
)
mock_instructor.assert_called_once_with(model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None)
mock_instructor.assert_called_once_with(
model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None, local_files_only=False
)
# restore the factory state
_FastembedEmbeddingBackendFactory._instances = {}

Expand Down
12 changes: 11 additions & 1 deletion integrations/fastembed/tests/test_fastembed_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_init_default(self):
assert embedder.batch_size == 256
assert embedder.progress_bar is True
assert embedder.parallel is None
assert not embedder.local_files_only
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"

Expand All @@ -38,6 +39,7 @@ def test_init_with_parameters(self):
batch_size=64,
progress_bar=False,
parallel=1,
local_files_only=True,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
)
Expand All @@ -49,6 +51,7 @@ def test_init_with_parameters(self):
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.parallel == 1
assert embedder.local_files_only
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "

Expand All @@ -69,6 +72,7 @@ def test_to_dict(self):
"batch_size": 256,
"progress_bar": True,
"parallel": None,
"local_files_only": False,
"embedding_separator": "\n",
"meta_fields_to_embed": [],
},
Expand All @@ -87,6 +91,7 @@ def test_to_dict_with_custom_init_parameters(self):
batch_size=64,
progress_bar=False,
parallel=1,
local_files_only=True,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
)
Expand All @@ -102,6 +107,7 @@ def test_to_dict_with_custom_init_parameters(self):
"batch_size": 64,
"progress_bar": False,
"parallel": 1,
"local_files_only": True,
"meta_fields_to_embed": ["test_field"],
"embedding_separator": " | ",
},
Expand All @@ -122,6 +128,7 @@ def test_from_dict(self):
"batch_size": 256,
"progress_bar": True,
"parallel": None,
"local_files_only": False,
"meta_fields_to_embed": [],
"embedding_separator": "\n",
},
Expand All @@ -135,6 +142,7 @@ def test_from_dict(self):
assert embedder.batch_size == 256
assert embedder.progress_bar is True
assert embedder.parallel is None
assert not embedder.local_files_only
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"

Expand All @@ -153,6 +161,7 @@ def test_from_dict_with_custom_init_parameters(self):
"batch_size": 64,
"progress_bar": False,
"parallel": 1,
"local_files_only": True,
"meta_fields_to_embed": ["test_field"],
"embedding_separator": " | ",
},
Expand All @@ -166,6 +175,7 @@ def test_from_dict_with_custom_init_parameters(self):
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.parallel == 1
assert embedder.local_files_only
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "

Expand All @@ -180,7 +190,7 @@ def test_warmup(self, mocked_factory):
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None
model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None, local_files_only=False
)

@patch(
Expand Down
Loading

0 comments on commit 6f298ce

Please sign in to comment.