From 2d5ad0a8537b2b8ab0cfa3951f6b78b503ecefc9 Mon Sep 17 00:00:00 2001 From: Nico Date: Fri, 9 Feb 2024 19:33:13 +0100 Subject: [PATCH] formatted with black --- .../embedding_backend/fastembed_backend.py | 4 +--- .../fastembed/fastembed_document_embedder.py | 20 ++++------------ .../fastembed/fastembed_text_embedder.py | 6 +---- .../fastembed/tests/test_fastembed_backend.py | 24 +++++-------------- .../tests/test_fastembed_document_embedder.py | 11 ++------- .../tests/test_fastembed_text_embedder.py | 12 +++------- 6 files changed, 17 insertions(+), 60 deletions(-) diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py index bf7313103..2b6fc3f38 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py @@ -22,9 +22,7 @@ def get_embedding_backend( embedding_backend = _FastembedEmbeddingBackend( model_name=model_name, ) - _FastembedEmbeddingBackendFactory._instances[embedding_backend_id] = ( - embedding_backend - ) + _FastembedEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend return embedding_backend diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py index 0bf658b6f..63776f48c 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py @@ -104,11 +104,7 @@ def warm_up(self): Load the embedding backend. """ if not hasattr(self, "embedding_backend"): - self.embedding_backend = ( - _FastembedEmbeddingBackendFactory.get_embedding_backend( - model_name=self.model_name - ) - ) + self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name=self.model_name) @component.output_types(documents=List[Document]) def run(self, documents: List[Document]): @@ -116,11 +112,7 @@ def run(self, documents: List[Document]): Embed a list of Documents. The embedding of each Document is stored in the `embedding` field of the Document. """ - if ( - not isinstance(documents, list) - or documents - and not isinstance(documents[0], Document) - ): + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): msg = ( "FastembedDocumentEmbedder expects a list of Documents as input. " "In case you want to embed a list of strings, please use the FastembedTextEmbedder." @@ -135,14 +127,10 @@ def run(self, documents: List[Document]): texts_to_embed = [] for doc in documents: meta_values_to_embed = [ - str(doc.meta[key]) - for key in self.meta_fields_to_embed - if key in doc.meta and doc.meta[key] is not None + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None ] text_to_embed = [ - self.embedding_separator.join( - [*meta_values_to_embed, doc.content or ""] - ), + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]), ] texts_to_embed.append(text_to_embed[0]) diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py index 6c2323df4..986ab1cd1 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py @@ -74,11 +74,7 @@ def warm_up(self): Load the embedding backend. """ if not hasattr(self, "embedding_backend"): - self.embedding_backend = ( - _FastembedEmbeddingBackendFactory.get_embedding_backend( - model_name=self.model_name - ) - ) + self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name=self.model_name) @component.output_types(embedding=List[float]) def run(self, text: str): diff --git a/integrations/fastembed/tests/test_fastembed_backend.py b/integrations/fastembed/tests/test_fastembed_backend.py index 735301eb4..c564c72bf 100644 --- a/integrations/fastembed/tests/test_fastembed_backend.py +++ b/integrations/fastembed/tests/test_fastembed_backend.py @@ -5,16 +5,10 @@ ) -@patch( - "haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding" -) +@patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding") def test_factory_behavior(mock_instructor): # noqa: ARG001 - embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend( - model_name="BAAI/bge-small-en-v1.5" - ) - same_embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend( - "BAAI/bge-small-en-v1.5" - ) + embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name="BAAI/bge-small-en-v1.5") + same_embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend("BAAI/bge-small-en-v1.5") another_embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend( model_name="BAAI/bge-base-en-v1.5" ) @@ -26,9 +20,7 @@ def test_factory_behavior(mock_instructor): # noqa: ARG001 _FastembedEmbeddingBackendFactory._instances = {} -@patch( - "haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding" -) +@patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding") def test_model_initialization(mock_instructor): _FastembedEmbeddingBackendFactory.get_embedding_backend( model_name="BAAI/bge-small-en-v1.5", @@ -40,13 +32,9 @@ def test_model_initialization(mock_instructor): _FastembedEmbeddingBackendFactory._instances = {} -@patch( - "haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding" -) +@patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding") def test_embedding_function_with_kwargs(mock_instructor): # noqa: ARG001 - embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend( - model_name="BAAI/bge-small-en-v1.5" - ) + embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name="BAAI/bge-small-en-v1.5") data = ["sentence1", "sentence2"] embedding_backend.embed(data=data) diff --git a/integrations/fastembed/tests/test_fastembed_document_embedder.py b/integrations/fastembed/tests/test_fastembed_document_embedder.py index 80c889b4f..a387db797 100644 --- a/integrations/fastembed/tests/test_fastembed_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_document_embedder.py @@ -162,9 +162,7 @@ def test_embed(self): """ embedder = FastembedDocumentEmbedder(model="BAAI/bge-base-en-v1.5") embedder.embedding_backend = MagicMock() - embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand( # noqa: ARG005 - len(x), 16 - ).tolist() + embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist() # noqa: ARG005 documents = [Document(content=f"Sample-document text {i}") for i in range(5)] @@ -210,12 +208,7 @@ def test_embed_metadata(self): ) embedder.embedding_backend = MagicMock() - documents = [ - Document( - content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"} - ) - for i in range(5) - ] + documents = [Document(content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] embedder.run(documents=documents) diff --git a/integrations/fastembed/tests/test_fastembed_text_embedder.py b/integrations/fastembed/tests/test_fastembed_text_embedder.py index c30954f87..ee9bfd3da 100644 --- a/integrations/fastembed/tests/test_fastembed_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_text_embedder.py @@ -118,9 +118,7 @@ def test_warmup(self, mocked_factory): embedder = FastembedTextEmbedder(model="BAAI/bge-small-en-v1.5") mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() - mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="BAAI/bge-small-en-v1.5" - ) + mocked_factory.get_embedding_backend.assert_called_once_with(model_name="BAAI/bge-small-en-v1.5") @patch( "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder._FastembedEmbeddingBackendFactory" @@ -141,9 +139,7 @@ def test_embed(self): """ embedder = FastembedTextEmbedder(model="BAAI/bge-base-en-v1.5") embedder.embedding_backend = MagicMock() - embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand( # noqa: ARG005 - len(x), 16 - ).tolist() + embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist() # noqa: ARG005 text = "Good text to embed" @@ -162,9 +158,7 @@ def test_run_wrong_incorrect_format(self): list_integers_input = [1, 2, 3] - with pytest.raises( - TypeError, match="FastembedTextEmbedder expects a string as input" - ): + with pytest.raises(TypeError, match="FastembedTextEmbedder expects a string as input"): embedder.run(text=list_integers_input) @pytest.mark.integration