From fae1e366372ba89bd4b469604f1c24520c69d789 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 5 Jan 2024 10:59:20 +0100 Subject: [PATCH 1/2] Optimize API key reading (#162) * optimize api key reading * fmt and rm warning match --- .../cohere_haystack/chat/chat_generator.py | 11 ++++++---- .../embedders/document_embedder.py | 17 +++++++-------- .../embedders/text_embedder.py | 17 +++++++-------- .../cohere/src/cohere_haystack/generator.py | 8 +++---- .../tests/test_cohere_chat_generator.py | 4 ++-- .../src/jina_haystack/document_embedder.py | 21 ++++++++----------- .../jina/src/jina_haystack/text_embedder.py | 19 ++++++++--------- .../jina/tests/test_document_embedder.py | 2 +- integrations/jina/tests/test_text_embedder.py | 2 +- .../src/pinecone_haystack/document_store.py | 2 +- .../fileconverter.py | 19 +++++++++-------- 11 files changed, 60 insertions(+), 62 deletions(-) diff --git a/integrations/cohere/src/cohere_haystack/chat/chat_generator.py b/integrations/cohere/src/cohere_haystack/chat/chat_generator.py index f3178d567..be236f6ca 100644 --- a/integrations/cohere/src/cohere_haystack/chat/chat_generator.py +++ b/integrations/cohere/src/cohere_haystack/chat/chat_generator.py @@ -68,11 +68,14 @@ def __init__( """ cohere_import.check() + api_key = api_key or os.environ.get("COHERE_API_KEY") + # we check whether api_key is None or an empty string if not api_key: - api_key = os.environ.get("COHERE_API_KEY") - if not api_key: - error = "CohereChatGenerator needs an API key to run. Either provide it as init parameter or set the env var COHERE_API_KEY." # noqa: E501 - raise ValueError(error) + msg = ( + "CohereChatGenerator expects an API key. " + "Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly." + ) + raise ValueError(msg) if not api_base_url: api_base_url = cohere.COHERE_API_URL diff --git a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py index 6c87f3537..bc0b9381d 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py +++ b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py @@ -78,15 +78,14 @@ def __init__( :param embedding_separator: Separator used to concatenate the meta fields to the Document text. """ - if api_key is None: - try: - api_key = os.environ["COHERE_API_KEY"] - except KeyError as error_msg: - msg = ( - "CohereDocumentEmbedder expects an Cohere API key. Please provide one by setting the environment " - "variable COHERE_API_KEY (recommended) or by passing it explicitly." - ) - raise ValueError(msg) from error_msg + api_key = api_key or os.environ.get("COHERE_API_KEY") + # we check whether api_key is None or an empty string + if not api_key: + msg = ( + "CohereDocumentEmbedder expects an API key. " + "Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly." + ) + raise ValueError(msg) self.api_key = api_key self.model_name = model_name diff --git a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py index 25822223e..4ba8acd47 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py +++ b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py @@ -67,15 +67,14 @@ def __init__( :param timeout: Request timeout in seconds, defaults to `120`. """ - if api_key is None: - try: - api_key = os.environ["COHERE_API_KEY"] - except KeyError as error_msg: - msg = ( - "CohereTextEmbedder expects an Cohere API key. Please provide one by setting the environment " - "variable COHERE_API_KEY (recommended) or by passing it explicitly." - ) - raise ValueError(msg) from error_msg + api_key = api_key or os.environ.get("COHERE_API_KEY") + # we check whether api_key is None or an empty string + if not api_key: + msg = ( + "CohereTextEmbedder expects an API key. " + "Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly." + ) + raise ValueError(msg) self.api_key = api_key self.model_name = model_name diff --git a/integrations/cohere/src/cohere_haystack/generator.py b/integrations/cohere/src/cohere_haystack/generator.py index 571464c0c..66c80afa4 100644 --- a/integrations/cohere/src/cohere_haystack/generator.py +++ b/integrations/cohere/src/cohere_haystack/generator.py @@ -73,12 +73,12 @@ def __init__( - 'logit_bias': Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. The format is {token_id: bias} where bias is a float between -10 and 10. """ - if not api_key: - api_key = os.environ.get("COHERE_API_KEY") + api_key = api_key or os.environ.get("COHERE_API_KEY") + # we check whether api_key is None or an empty string if not api_key: msg = ( - "CohereGenerator needs an API key to run." - "Either provide it as init parameter or set the env var COHERE_API_KEY." + "CohereGenerator expects an API key. " + "Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly." ) raise ValueError(msg) diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 92954df8b..f9ac7b2c6 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -65,7 +65,7 @@ def test_init_default(self): @pytest.mark.unit def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("COHERE_API_KEY", raising=False) - with pytest.raises(ValueError, match=r"^CohereChatGenerator needs an API key to run. (.+)$"): + with pytest.raises(ValueError): CohereChatGenerator() @pytest.mark.unit @@ -167,7 +167,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } - with pytest.raises(ValueError, match=r"^CohereChatGenerator needs an API key to run. (.+)$"): + with pytest.raises(ValueError): CohereChatGenerator.from_dict(data) @pytest.mark.unit diff --git a/integrations/jina/src/jina_haystack/document_embedder.py b/integrations/jina/src/jina_haystack/document_embedder.py index 03f64462f..9f51a9e26 100644 --- a/integrations/jina/src/jina_haystack/document_embedder.py +++ b/integrations/jina/src/jina_haystack/document_embedder.py @@ -57,22 +57,19 @@ def __init__( :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text. :param embedding_separator: Separator used to concatenate the meta fields to the Document text. """ - # if the user does not provide the API key, check if it is set in the module client - if api_key is None: - try: - api_key = os.environ["JINA_API_KEY"] - except KeyError as e: - msg = ( - "JinaDocumentEmbedder expects a Jina API key. " - "Set the JINA_API_KEY environment variable (recommended) or pass it explicitly." - ) - raise ValueError(msg) from e + + api_key = api_key or os.environ.get("JINA_API_KEY") + # we check whether api_key is None or an empty string + if not api_key: + msg = ( + "JinaDocumentEmbedder expects an API key. " + "Set the JINA_API_KEY environment variable (recommended) or pass it explicitly." + ) + raise ValueError(msg) self.model_name = model_name self.prefix = prefix self.suffix = suffix - self.prefix = prefix - self.suffix = suffix self.batch_size = batch_size self.progress_bar = progress_bar self.meta_fields_to_embed = meta_fields_to_embed or [] diff --git a/integrations/jina/src/jina_haystack/text_embedder.py b/integrations/jina/src/jina_haystack/text_embedder.py index 5b29bef6d..f717f4748 100644 --- a/integrations/jina/src/jina_haystack/text_embedder.py +++ b/integrations/jina/src/jina_haystack/text_embedder.py @@ -47,16 +47,15 @@ def __init__( :param prefix: A string to add to the beginning of each text. :param suffix: A string to add to the end of each text. """ - # if the user does not provide the API key, check if it is set in the module client - if api_key is None: - try: - api_key = os.environ["JINA_API_KEY"] - except KeyError as e: - msg = ( - "JinaTextEmbedder expects a Jina API key. " - "Set the JINA_API_KEY environment variable (recommended) or pass it explicitly." - ) - raise ValueError(msg) from e + + api_key = api_key or os.environ.get("JINA_API_KEY") + # we check whether api_key is None or an empty string + if not api_key: + msg = ( + "JinaTextEmbedder expects an API key. " + "Set the JINA_API_KEY environment variable (recommended) or pass it explicitly." + ) + raise ValueError(msg) self.model_name = model_name self.prefix = prefix diff --git a/integrations/jina/tests/test_document_embedder.py b/integrations/jina/tests/test_document_embedder.py index ac8bb6975..2ebc5d358 100644 --- a/integrations/jina/tests/test_document_embedder.py +++ b/integrations/jina/tests/test_document_embedder.py @@ -58,7 +58,7 @@ def test_init_with_parameters(self): def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("JINA_API_KEY", raising=False) - with pytest.raises(ValueError, match="JinaDocumentEmbedder expects a Jina API key"): + with pytest.raises(ValueError): JinaDocumentEmbedder() def test_to_dict(self): diff --git a/integrations/jina/tests/test_text_embedder.py b/integrations/jina/tests/test_text_embedder.py index c8a730c2f..7dfd64a05 100644 --- a/integrations/jina/tests/test_text_embedder.py +++ b/integrations/jina/tests/test_text_embedder.py @@ -32,7 +32,7 @@ def test_init_with_parameters(self): def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("JINA_API_KEY", raising=False) - with pytest.raises(ValueError, match="JinaTextEmbedder expects a Jina API key"): + with pytest.raises(ValueError): JinaTextEmbedder() def test_to_dict(self): diff --git a/integrations/pinecone/src/pinecone_haystack/document_store.py b/integrations/pinecone/src/pinecone_haystack/document_store.py index d6296e030..252437a9a 100644 --- a/integrations/pinecone/src/pinecone_haystack/document_store.py +++ b/integrations/pinecone/src/pinecone_haystack/document_store.py @@ -60,7 +60,7 @@ def __init__( api_key = api_key or os.environ.get("PINECONE_API_KEY") if not api_key: msg = ( - "PineconeDocumentStore expects a Pinecone API key. " + "PineconeDocumentStore expects an API key. " "Set the PINECONE_API_KEY environment variable (recommended) or pass it explicitly." ) raise ValueError(msg) diff --git a/integrations/unstructured/fileconverter/src/unstructured_fileconverter_haystack/fileconverter.py b/integrations/unstructured/fileconverter/src/unstructured_fileconverter_haystack/fileconverter.py index 5a565d00b..d94cb49c4 100644 --- a/integrations/unstructured/fileconverter/src/unstructured_fileconverter_haystack/fileconverter.py +++ b/integrations/unstructured/fileconverter/src/unstructured_fileconverter_haystack/fileconverter.py @@ -60,15 +60,16 @@ def __init__( self.progress_bar = progress_bar is_hosted_api = api_url == UNSTRUCTURED_HOSTED_API_URL - if api_key is None and is_hosted_api: - try: - api_key = os.environ["UNSTRUCTURED_API_KEY"] - except KeyError as e: - msg = ( - "To use the hosted version of Unstructured, you need to set the environment variable " - "UNSTRUCTURED_API_KEY (recommended) or explictly pass the parameter api_key." - ) - raise ValueError(msg) from e + + api_key = api_key or os.environ.get("UNSTRUCTURED_API_KEY") + # we check whether api_key is None or an empty string + if is_hosted_api and not api_key: + msg = ( + "To use the hosted version of Unstructured, you need to set the environment variable " + "UNSTRUCTURED_API_KEY (recommended) or explictly pass the parameter api_key." + ) + raise ValueError(msg) + self.api_key = api_key def to_dict(self) -> Dict[str, Any]: From 3e314ea0af74fe72ab97b93609a7080411afbffc Mon Sep 17 00:00:00 2001 From: sahusiddharth <112792547+sahusiddharth@users.noreply.github.com> Date: Fri, 5 Jan 2024 18:54:19 +0530 Subject: [PATCH 2/2] renamed QdrntRetriever to QdrntEmbeddingRetriever (#174) --- .../qdrant/src/qdrant_haystack/retriever.py | 6 +++--- integrations/qdrant/tests/test_retriever.py | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/integrations/qdrant/src/qdrant_haystack/retriever.py b/integrations/qdrant/src/qdrant_haystack/retriever.py index 054ba96ac..bf378688c 100644 --- a/integrations/qdrant/src/qdrant_haystack/retriever.py +++ b/integrations/qdrant/src/qdrant_haystack/retriever.py @@ -6,7 +6,7 @@ @component -class QdrantRetriever: +class QdrantEmbeddingRetriever: """ A component for retrieving documents from an QdrantDocumentStore. """ @@ -20,7 +20,7 @@ def __init__( return_embedding: bool = False, # noqa: FBT001, FBT002 ): """ - Create a QdrantRetriever component. + Create a QdrantEmbeddingRetriever component. :param document_store: An instance of QdrantDocumentStore. :param filters: A dictionary with filters to narrow down the search space. Default is None. @@ -59,7 +59,7 @@ def to_dict(self) -> Dict[str, Any]: return d @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "QdrantRetriever": + def from_dict(cls, data: Dict[str, Any]) -> "QdrantEmbeddingRetriever": """ Deserialize this component from a dictionary. """ diff --git a/integrations/qdrant/tests/test_retriever.py b/integrations/qdrant/tests/test_retriever.py index 22eabfaad..ed220c5bc 100644 --- a/integrations/qdrant/tests/test_retriever.py +++ b/integrations/qdrant/tests/test_retriever.py @@ -7,13 +7,13 @@ ) from qdrant_haystack import QdrantDocumentStore -from qdrant_haystack.retriever import QdrantRetriever +from qdrant_haystack.retriever import QdrantEmbeddingRetriever class TestQdrantRetriever(FilterableDocsFixtureMixin): def test_init_default(self): document_store = QdrantDocumentStore(location=":memory:", index="test") - retriever = QdrantRetriever(document_store=document_store) + retriever = QdrantEmbeddingRetriever(document_store=document_store) assert retriever._document_store == document_store assert retriever._filters is None assert retriever._top_k == 10 @@ -21,10 +21,10 @@ def test_init_default(self): def test_to_dict(self): document_store = QdrantDocumentStore(location=":memory:", index="test") - retriever = QdrantRetriever(document_store=document_store) + retriever = QdrantEmbeddingRetriever(document_store=document_store) res = retriever.to_dict() assert res == { - "type": "qdrant_haystack.retriever.QdrantRetriever", + "type": "qdrant_haystack.retriever.QdrantEmbeddingRetriever", "init_parameters": { "document_store": { "type": "qdrant_haystack.document_store.QdrantDocumentStore", @@ -74,7 +74,7 @@ def test_to_dict(self): def test_from_dict(self): data = { - "type": "qdrant_haystack.retriever.QdrantRetriever", + "type": "qdrant_haystack.retriever.QdrantEmbeddingRetriever", "init_parameters": { "document_store": { "init_parameters": {"location": ":memory:", "index": "test"}, @@ -86,7 +86,7 @@ def test_from_dict(self): "return_embedding": True, }, } - retriever = QdrantRetriever.from_dict(data) + retriever = QdrantEmbeddingRetriever.from_dict(data) assert isinstance(retriever._document_store, QdrantDocumentStore) assert retriever._document_store.index == "test" assert retriever._filters is None @@ -99,7 +99,7 @@ def test_run(self, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) - retriever = QdrantRetriever(document_store=document_store) + retriever = QdrantEmbeddingRetriever(document_store=document_store) results: List[Document] = retriever.run(query_embedding=_random_embeddings(768))