From f07d768e8b6630725b2fbf326a7526a62987a2f7 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 8 Mar 2024 17:26:57 +0100 Subject: [PATCH] Cosmetics --- .../amazon_bedrock/tests/test_chat_generator.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 8eda7da39..b40cf09f1 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -11,7 +11,8 @@ MetaLlama2ChatAdapter, ) -clazz = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" +KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" +MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"] def test_to_dict(mock_boto3_session): @@ -24,7 +25,7 @@ def test_to_dict(mock_boto3_session): streaming_callback=print_streaming_chunk, ) expected_dict = { - "type": clazz, + "type": KLASS, "init_parameters": { "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, @@ -47,7 +48,7 @@ def test_from_dict(mock_boto3_session): """ generator = AmazonBedrockChatGenerator.from_dict( { - "type": clazz, + "type": KLASS, "init_parameters": { "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, @@ -229,9 +230,7 @@ def test_get_responses(self) -> None: assert response_message == [ChatMessage.from_assistant(expected_response)] - @pytest.mark.parametrize( - "model_name", ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"] - ) + @pytest.mark.parametrize("model_name", MODELS_TO_TEST) @pytest.mark.integration def test_default_inference_params(self, model_name): messages = [ @@ -248,11 +247,11 @@ def test_default_inference_params(self, model_name): assert response["replies"][0].content assert ChatMessage.is_from(response["replies"][0], ChatRole.ASSISTANT) assert "paris" in response["replies"][0].content.lower() + + # validate meta assert len(response["replies"][0].meta) > 0 - @pytest.mark.parametrize( - "model_name", ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"] - ) + @pytest.mark.parametrize("model_name", MODELS_TO_TEST) @pytest.mark.integration def test_default_inference_with_streaming(self, model_name):