From b8e45bcf57cf4b7c2cd3d35825bff306c29df0ca Mon Sep 17 00:00:00 2001 From: tstadel Date: Wed, 11 Dec 2024 11:57:05 +0100 Subject: [PATCH 1/5] feat: support model_arn in AmazonBedrockGenerator --- .../generators/amazon_bedrock/generator.py | 41 +++++++++++++++---- .../amazon_bedrock/tests/test_generator.py | 41 ++++++++++++++++++- 2 files changed, 73 insertions(+), 9 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 941fdbf71..9313b2e50 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -1,7 +1,7 @@ import json import logging import re -from typing import Any, Callable, ClassVar, Dict, List, Optional, Type +from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type, get_args from botocore.config import Config from botocore.exceptions import ClientError @@ -75,6 +75,26 @@ class AmazonBedrockGenerator: r"([a-z]{2}\.)?mistral.*": MistralAdapter, } + SUPPORTED_MODEL_FAMILIES: ClassVar[Dict[str, Type[BedrockModelAdapter]]] = { + "amazon.titan-text": AmazonTitanAdapter, + "ai21.j2": AI21LabsJurassic2Adapter, + "cohere.command": CohereCommandAdapter, + "cohere.command-r": CohereCommandRAdapter, + "anthropic.claude": AnthropicClaudeAdapter, + "meta.llama": MetaLlamaAdapter, + "mistral": MistralAdapter, + } + + MODEL_FAMILIES = Literal[ + "amazon.titan-text", + "ai21.j2", + "cohere.command", + "cohere.command-r", + "anthropic.claude", + "meta.llama", + "mistral", + ] + def __init__( self, model: str, @@ -89,6 +109,7 @@ def __init__( truncate: Optional[bool] = True, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, boto3_config: Optional[Dict[str, Any]] = None, + model_family: Optional[MODEL_FAMILIES] = None, **kwargs, ): """ @@ -105,6 +126,7 @@ def __init__( :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. :param boto3_config: The configuration for the boto3 client. + :param model_family: The model family to use. If not provided, the model adapter is selected based on the model name. :param kwargs: Additional keyword arguments to be passed to the model. These arguments are specific to the model. You can find them in the model's documentation. :raises ValueError: If the model name is empty or None. @@ -163,10 +185,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: max_length=self.max_length or 100, ) - model_adapter_cls = self.get_model_adapter(model=model) - if not model_adapter_cls: - msg = f"AmazonBedrockGenerator doesn't support the model {model}." - raise AmazonBedrockConfigurationError(msg) + model_adapter_cls = self.get_model_adapter(model=model, model_family=model_family) self.model_adapter = model_adapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length) def _ensure_token_limit(self, prompt: str) -> str: @@ -250,17 +269,25 @@ def run( return {"replies": replies} @classmethod - def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelAdapter]]: + def get_model_adapter(cls, model: str, model_family: Optional[str]) -> Type[BedrockModelAdapter]: """ Gets the model adapter for the given model. :param model: The model name. :returns: The model adapter class, or None if no adapter is found. """ + if model_family: + if model_family not in cls.SUPPORTED_MODEL_FAMILIES: + msg = f"Model family {model_family} is not supported. Must be one of {get_args(cls.MODEL_FAMILIES)}." + raise AmazonBedrockConfigurationError(msg) + return cls.SUPPORTED_MODEL_FAMILIES[model_family] + for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items(): if re.fullmatch(pattern, model): return adapter - return None + + msg = f"Could not auto-detect model family of {model}. `model_family` parameter must be one of {get_args(cls.MODEL_FAMILIES)}." + raise AmazonBedrockConfigurationError(msg) def to_dict(self) -> Dict[str, Any]: """ diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index 54b185da5..ca28f290c 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -15,6 +15,7 @@ MetaLlamaAdapter, MistralAdapter, ) +from integrations.amazon_bedrock.src.haystack_integrations.common.amazon_bedrock.errors import AmazonBedrockConfigurationError @pytest.mark.parametrize( @@ -294,17 +295,53 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): ("eu.mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), # cross-region inference ("us.mistral.mistral-large-2402-v1:0", MistralAdapter), # cross-region inference ("mistral.mistral-medium-v8:0", MistralAdapter), # artificial - ("unknown_model", None), ], ) def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]): """ Test that the correct model adapter is returned for a given model """ - model_adapter = AmazonBedrockGenerator.get_model_adapter(model=model) + model_adapter = AmazonBedrockGenerator.get_model_adapter(model=model, model_family=None) assert model_adapter == expected_model_adapter +@pytest.mark.parametrize( + "model_family, expected_model_adapter", + [ + ("anthropic.claude", AnthropicClaudeAdapter), + ("cohere.command", CohereCommandAdapter), + ("cohere.command-r", CohereCommandRAdapter), + ("ai21.j2", AI21LabsJurassic2Adapter), + ("amazon.titan-text", AmazonTitanAdapter), + ("meta.llama", MetaLlamaAdapter), + ("mistral", MistralAdapter), + ], +) +def test_get_model_adapter_with_model_family(model_family: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]): + """ + Test that the correct model adapter is returned for a given model + """ + model_adapter = AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family=model_family) + assert model_adapter == expected_model_adapter + + +def test_get_model_adapter_with_invalid_model_family(): + """ + Test that the correct model adapter is returned for a given model + """ + with pytest.raises(AmazonBedrockConfigurationError): + AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family="invalid") + + +def test_get_model_adapter_auto_detect_family_fails(): + """ + Test that the correct model adapter is returned for a given model + """ + with pytest.raises(AmazonBedrockConfigurationError): + AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family=None) + + + class TestAnthropicClaudeAdapter: def test_default_init(self) -> None: adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=100) From 00234031be1774b00d861bc3bc87bcb83a5f810c Mon Sep 17 00:00:00 2001 From: tstadel Date: Wed, 11 Dec 2024 12:00:52 +0100 Subject: [PATCH 2/5] add test --- .../amazon_bedrock/tests/test_generator.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index ca28f290c..bf66a2cf3 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -319,7 +319,7 @@ def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[Bed ) def test_get_model_adapter_with_model_family(model_family: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]): """ - Test that the correct model adapter is returned for a given model + Test that the correct model adapter is returned for a given model model_family """ model_adapter = AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family=model_family) assert model_adapter == expected_model_adapter @@ -327,7 +327,7 @@ def test_get_model_adapter_with_model_family(model_family: str, expected_model_a def test_get_model_adapter_with_invalid_model_family(): """ - Test that the correct model adapter is returned for a given model + Test that an error is raised when an invalid model_family is provided """ with pytest.raises(AmazonBedrockConfigurationError): AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family="invalid") @@ -335,12 +335,21 @@ def test_get_model_adapter_with_invalid_model_family(): def test_get_model_adapter_auto_detect_family_fails(): """ - Test that the correct model adapter is returned for a given model + Test that an error is raised when auto-detection of model_family fails """ with pytest.raises(AmazonBedrockConfigurationError): AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family=None) +def test_get_model_adapter_model_family_over_auto_detection(): + """ + Test that the model_family is used over auto-detection + """ + model_adapter = AmazonBedrockGenerator.get_model_adapter( + model="cohere.command-text-v14", model_family="anthropic.claude" + ) + assert model_adapter == AnthropicClaudeAdapter + class TestAnthropicClaudeAdapter: def test_default_init(self) -> None: From 7b626ae924b52a6a5fc91c5cd20315c7a21b1b58 Mon Sep 17 00:00:00 2001 From: tstadel Date: Wed, 11 Dec 2024 12:12:10 +0100 Subject: [PATCH 3/5] fix tests --- .../generators/amazon_bedrock/generator.py | 12 +++++++++--- integrations/amazon_bedrock/tests/test_generator.py | 11 +++++++++-- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 9313b2e50..0418a3cfe 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -126,7 +126,8 @@ def __init__( :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. :param boto3_config: The configuration for the boto3 client. - :param model_family: The model family to use. If not provided, the model adapter is selected based on the model name. + :param model_family: The model family to use. If not provided, the model adapter is selected based on the model + name. :param kwargs: Additional keyword arguments to be passed to the model. These arguments are specific to the model. You can find them in the model's documentation. :raises ValueError: If the model name is empty or None. @@ -147,6 +148,7 @@ def __init__( self.streaming_callback = streaming_callback self.boto3_config = boto3_config self.kwargs = kwargs + self.model_family = model_family def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -285,8 +287,11 @@ def get_model_adapter(cls, model: str, model_family: Optional[str]) -> Type[Bedr for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items(): if re.fullmatch(pattern, model): return adapter - - msg = f"Could not auto-detect model family of {model}. `model_family` parameter must be one of {get_args(cls.MODEL_FAMILIES)}." + + msg = ( + f"Could not auto-detect model family of {model}. " + f"`model_family` parameter must be one of {get_args(cls.MODEL_FAMILIES)}." + ) raise AmazonBedrockConfigurationError(msg) def to_dict(self) -> Dict[str, Any]: @@ -309,6 +314,7 @@ def to_dict(self) -> Dict[str, Any]: truncate=self.truncate, streaming_callback=callback_name, boto3_config=self.boto3_config, + model_family=self.model_family, **self.kwargs, ) diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index bf66a2cf3..c545913e8 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -4,6 +4,9 @@ import pytest from haystack.dataclasses import StreamingChunk +from haystack_integrations.common.amazon_bedrock.errors import ( + AmazonBedrockConfigurationError, +) from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator from haystack_integrations.components.generators.amazon_bedrock.adapters import ( AI21LabsJurassic2Adapter, @@ -15,7 +18,6 @@ MetaLlamaAdapter, MistralAdapter, ) -from integrations.amazon_bedrock.src.haystack_integrations.common.amazon_bedrock.errors import AmazonBedrockConfigurationError @pytest.mark.parametrize( @@ -49,6 +51,7 @@ def test_to_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]] "temperature": 10, "streaming_callback": None, "boto3_config": boto3_config, + "model_family": None, }, } @@ -80,6 +83,7 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any "model": "anthropic.claude-v2", "max_length": 99, "boto3_config": boto3_config, + "model_family": "anthropic.claude", }, } ) @@ -87,6 +91,7 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any assert generator.max_length == 99 assert generator.model == "anthropic.claude-v2" assert generator.boto3_config == boto3_config + assert generator.model_family == "anthropic.claude" def test_default_constructor(mock_boto3_session, set_env_variables): @@ -317,7 +322,9 @@ def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[Bed ("mistral", MistralAdapter), ], ) -def test_get_model_adapter_with_model_family(model_family: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]): +def test_get_model_adapter_with_model_family( + model_family: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]] +): """ Test that the correct model adapter is returned for a given model model_family """ From 44e0c8209475a5be9d2a223e4f9ccefe4f7d1e69 Mon Sep 17 00:00:00 2001 From: tstadel Date: Wed, 11 Dec 2024 13:28:19 +0100 Subject: [PATCH 4/5] apply feedback --- .../components/generators/amazon_bedrock/generator.py | 7 ++++++- integrations/amazon_bedrock/tests/test_generator.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 0418a3cfe..8d252db7b 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -271,12 +271,17 @@ def run( return {"replies": replies} @classmethod - def get_model_adapter(cls, model: str, model_family: Optional[str]) -> Type[BedrockModelAdapter]: + def get_model_adapter(cls, model: str, model_family: Optional[str] = None) -> Type[BedrockModelAdapter]: """ Gets the model adapter for the given model. + If `model_family` is provided, the adapter for the model family is returned. + If `model_family` is not provided, the adapter is auto-detected based on the model name. + :param model: The model name. + :param model_family: The model family. :returns: The model adapter class, or None if no adapter is found. + :raises AmazonBedrockConfigurationError: If the model family is not supported or the model cannot be auto-detected. """ if model_family: if model_family not in cls.SUPPORTED_MODEL_FAMILIES: diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index c545913e8..3d2cbc01f 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -306,7 +306,7 @@ def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[Bed """ Test that the correct model adapter is returned for a given model """ - model_adapter = AmazonBedrockGenerator.get_model_adapter(model=model, model_family=None) + model_adapter = AmazonBedrockGenerator.get_model_adapter(model=model) assert model_adapter == expected_model_adapter @@ -345,7 +345,7 @@ def test_get_model_adapter_auto_detect_family_fails(): Test that an error is raised when auto-detection of model_family fails """ with pytest.raises(AmazonBedrockConfigurationError): - AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family=None) + AmazonBedrockGenerator.get_model_adapter(model="arn:123435423") def test_get_model_adapter_model_family_over_auto_detection(): From fddeb46a5ddb134ca677b930e92f39f6d67ee3a4 Mon Sep 17 00:00:00 2001 From: tstadel Date: Wed, 11 Dec 2024 13:29:13 +0100 Subject: [PATCH 5/5] fix lint --- .../components/generators/amazon_bedrock/generator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 8d252db7b..79dc07cdc 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -281,7 +281,8 @@ def get_model_adapter(cls, model: str, model_family: Optional[str] = None) -> Ty :param model: The model name. :param model_family: The model family. :returns: The model adapter class, or None if no adapter is found. - :raises AmazonBedrockConfigurationError: If the model family is not supported or the model cannot be auto-detected. + :raises AmazonBedrockConfigurationError: If the model family is not supported or the model cannot be + auto-detected. """ if model_family: if model_family not in cls.SUPPORTED_MODEL_FAMILIES: