Skip to content

Commit

Permalink
improvements to FastEmbed integration
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Mar 7, 2024
1 parent 810ad84 commit d30f125
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import ClassVar, Dict, List, Optional

from tqdm import tqdm

from fastembed import TextEmbedding


Expand Down Expand Up @@ -39,7 +41,12 @@ def __init__(
):
self.model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads)

def embed(self, data: List[List[str]], **kwargs) -> List[List[float]]:
def embed(self, data: List[str], progress_bar=True, **kwargs) -> List[List[float]]:
# the embed method returns a Iterable[np.ndarray], so we convert it to a list of lists
embeddings = [np_array.tolist() for np_array in self.model.embed(data, **kwargs)]
embeddings = []
embeddings_iterable = self.model.embed(data, **kwargs)
for np_array in tqdm(
embeddings_iterable, disable=not progress_bar, desc="Calculating embeddings", total=len(data)
):
embeddings.append(np_array.tolist())
return embeddings
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
]
text_to_embed = [
self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix,
]
text_to_embed = (
self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix
)

texts_to_embed.append(text_to_embed[0])
texts_to_embed.append(text_to_embed)
return texts_to_embed

@component.output_types(documents=List[Document])
Expand All @@ -157,13 +157,11 @@ def run(self, documents: List[Document]):
msg = "The embedding model has not been loaded. Please call warm_up() before running."
raise RuntimeError(msg)

# TODO: once non textual Documents are properly supported, we should also prepare them for embedding here

texts_to_embed = self._prepare_texts_to_embed(documents=documents)
embeddings = self.embedding_backend.embed(
texts_to_embed,
batch_size=self.batch_size,
show_progress_bar=self.progress_bar,
progress_bar=self.progress_bar,
parallel=self.parallel,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
threads: Optional[int] = None,
prefix: str = "",
suffix: str = "",
batch_size: int = 256,
progress_bar: bool = True,
parallel: Optional[int] = None,
):
Expand All @@ -47,7 +46,6 @@ def __init__(
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
:param threads: The number of threads single onnxruntime session can use. Defaults to None.
:param batch_size: Number of strings to encode at once.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param progress_bar: If true, displays progress bar during embedding.
Expand All @@ -62,7 +60,6 @@ def __init__(
self.threads = threads
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar
self.parallel = parallel

Expand All @@ -80,7 +77,6 @@ def to_dict(self) -> Dict[str, Any]:
threads=self.threads,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
parallel=self.parallel,
)
Expand Down Expand Up @@ -119,8 +115,7 @@ def run(self, text: str):
embedding = list(
self.embedding_backend.embed(
text_to_embed,
batch_size=self.batch_size,
show_progress_bar=self.progress_bar,
progress_bar=self.progress_bar,
parallel=self.parallel,
)[0]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def test_embed_metadata(self):
"meta_value 4\ndocument-number 4",
],
batch_size=256,
show_progress_bar=True,
progress_bar=True,
parallel=None,
)

Expand Down
10 changes: 0 additions & 10 deletions integrations/fastembed/tests/test_fastembed_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def test_init_default(self):
assert embedder.threads is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
assert embedder.progress_bar is True
assert embedder.parallel is None

Expand All @@ -33,7 +32,6 @@ def test_init_with_parameters(self):
threads=2,
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
parallel=1,
)
Expand All @@ -42,7 +40,6 @@ def test_init_with_parameters(self):
assert embedder.threads == 2
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.parallel == 1

Expand All @@ -60,7 +57,6 @@ def test_to_dict(self):
"threads": None,
"prefix": "",
"suffix": "",
"batch_size": 256,
"progress_bar": True,
"parallel": None,
},
Expand All @@ -76,7 +72,6 @@ def test_to_dict_with_custom_init_parameters(self):
threads=2,
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
parallel=1,
)
Expand All @@ -89,7 +84,6 @@ def test_to_dict_with_custom_init_parameters(self):
"threads": 2,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"parallel": 1,
},
Expand All @@ -107,7 +101,6 @@ def test_from_dict(self):
"threads": None,
"prefix": "",
"suffix": "",
"batch_size": 256,
"progress_bar": True,
"parallel": None,
},
Expand All @@ -118,7 +111,6 @@ def test_from_dict(self):
assert embedder.threads is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
assert embedder.progress_bar is True
assert embedder.parallel is None

Expand All @@ -134,7 +126,6 @@ def test_from_dict_with_custom_init_parameters(self):
"threads": 2,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"parallel": 1,
},
Expand All @@ -145,7 +136,6 @@ def test_from_dict_with_custom_init_parameters(self):
assert embedder.threads == 2
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.parallel == 1

Expand Down

0 comments on commit d30f125

Please sign in to comment.