Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fastembed fix: added prefix and suffix #390

Merged
merged 61 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
1e98d90
created project
nickprock Feb 8, 2024
ab89dc8
added parallel param
nickprock Feb 8, 2024
4dae714
updated test
nickprock Feb 8, 2024
1c6d4e8
version 0.0.1
nickprock Feb 8, 2024
551db0c
renamed folder
nickprock Feb 8, 2024
4ccac33
Merge branch 'main' into dev/fastembed
nickprock Feb 8, 2024
a0c7a6e
removed print
nickprock Feb 8, 2024
88a9282
updated readme
nickprock Feb 8, 2024
7153197
Merge branch 'dev/fastembed' of https://github.com/nickprock/haystack…
nickprock Feb 8, 2024
547728b
Merge branch 'main' into dev/fastembed
nickprock Feb 8, 2024
a73146f
added fastembed.yml
nickprock Feb 8, 2024
ead43ed
Merge branch 'dev/fastembed' of https://github.com/nickprock/haystack…
nickprock Feb 8, 2024
54565b7
fix typos
nickprock Feb 8, 2024
f68a310
python version to 3.9 for lint
nickprock Feb 8, 2024
cbba970
updated file
nickprock Feb 8, 2024
2a08cbe
force install black
nickprock Feb 8, 2024
2fccaf6
return to original file
nickprock Feb 8, 2024
afd4645
try to fix workflow
anakin87 Feb 8, 2024
2e01f32
retry
anakin87 Feb 8, 2024
d357330
add missing info to pyproject
anakin87 Feb 8, 2024
174c51b
add hatch-vcs to check version
anakin87 Feb 8, 2024
c9eadfa
Update pyproject.toml
anakin87 Feb 8, 2024
bba17ae
fixed typos
nickprock Feb 8, 2024
d91f449
removed python 3.9
nickprock Feb 8, 2024
eee09c3
Update fastembed.yml
nickprock Feb 9, 2024
2355aa3
Update fastembed_document_embedder.py
nickprock Feb 9, 2024
c31bb3b
Update fastembed_text_embedder.py
nickprock Feb 9, 2024
ba5cb28
ignore errors for bool arguments
anakin87 Feb 9, 2024
3e0c1fe
fix
anakin87 Feb 9, 2024
abe8a97
try moving noqa
anakin87 Feb 9, 2024
27a339c
move noqa
anakin87 Feb 9, 2024
2d5ad0a
formatted with black
nickprock Feb 9, 2024
79a5f9f
added numpy dependency
nickprock Feb 10, 2024
fb2bd05
removed numpy
nickprock Feb 10, 2024
d93092b
Merge branch 'dev/fastembed' of https://github.com/nickprock/haystack…
nickprock Feb 10, 2024
690659c
removed numpy
nickprock Feb 10, 2024
4bbb169
make mypy happy
anakin87 Feb 10, 2024
d9ae567
Update fastembed_backend.py
anakin87 Feb 10, 2024
87e3fe7
removed classvar
nickprock Feb 10, 2024
668572c
fix
nickprock Feb 10, 2024
206b842
Update pyproject.toml
anakin87 Feb 10, 2024
5c144d6
added import numpy lint
nickprock Feb 10, 2024
4c1f98b
Merge branch 'dev/fastembed' of https://github.com/nickprock/haystack…
nickprock Feb 10, 2024
070f04d
skip docs generation for the time being
anakin87 Feb 10, 2024
2842ff9
Update README.md
anakin87 Feb 10, 2024
a7bf308
added config.yml
nickprock Feb 10, 2024
8ca6daf
Merge branch 'dev/fastembed' of https://github.com/nickprock/haystack…
nickprock Feb 10, 2024
673e3e7
generate docs
nickprock Feb 10, 2024
e73d719
Update fastembed.yml
anakin87 Feb 10, 2024
9d79e1c
Update config.yml
anakin87 Feb 10, 2024
4d46161
rm unnecessary from_dict
anakin87 Feb 10, 2024
c5980d3
final touch
anakin87 Feb 10, 2024
8dd145e
updated labeler.yml
nickprock Feb 11, 2024
f0bb7d9
updated library readme
nickprock Feb 11, 2024
76e88c7
Merge branch 'main' into dev/fastembed
anakin87 Feb 11, 2024
0847181
fix typos
anakin87 Feb 11, 2024
37a6aaf
fix docstrings/README
anakin87 Feb 11, 2024
28f33c2
Merge branch 'dev/fastembed' into dev/fastembed_fix
nickprock Feb 11, 2024
2cef63f
added prefix and suffix
nickprock Feb 11, 2024
7d18df6
fixed typos
nickprock Feb 11, 2024
b705443
Merge branch 'main' into dev/fastembed_fix
anakin87 Feb 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class FastembedDocumentEmbedder:
def __init__(
self,
model: str = "BAAI/bge-small-en-v1.5",
prefix: str = "",
suffix: str = "",
batch_size: int = 256,
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
Expand All @@ -63,13 +65,17 @@ def __init__(

:param model: Local path or name of the model in Hugging Face's model hub,
such as ``'BAAI/bge-small-en-v1.5'``.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param batch_size: Number of strings to encode at once.
:param progress_bar: If true, displays progress bar during embedding.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""

self.model_name = model
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
Expand All @@ -82,6 +88,8 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
model=self.model_name,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
Expand All @@ -95,6 +103,19 @@ def warm_up(self):
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name=self.model_name)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
]
text_to_embed = [
self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix,
]

