Skip to content

Commit

Permalink
Cosmetics
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Mar 8, 2024
1 parent 16f08b6 commit f07d768
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
MetaLlama2ChatAdapter,
)

clazz = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator"
KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator"
MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"]


def test_to_dict(mock_boto3_session):
Expand All @@ -24,7 +25,7 @@ def test_to_dict(mock_boto3_session):
streaming_callback=print_streaming_chunk,
)
expected_dict = {
"type": clazz,
"type": KLASS,
"init_parameters": {
"aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
"aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
Expand All @@ -47,7 +48,7 @@ def test_from_dict(mock_boto3_session):
"""
generator = AmazonBedrockChatGenerator.from_dict(
{
"type": clazz,
"type": KLASS,
"init_parameters": {
"aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
"aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
Expand Down Expand Up @@ -229,9 +230,7 @@ def test_get_responses(self) -> None:

assert response_message == [ChatMessage.from_assistant(expected_response)]

@pytest.mark.parametrize(
"model_name", ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"]
)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
@pytest.mark.integration
def test_default_inference_params(self, model_name):
messages = [
Expand All @@ -248,11 +247,11 @@ def test_default_inference_params(self, model_name):
assert response["replies"][0].content
assert ChatMessage.is_from(response["replies"][0], ChatRole.ASSISTANT)
assert "paris" in response["replies"][0].content.lower()

# validate meta
assert len(response["replies"][0].meta) > 0

@pytest.mark.parametrize(
"model_name", ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"]
)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
@pytest.mark.integration
def test_default_inference_with_streaming(self, model_name):

Expand Down

0 comments on commit f07d768

Please sign in to comment.