From b4b31f4b4f98b21a740ad70afcb83ca8c533e9d8 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Thu, 2 Jan 2025 12:32:58 +0100 Subject: [PATCH] Adapt Mistral to OpenAI refactoring --- .../tests/test_mistral_chat_generator.py | 58 ++++++++++++------- 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/integrations/mistral/tests/test_mistral_chat_generator.py b/integrations/mistral/tests/test_mistral_chat_generator.py index 6277b9c36..be3dce497 100644 --- a/integrations/mistral/tests/test_mistral_chat_generator.py +++ b/integrations/mistral/tests/test_mistral_chat_generator.py @@ -80,18 +80,24 @@ def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("MISTRAL_API_KEY", "test-api-key") component = MistralChatGenerator() data = component.to_dict() - assert data == { - "type": "haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator", - "init_parameters": { - "api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"}, - "model": "mistral-tiny", - "organization": None, - "streaming_callback": None, - "api_base_url": "https://api.mistral.ai/v1", - "generation_kwargs": {}, - }, + + assert ( + data["type"] + == "haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator" + ) + + expected_params = { + "api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"}, + "model": "mistral-tiny", + "organization": None, + "streaming_callback": None, + "api_base_url": "https://api.mistral.ai/v1", + "generation_kwargs": {}, } + for key, value in expected_params.items(): + assert data["init_parameters"][key] == value + def test_to_dict_with_parameters(self, monkeypatch): monkeypatch.setenv("ENV_VAR", "test-api-key") component = MistralChatGenerator( @@ -102,18 +108,23 @@ def test_to_dict_with_parameters(self, monkeypatch): generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, ) data = component.to_dict() - assert data == { - "type": "haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator", - "init_parameters": { - "api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"}, - "model": "mistral-small", - "api_base_url": "test-base-url", - "organization": None, - "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - }, + + assert ( + data["type"] + == "haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator" + ) + + expected_params = { + "api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"}, + "model": "mistral-small", + "api_base_url": "test-base-url", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, } + for key, value in expected_params.items(): + assert data["init_parameters"][key] == value + def test_from_dict(self, monkeypatch): monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key") data = { @@ -187,7 +198,12 @@ def test_check_abnormal_completions(self, caplog): ] for m in messages: - component._check_finish_reason(m) + try: + # Haystack >= 2.9.0 + component._check_finish_reason(m.meta) + except AttributeError: + # Haystack < 2.9.0 + component._check_finish_reason(m) # check truncation warning message_template = (