texts_to_embed.append(text_to_embed[0])
return texts_to_embed

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
"""
Expand All @@ -113,16 +134,7 @@ def run(self, documents: List[Document]):

# TODO: once non textual Documents are properly supported, we should also prepare them for embedding here

texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
]
text_to_embed = [
self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]),
]

texts_to_embed.append(text_to_embed[0])
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
embeddings = self.embedding_backend.embed(
texts_to_embed,
batch_size=self.batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class FastembedTextEmbedder:
def __init__(
self,
model: str = "BAAI/bge-small-en-v1.5",
prefix: str = "",
suffix: str = "",
batch_size: int = 256,
progress_bar: bool = True,
):
Expand All @@ -40,11 +42,15 @@ def __init__(
:param model: Local path or name of the model in Fastembed's model hub,
such as ``'BAAI/bge-small-en-v1.5'``.
:param batch_size: Number of strings to encode at once.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
"""

# TODO add parallel

self.model_name = model
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar

Expand All @@ -55,6 +61,8 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
model=self.model_name,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
)
Expand All @@ -79,7 +87,7 @@ def run(self, text: str):
msg = "The embedding model has not been loaded. Please call warm_up() before running."
raise RuntimeError(msg)

text_to_embed = [text]
text_to_embed = [self.prefix + text + self.suffix]
embedding = list(
self.embedding_backend.embed(
text_to_embed,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def test_init_default(self):
"""
embedder = FastembedDocumentEmbedder(model="BAAI/bge-small-en-v1.5")
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
assert embedder.progress_bar is True
assert embedder.meta_fields_to_embed == []
Expand All @@ -26,12 +28,16 @@ def test_init_with_parameters(self):
"""
embedder = FastembedDocumentEmbedder(
model="BAAI/bge-small-en-v1.5",
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.meta_fields_to_embed == ["test_field"]
Expand All @@ -47,6 +53,8 @@ def test_to_dict(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"prefix": "",
"suffix": "",
"batch_size": 256,
"progress_bar": True,
"embedding_separator": "\n",
Expand All @@ -60,6 +68,8 @@ def test_to_dict_with_custom_init_parameters(self):
"""
embedder = FastembedDocumentEmbedder(
model="BAAI/bge-small-en-v1.5",
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
meta_fields_to_embed=["test_field"],
Expand All @@ -70,6 +80,8 @@ def test_to_dict_with_custom_init_parameters(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"meta_fields_to_embed": ["test_field"],
Expand All @@ -85,6 +97,8 @@ def test_from_dict(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"prefix": "",
"suffix": "",
"batch_size": 256,
"progress_bar": True,
"meta_fields_to_embed": [],
Expand All @@ -93,6 +107,8 @@ def test_from_dict(self):
}
embedder = default_from_dict(FastembedDocumentEmbedder, embedder_dict)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
assert embedder.progress_bar is True
assert embedder.meta_fields_to_embed == []
Expand All @@ -106,6 +122,8 @@ def test_from_dict_with_custom_init_parameters(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"meta_fields_to_embed": ["test_field"],
Expand All @@ -114,6 +132,8 @@ def test_from_dict_with_custom_init_parameters(self):
}
embedder = default_from_dict(FastembedDocumentEmbedder, embedder_dict)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.meta_fields_to_embed == ["test_field"]
Expand Down
20 changes: 20 additions & 0 deletions integrations/fastembed/tests/test_fastembed_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def test_init_default(self):
"""
embedder = FastembedTextEmbedder(model="BAAI/bge-small-en-v1.5")
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
assert embedder.progress_bar is True

Expand All @@ -24,10 +26,14 @@ def test_init_with_parameters(self):
"""
embedder = FastembedTextEmbedder(
model="BAAI/bge-small-en-v1.5",
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False

Expand All @@ -41,6 +47,8 @@ def test_to_dict(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"prefix": "",
"suffix": "",
"batch_size": 256,
"progress_bar": True,
},
Expand All @@ -52,6 +60,8 @@ def test_to_dict_with_custom_init_parameters(self):
"""
embedder = FastembedTextEmbedder(
model="BAAI/bge-small-en-v1.5",
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
)
Expand All @@ -60,6 +70,8 @@ def test_to_dict_with_custom_init_parameters(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
},
Expand All @@ -73,12 +85,16 @@ def test_from_dict(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"prefix": "",
"suffix": "",
"batch_size": 256,
"progress_bar": True,
},
}
embedder = default_from_dict(FastembedTextEmbedder, embedder_dict)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
assert embedder.progress_bar is True

Expand All @@ -90,12 +106,16 @@ def test_from_dict_with_custom_init_parameters(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
},
}
embedder = default_from_dict(FastembedTextEmbedder, embedder_dict)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False

Expand Down