Skip to content

Commit

Permalink
formatted with black
Browse files Browse the repository at this point in the history
  • Loading branch information
nickprock committed Feb 9, 2024
1 parent 27a339c commit 2d5ad0a
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ def get_embedding_backend(
embedding_backend = _FastembedEmbeddingBackend(
model_name=model_name,
)
_FastembedEmbeddingBackendFactory._instances[embedding_backend_id] = (
embedding_backend
)
_FastembedEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,15 @@ def warm_up(self):
Load the embedding backend.
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = (
_FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name=self.model_name
)
)
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name=self.model_name)

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
"""
Embed a list of Documents.
The embedding of each Document is stored in the `embedding` field of the Document.
"""
if (
not isinstance(documents, list)
or documents
and not isinstance(documents[0], Document)
):
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
msg = (
"FastembedDocumentEmbedder expects a list of Documents as input. "
"In case you want to embed a list of strings, please use the FastembedTextEmbedder."
Expand All @@ -135,14 +127,10 @@ def run(self, documents: List[Document]):
texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key])
for key in self.meta_fields_to_embed
if key in doc.meta and doc.meta[key] is not None
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
]
text_to_embed = [
self.embedding_separator.join(
[*meta_values_to_embed, doc.content or ""]
),
self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]),
]

texts_to_embed.append(text_to_embed[0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,7 @@ def warm_up(self):
Load the embedding backend.
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = (
_FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name=self.model_name
)
)
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name=self.model_name)

@component.output_types(embedding=List[float])
def run(self, text: str):
Expand Down
24 changes: 6 additions & 18 deletions integrations/fastembed/tests/test_fastembed_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,10 @@
)


@patch(
"haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding"
)
@patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding")
def test_factory_behavior(mock_instructor): # noqa: ARG001
embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name="BAAI/bge-small-en-v1.5"
)
same_embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
"BAAI/bge-small-en-v1.5"
)
embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name="BAAI/bge-small-en-v1.5")
same_embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend("BAAI/bge-small-en-v1.5")
another_embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name="BAAI/bge-base-en-v1.5"
)
Expand All @@ -26,9 +20,7 @@ def test_factory_behavior(mock_instructor): # noqa: ARG001
_FastembedEmbeddingBackendFactory._instances = {}


@patch(
"haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding"
)
@patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding")
def test_model_initialization(mock_instructor):
_FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name="BAAI/bge-small-en-v1.5",
Expand All @@ -40,13 +32,9 @@ def test_model_initialization(mock_instructor):
_FastembedEmbeddingBackendFactory._instances = {}


@patch(
"haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding"
)
@patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding")
def test_embedding_function_with_kwargs(mock_instructor): # noqa: ARG001
embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name="BAAI/bge-small-en-v1.5"
)
embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name="BAAI/bge-small-en-v1.5")

data = ["sentence1", "sentence2"]
embedding_backend.embed(data=data)
Expand Down
11 changes: 2 additions & 9 deletions integrations/fastembed/tests/test_fastembed_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,7 @@ def test_embed(self):
"""
embedder = FastembedDocumentEmbedder(model="BAAI/bge-base-en-v1.5")
embedder.embedding_backend = MagicMock()
embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand( # noqa: ARG005
len(x), 16
).tolist()
embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist() # noqa: ARG005

documents = [Document(content=f"Sample-document text {i}") for i in range(5)]

Expand Down Expand Up @@ -210,12 +208,7 @@ def test_embed_metadata(self):
)
embedder.embedding_backend = MagicMock()

documents = [
Document(
content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"}
)
for i in range(5)
]
documents = [Document(content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]

embedder.run(documents=documents)

Expand Down
12 changes: 3 additions & 9 deletions integrations/fastembed/tests/test_fastembed_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ def test_warmup(self, mocked_factory):
embedder = FastembedTextEmbedder(model="BAAI/bge-small-en-v1.5")
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model_name="BAAI/bge-small-en-v1.5"
)
mocked_factory.get_embedding_backend.assert_called_once_with(model_name="BAAI/bge-small-en-v1.5")

@patch(
"haystack_integrations.components.embedders.fastembed.fastembed_text_embedder._FastembedEmbeddingBackendFactory"
Expand All @@ -141,9 +139,7 @@ def test_embed(self):
"""
embedder = FastembedTextEmbedder(model="BAAI/bge-base-en-v1.5")
embedder.embedding_backend = MagicMock()
embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand( # noqa: ARG005
len(x), 16
).tolist()
embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist() # noqa: ARG005

text = "Good text to embed"

Expand All @@ -162,9 +158,7 @@ def test_run_wrong_incorrect_format(self):

list_integers_input = [1, 2, 3]

with pytest.raises(
TypeError, match="FastembedTextEmbedder expects a string as input"
):
with pytest.raises(TypeError, match="FastembedTextEmbedder expects a string as input"):
embedder.run(text=list_integers_input)

@pytest.mark.integration
Expand Down

0 comments on commit 2d5ad0a

Please sign in to comment.