From 00234031be1774b00d861bc3bc87bcb83a5f810c Mon Sep 17 00:00:00 2001 From: tstadel Date: Wed, 11 Dec 2024 12:00:52 +0100 Subject: [PATCH] add test --- .../amazon_bedrock/tests/test_generator.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index ca28f290c..bf66a2cf3 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -319,7 +319,7 @@ def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[Bed ) def test_get_model_adapter_with_model_family(model_family: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]): """ - Test that the correct model adapter is returned for a given model + Test that the correct model adapter is returned for a given model model_family """ model_adapter = AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family=model_family) assert model_adapter == expected_model_adapter @@ -327,7 +327,7 @@ def test_get_model_adapter_with_model_family(model_family: str, expected_model_a def test_get_model_adapter_with_invalid_model_family(): """ - Test that the correct model adapter is returned for a given model + Test that an error is raised when an invalid model_family is provided """ with pytest.raises(AmazonBedrockConfigurationError): AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family="invalid") @@ -335,12 +335,21 @@ def test_get_model_adapter_with_invalid_model_family(): def test_get_model_adapter_auto_detect_family_fails(): """ - Test that the correct model adapter is returned for a given model + Test that an error is raised when auto-detection of model_family fails """ with pytest.raises(AmazonBedrockConfigurationError): AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family=None) +def test_get_model_adapter_model_family_over_auto_detection(): + """ + Test that the model_family is used over auto-detection + """ + model_adapter = AmazonBedrockGenerator.get_model_adapter( + model="cohere.command-text-v14", model_family="anthropic.claude" + ) + assert model_adapter == AnthropicClaudeAdapter + class TestAnthropicClaudeAdapter: def test_default_init(self) -> None: