Skip to content

Commit

Permalink
fastembed fix: add parallel (#403)
Browse files Browse the repository at this point in the history
* created project

* added parallel param

* updated test

* version 0.0.1

* renamed folder

* removed print

* updated readme

* added fastembed.yml

* fix typos

* python version to 3.9 for lint

* updated file

* force install black

* return to original file

* try to fix workflow

* retry

* add missing info to pyproject

* add hatch-vcs to check version

* Update pyproject.toml

* fixed typos

* removed python 3.9

* Update fastembed.yml

* Update fastembed_document_embedder.py

* Update fastembed_text_embedder.py

* ignore errors for bool arguments

* fix

* try moving noqa

* move noqa

* formatted with black

* added numpy dependency

* removed numpy

* removed numpy

* make mypy happy

* Update fastembed_backend.py

* removed classvar

* fix

* Update pyproject.toml

* added import numpy lint

* skip docs generation for the time being

* Update README.md

* added config.yml

* generate docs

* Update fastembed.yml

* Update config.yml

* rm unnecessary from_dict

* final touch

* updated labeler.yml

* updated library readme

* fix typos

* fix docstrings/README

* added prefix and suffix

* fixed typos

* Update fastembed_text_embedder.py

from numpy float to float

* Update fastembed_document_embedder.py

from numpy float to float

* Update test_fastembed_text_embedder.py

from numpy float to float

* Update test_fastembed_document_embedder.py

from numpy float to float

* Update fastembed_document_embedder.py

fix typos

* Update fastembed_text_embedder.py

added if in run

* Update fastembed_document_embedder.py

added if into run

* modify backend

* added parallel to text_embedder

* fix: added Optional

* fix

* added comment

* added parameter parallel in document_embedder

---------

Co-authored-by: Stefano Fiorucci <[email protected]>
  • Loading branch information
nickprock and anakin87 authored Feb 16, 2024
1 parent babef6f commit ec58d6f
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
suffix: str = "",
batch_size: int = 256,
progress_bar: bool = True,
parallel: Optional[int] = None,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
Expand All @@ -69,6 +70,10 @@ def __init__(
:param suffix: A string to add to the end of each text.
:param batch_size: Number of strings to encode at once.
:param progress_bar: If true, displays progress bar during embedding.
:param parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""
Expand All @@ -78,6 +83,7 @@ def __init__(
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar
self.parallel = parallel
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator

Expand All @@ -92,6 +98,7 @@ def to_dict(self) -> Dict[str, Any]:
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
parallel=self.parallel,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
)
Expand Down Expand Up @@ -139,6 +146,7 @@ def run(self, documents: List[Document]):
texts_to_embed,
batch_size=self.batch_size,
show_progress_bar=self.progress_bar,
parallel=self.parallel,
)

for doc, emb in zip(documents, embeddings):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from haystack import component, default_to_dict

Expand Down Expand Up @@ -35,6 +35,7 @@ def __init__(
suffix: str = "",
batch_size: int = 256,
progress_bar: bool = True,
parallel: Optional[int] = None,
):
"""
Create a FastembedTextEmbedder component.
Expand All @@ -44,6 +45,11 @@ def __init__(
:param batch_size: Number of strings to encode at once.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param progress_bar: If true, displays progress bar during embedding.
:param parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
"""

# TODO add parallel
Expand All @@ -53,6 +59,7 @@ def __init__(
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar
self.parallel = parallel

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -65,6 +72,7 @@ def to_dict(self) -> Dict[str, Any]:
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
parallel=self.parallel,
)

def warm_up(self):
Expand Down Expand Up @@ -93,6 +101,7 @@ def run(self, text: str):
text_to_embed,
batch_size=self.batch_size,
show_progress_bar=self.progress_bar,
parallel=self.parallel,
)[0]
)
return {"embedding": embedding}
11 changes: 11 additions & 0 deletions integrations/fastembed/tests/test_fastembed_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_init_default(self):
assert embedder.suffix == ""
assert embedder.batch_size == 256
assert embedder.progress_bar is True
assert embedder.parallel is None
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"

