Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improvements to FastEmbed integration #558

Merged
merged 2 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions integrations/fastembed/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,13 @@ ban-relative-imports = "parents"
"examples/**/*" = ["T201"]

[tool.coverage.run]
source_pkgs = ["src", "tests"]
source = ["haystack_integrations"]
branch = true
parallel = true


[tool.coverage.paths]
fastembed_haystack = ["src/haystack_integrations", "*/fastembed-haystack/src"]
tests = ["tests", "*/fastembed-haystack/tests"]
parallel = false

[tool.coverage.report]
omit = ["*/tests/*", "*/__init__.py"]
show_missing=true
Comment on lines -161 to +167
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

show coverage

exclude_lines = [
"no cov",
"if __name__ == .__main__.:",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import ClassVar, Dict, List, Optional

from tqdm import tqdm

from fastembed import TextEmbedding


Expand Down Expand Up @@ -39,7 +41,12 @@ def __init__(
):
self.model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads)

def embed(self, data: List[List[str]], **kwargs) -> List[List[float]]:
def embed(self, data: List[str], progress_bar=True, **kwargs) -> List[List[float]]:
# the embed method returns a Iterable[np.ndarray], so we convert it to a list of lists
embeddings = [np_array.tolist() for np_array in self.model.embed(data, **kwargs)]
embeddings = []
embeddings_iterable = self.model.embed(data, **kwargs)
for np_array in tqdm(
embeddings_iterable, disable=not progress_bar, desc="Calculating embeddings", total=len(data)
):
embeddings.append(np_array.tolist())
Comment on lines +44 to +51
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since the original library does not provide this feature,
we create the progress bar using tqdm in the embedding backend.

return embeddings
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
]
text_to_embed = [
self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix,
]
text_to_embed = (
self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix
)

texts_to_embed.append(text_to_embed[0])
texts_to_embed.append(text_to_embed)
return texts_to_embed

@component.output_types(documents=List[Document])
Expand All @@ -157,13 +157,11 @@ def run(self, documents: List[Document]):
msg = "The embedding model has not been loaded. Please call warm_up() before running."
raise RuntimeError(msg)

# TODO: once non textual Documents are properly supported, we should also prepare them for embedding here

texts_to_embed = self._prepare_texts_to_embed(documents=documents)
embeddings = self.embedding_backend.embed(
texts_to_embed,
batch_size=self.batch_size,
show_progress_bar=self.progress_bar,
progress_bar=self.progress_bar,
parallel=self.parallel,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
threads: Optional[int] = None,
prefix: str = "",
suffix: str = "",
batch_size: int = 256,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the text embedder only accepts a single string,
so batch_size was misleading

progress_bar: bool = True,
parallel: Optional[int] = None,
):
Expand All @@ -47,7 +46,6 @@ def __init__(
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
:param threads: The number of threads single onnxruntime session can use. Defaults to None.
:param batch_size: Number of strings to encode at once.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param progress_bar: If true, displays progress bar during embedding.
Expand All @@ -62,7 +60,6 @@ def __init__(
self.threads = threads
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar
self.parallel = parallel

Expand All @@ -80,7 +77,6 @@ def to_dict(self) -> Dict[str, Any]:
threads=self.threads,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
parallel=self.parallel,
)
Expand Down Expand Up @@ -119,8 +115,7 @@ def run(self, text: str):
embedding = list(
self.embedding_backend.embed(
text_to_embed,
batch_size=self.batch_size,
show_progress_bar=self.progress_bar,
progress_bar=self.progress_bar,
parallel=self.parallel,
)[0]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def test_embed_metadata(self):
"meta_value 4\ndocument-number 4",
],
batch_size=256,
show_progress_bar=True,
progress_bar=True,
parallel=None,
)

Expand Down
10 changes: 0 additions & 10 deletions integrations/fastembed/tests/test_fastembed_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def test_init_default(self):
assert embedder.threads is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
assert embedder.progress_bar is True
assert embedder.parallel is None

Expand All @@ -33,7 +32,6 @@ def test_init_with_parameters(self):
threads=2,
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
parallel=1,
)
Expand All @@ -42,7 +40,6 @@ def test_init_with_parameters(self):
assert embedder.threads == 2
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.parallel == 1

Expand All @@ -60,7 +57,6 @@ def test_to_dict(self):
"threads": None,
"prefix": "",
"suffix": "",
"batch_size": 256,
"progress_bar": True,
"parallel": None,
},
Expand All @@ -76,7 +72,6 @@ def test_to_dict_with_custom_init_parameters(self):
threads=2,
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
parallel=1,
)
Expand All @@ -89,7 +84,6 @@ def test_to_dict_with_custom_init_parameters(self):
"threads": 2,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"parallel": 1,
},
Expand All @@ -107,7 +101,6 @@ def test_from_dict(self):
"threads": None,
"prefix": "",
"suffix": "",
"batch_size": 256,
"progress_bar": True,
"parallel": None,
},
Expand All @@ -118,7 +111,6 @@ def test_from_dict(self):
assert embedder.threads is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
assert embedder.progress_bar is True
assert embedder.parallel is None

Expand All @@ -134,7 +126,6 @@ def test_from_dict_with_custom_init_parameters(self):
"threads": 2,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"parallel": 1,
},
Expand All @@ -145,7 +136,6 @@ def test_from_dict_with_custom_init_parameters(self):
assert embedder.threads == 2
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.parallel == 1

Expand Down