From 8e4db9e43bc1402683f32b25d40ec0ce7f92ab0c Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 26 Feb 2024 13:15:38 +0100 Subject: [PATCH] refactor tests (#487) --- .../tests/test_mistral_chat_generator.py | 27 +++---------------- .../tests/test_mistral_document_embedder.py | 10 ++++--- .../tests/test_mistral_text_embedder.py | 14 +++++++--- 3 files changed, 22 insertions(+), 29 deletions(-) diff --git a/integrations/mistral/tests/test_mistral_chat_generator.py b/integrations/mistral/tests/test_mistral_chat_generator.py index d2a4129e2..181397c00 100644 --- a/integrations/mistral/tests/test_mistral_chat_generator.py +++ b/integrations/mistral/tests/test_mistral_chat_generator.py @@ -147,7 +147,8 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): with pytest.raises(ValueError, match="None of the .* environment variables are set"): MistralChatGenerator.from_dict(data) - def test_run(self, chat_messages): + def test_run(self, chat_messages, mock_chat_completion, monkeypatch): # noqa: ARG002 + monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key") component = MistralChatGenerator() response = component.run(chat_messages) @@ -158,7 +159,8 @@ def test_run(self, chat_messages): assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - def test_run_with_params(self, chat_messages, mock_chat_completion): + def test_run_with_params(self, chat_messages, mock_chat_completion, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key") component = MistralChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5}) response = component.run(chat_messages) @@ -174,27 +176,6 @@ def test_run_with_params(self, chat_messages, mock_chat_completion): assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - def test_run_with_params_streaming(self, chat_messages): - streaming_callback_called = False - - def streaming_callback(chunk: StreamingChunk) -> None: # noqa: ARG001 - nonlocal streaming_callback_called - streaming_callback_called = True - - component = MistralChatGenerator(streaming_callback=streaming_callback) - response = component.run(chat_messages) - - # check we called the streaming callback - assert streaming_callback_called - - # check that the component still returns the correct response - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - assert "Paris" in response["replies"][0].content # see mock_chat_completion_chunk - def test_check_abnormal_completions(self, caplog): component = MistralChatGenerator(api_key=Secret.from_token("test-api-key")) messages = [ diff --git a/integrations/mistral/tests/test_mistral_document_embedder.py b/integrations/mistral/tests/test_mistral_document_embedder.py index 85d56bbaf..6e5c11759 100644 --- a/integrations/mistral/tests/test_mistral_document_embedder.py +++ b/integrations/mistral/tests/test_mistral_document_embedder.py @@ -12,7 +12,9 @@ class TestMistralDocumentEmbedder: - def test_init_default(self): + def test_init_default(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-api-key") + embedder = MistralDocumentEmbedder() assert embedder.api_key == Secret.from_env_var(["MISTRAL_API_KEY"]) assert embedder.model == "mistral-embed" @@ -46,7 +48,9 @@ def test_init_with_parameters(self): assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == "-" - def test_to_dict(self): + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-api-key") + embedder_component = MistralDocumentEmbedder() component_dict = embedder_component.to_dict() assert component_dict == { @@ -99,7 +103,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): @pytest.mark.skipif( not os.environ.get("MISTRAL_API_KEY", None), - reason="Export an env var called MISTRAL_API_KEY containing the Cohere API key to run this test.", + reason="Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test.", ) @pytest.mark.integration def test_run(self): diff --git a/integrations/mistral/tests/test_mistral_text_embedder.py b/integrations/mistral/tests/test_mistral_text_embedder.py index 82e9d23ee..af004b022 100644 --- a/integrations/mistral/tests/test_mistral_text_embedder.py +++ b/integrations/mistral/tests/test_mistral_text_embedder.py @@ -11,9 +11,12 @@ class TestMistralTextEmbedder: - def test_init_default(self): + def test_init_default(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-api-key") + embedder = MistralTextEmbedder() assert embedder.api_key == Secret.from_env_var(["MISTRAL_API_KEY"]) + assert embedder.api_base_url == "https://api.mistral.ai/v1" assert embedder.model == "mistral-embed" assert embedder.prefix == "" assert embedder.suffix == "" @@ -26,11 +29,14 @@ def test_init_with_parameters(self): suffix="END", ) assert embedder.api_key == Secret.from_token("test-api-key") + assert embedder.api_base_url == "https://api.mistral.ai/v1" assert embedder.model == "mistral-embed-v2" assert embedder.prefix == "START" assert embedder.suffix == "END" - def test_to_dict(self): + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-api-key") + embedder_component = MistralTextEmbedder() component_dict = embedder_component.to_dict() assert component_dict == { @@ -38,6 +44,7 @@ def test_to_dict(self): "init_parameters": { "api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"}, "model": "mistral-embed", + "api_base_url": "https://api.mistral.ai/v1", "dimensions": None, "organization": None, "prefix": "", @@ -60,6 +67,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): "init_parameters": { "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "model": "mistral-embed-v2", + "api_base_url": "https://custom-api-base-url.com", "dimensions": None, "organization": None, "prefix": "START", @@ -69,7 +77,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): @pytest.mark.skipif( not os.environ.get("MISTRAL_API_KEY", None), - reason="Export an env var called MISTRAL_API_KEY containing the Cohere API key to run this test.", + reason="Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test.", ) @pytest.mark.integration def test_run(self):