Expand All @@ -32,6 +33,7 @@ def test_init_with_parameters(self):
suffix="suffix",
batch_size=64,
progress_bar=False,
parallel=1,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
)
Expand All @@ -40,6 +42,7 @@ def test_init_with_parameters(self):
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.parallel == 1
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "

Expand All @@ -57,6 +60,7 @@ def test_to_dict(self):
"suffix": "",
"batch_size": 256,
"progress_bar": True,
"parallel": None,
"embedding_separator": "\n",
"meta_fields_to_embed": [],
},
Expand All @@ -72,6 +76,7 @@ def test_to_dict_with_custom_init_parameters(self):
suffix="suffix",
batch_size=64,
progress_bar=False,
parallel=1,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
)
Expand All @@ -84,6 +89,7 @@ def test_to_dict_with_custom_init_parameters(self):
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"parallel": 1,
"meta_fields_to_embed": ["test_field"],
"embedding_separator": " | ",
},
Expand All @@ -101,6 +107,7 @@ def test_from_dict(self):
"suffix": "",
"batch_size": 256,
"progress_bar": True,
"parallel": None,
"meta_fields_to_embed": [],
"embedding_separator": "\n",
},
Expand All @@ -111,6 +118,7 @@ def test_from_dict(self):
assert embedder.suffix == ""
assert embedder.batch_size == 256
assert embedder.progress_bar is True
assert embedder.parallel is None
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"

Expand All @@ -126,6 +134,7 @@ def test_from_dict_with_custom_init_parameters(self):
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"parallel": 1,
"meta_fields_to_embed": ["test_field"],
"embedding_separator": " | ",
},
Expand All @@ -136,6 +145,7 @@ def test_from_dict_with_custom_init_parameters(self):
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.parallel == 1
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "

Expand Down Expand Up @@ -232,6 +242,7 @@ def test_embed_metadata(self):
],
batch_size=256,
show_progress_bar=True,
parallel=None,
)

@pytest.mark.integration
Expand Down
10 changes: 10 additions & 0 deletions integrations/fastembed/tests/test_fastembed_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_init_default(self):
assert embedder.suffix == ""
assert embedder.batch_size == 256
assert embedder.progress_bar is True
assert embedder.parallel is None

def test_init_with_parameters(self):
"""
Expand All @@ -30,12 +31,14 @@ def test_init_with_parameters(self):
suffix="suffix",
batch_size=64,
progress_bar=False,
parallel=1,
)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.parallel == 1

def test_to_dict(self):
"""
Expand All @@ -51,6 +54,7 @@ def test_to_dict(self):
"suffix": "",
"batch_size": 256,
"progress_bar": True,
"parallel": None,
},
}

Expand All @@ -64,6 +68,7 @@ def test_to_dict_with_custom_init_parameters(self):
suffix="suffix",
batch_size=64,
progress_bar=False,
parallel=1,
)
embedder_dict = embedder.to_dict()
assert embedder_dict == {
Expand All @@ -74,6 +79,7 @@ def test_to_dict_with_custom_init_parameters(self):
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"parallel": 1,
},
}

Expand All @@ -89,6 +95,7 @@ def test_from_dict(self):
"suffix": "",
"batch_size": 256,
"progress_bar": True,
"parallel": None,
},
}
embedder = default_from_dict(FastembedTextEmbedder, embedder_dict)
Expand All @@ -97,6 +104,7 @@ def test_from_dict(self):
assert embedder.suffix == ""
assert embedder.batch_size == 256
assert embedder.progress_bar is True
assert embedder.parallel is None

def test_from_dict_with_custom_init_parameters(self):
"""
Expand All @@ -110,6 +118,7 @@ def test_from_dict_with_custom_init_parameters(self):
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"parallel": 1,
},
}
embedder = default_from_dict(FastembedTextEmbedder, embedder_dict)
Expand All @@ -118,6 +127,7 @@ def test_from_dict_with_custom_init_parameters(self):
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.parallel == 1

@patch(
"haystack_integrations.components.embedders.fastembed.fastembed_text_embedder._FastembedEmbeddingBackendFactory"
Expand Down

0 comments on commit ec58d6f

Please sign in to comment.