diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 941fdbf71..9313b2e50 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -1,7 +1,7 @@ import json import logging import re -from typing import Any, Callable, ClassVar, Dict, List, Optional, Type +from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type, get_args from botocore.config import Config from botocore.exceptions import ClientError @@ -75,6 +75,26 @@ class AmazonBedrockGenerator: r"([a-z]{2}\.)?mistral.*": MistralAdapter, } + SUPPORTED_MODEL_FAMILIES: ClassVar[Dict[str, Type[BedrockModelAdapter]]] = { + "amazon.titan-text": AmazonTitanAdapter, + "ai21.j2": AI21LabsJurassic2Adapter, + "cohere.command": CohereCommandAdapter, + "cohere.command-r": CohereCommandRAdapter, + "anthropic.claude": AnthropicClaudeAdapter, + "meta.llama": MetaLlamaAdapter, + "mistral": MistralAdapter, + } + + MODEL_FAMILIES = Literal[ + "amazon.titan-text", + "ai21.j2", + "cohere.command", + "cohere.command-r", + "anthropic.claude", + "meta.llama", + "mistral", + ] + def __init__( self, model: str, @@ -89,6 +109,7 @@ def __init__( truncate: Optional[bool] = True, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, boto3_config: Optional[Dict[str, Any]] = None, + model_family: Optional[MODEL_FAMILIES] = None, **kwargs, ): """ @@ -105,6 +126,7 @@ def __init__( :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. :param boto3_config: The configuration for the boto3 client. + :param model_family: The model family to use. If not provided, the model adapter is selected based on the model name. :param kwargs: Additional keyword arguments to be passed to the model. These arguments are specific to the model. You can find them in the model's documentation. :raises ValueError: If the model name is empty or None. @@ -163,10 +185,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: max_length=self.max_length or 100, ) - model_adapter_cls = self.get_model_adapter(model=model) - if not model_adapter_cls: - msg = f"AmazonBedrockGenerator doesn't support the model {model}." - raise AmazonBedrockConfigurationError(msg) + model_adapter_cls = self.get_model_adapter(model=model, model_family=model_family) self.model_adapter = model_adapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length) def _ensure_token_limit(self, prompt: str) -> str: @@ -250,17 +269,25 @@ def run( return {"replies": replies} @classmethod - def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelAdapter]]: + def get_model_adapter(cls, model: str, model_family: Optional[str]) -> Type[BedrockModelAdapter]: """ Gets the model adapter for the given model. :param model: The model name. :returns: The model adapter class, or None if no adapter is found. """ + if model_family: + if model_family not in cls.SUPPORTED_MODEL_FAMILIES: + msg = f"Model family {model_family} is not supported. Must be one of {get_args(cls.MODEL_FAMILIES)}." + raise AmazonBedrockConfigurationError(msg) + return cls.SUPPORTED_MODEL_FAMILIES[model_family] + for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items(): if re.fullmatch(pattern, model): return adapter - return None + + msg = f"Could not auto-detect model family of {model}. `model_family` parameter must be one of {get_args(cls.MODEL_FAMILIES)}." + raise AmazonBedrockConfigurationError(msg) def to_dict(self) -> Dict[str, Any]: """ diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index 54b185da5..ca28f290c 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -15,6 +15,7 @@ MetaLlamaAdapter, MistralAdapter, ) +from integrations.amazon_bedrock.src.haystack_integrations.common.amazon_bedrock.errors import AmazonBedrockConfigurationError @pytest.mark.parametrize( @@ -294,17 +295,53 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): ("eu.mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), # cross-region inference ("us.mistral.mistral-large-2402-v1:0", MistralAdapter), # cross-region inference ("mistral.mistral-medium-v8:0", MistralAdapter), # artificial - ("unknown_model", None), ], ) def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]): """ Test that the correct model adapter is returned for a given model """ - model_adapter = AmazonBedrockGenerator.get_model_adapter(model=model) + model_adapter = AmazonBedrockGenerator.get_model_adapter(model=model, model_family=None) assert model_adapter == expected_model_adapter +@pytest.mark.parametrize( + "model_family, expected_model_adapter", + [ + ("anthropic.claude", AnthropicClaudeAdapter), + ("cohere.command", CohereCommandAdapter), + ("cohere.command-r", CohereCommandRAdapter), + ("ai21.j2", AI21LabsJurassic2Adapter), + ("amazon.titan-text", AmazonTitanAdapter), + ("meta.llama", MetaLlamaAdapter), + ("mistral", MistralAdapter), + ], +) +def test_get_model_adapter_with_model_family(model_family: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]): + """ + Test that the correct model adapter is returned for a given model + """ + model_adapter = AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family=model_family) + assert model_adapter == expected_model_adapter + + +def test_get_model_adapter_with_invalid_model_family(): + """ + Test that the correct model adapter is returned for a given model + """ + with pytest.raises(AmazonBedrockConfigurationError): + AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family="invalid") + + +def test_get_model_adapter_auto_detect_family_fails(): + """ + Test that the correct model adapter is returned for a given model + """ + with pytest.raises(AmazonBedrockConfigurationError): + AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family=None) + + + class TestAnthropicClaudeAdapter: def test_default_init(self) -> None: adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=100)