Skip to content

Commit

Permalink
feat: support model_arn in AmazonBedrockGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
tstadel committed Dec 11, 2024
1 parent adba166 commit b8e45bc
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import re
from typing import Any, Callable, ClassVar, Dict, List, Optional, Type
from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type, get_args

from botocore.config import Config
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -75,6 +75,26 @@ class AmazonBedrockGenerator:
r"([a-z]{2}\.)?mistral.*": MistralAdapter,
}

SUPPORTED_MODEL_FAMILIES: ClassVar[Dict[str, Type[BedrockModelAdapter]]] = {
"amazon.titan-text": AmazonTitanAdapter,
"ai21.j2": AI21LabsJurassic2Adapter,
"cohere.command": CohereCommandAdapter,
"cohere.command-r": CohereCommandRAdapter,
"anthropic.claude": AnthropicClaudeAdapter,
"meta.llama": MetaLlamaAdapter,
"mistral": MistralAdapter,
}

MODEL_FAMILIES = Literal[
"amazon.titan-text",
"ai21.j2",
"cohere.command",
"cohere.command-r",
"anthropic.claude",
"meta.llama",
"mistral",
]

def __init__(
self,
model: str,
Expand All @@ -89,6 +109,7 @@ def __init__(
truncate: Optional[bool] = True,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
boto3_config: Optional[Dict[str, Any]] = None,
model_family: Optional[MODEL_FAMILIES] = None,
**kwargs,
):
"""
Expand All @@ -105,6 +126,7 @@ def __init__(
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param boto3_config: The configuration for the boto3 client.
:param model_family: The model family to use. If not provided, the model adapter is selected based on the model name.
:param kwargs: Additional keyword arguments to be passed to the model.
These arguments are specific to the model. You can find them in the model's documentation.
:raises ValueError: If the model name is empty or None.
Expand Down Expand Up @@ -163,10 +185,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
max_length=self.max_length or 100,
)

model_adapter_cls = self.get_model_adapter(model=model)
if not model_adapter_cls:
msg = f"AmazonBedrockGenerator doesn't support the model {model}."
raise AmazonBedrockConfigurationError(msg)
model_adapter_cls = self.get_model_adapter(model=model, model_family=model_family)
self.model_adapter = model_adapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length)

def _ensure_token_limit(self, prompt: str) -> str:
Expand Down Expand Up @@ -250,17 +269,25 @@ def run(
return {"replies": replies}

@classmethod
def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelAdapter]]:
def get_model_adapter(cls, model: str, model_family: Optional[str]) -> Type[BedrockModelAdapter]:
"""
Gets the model adapter for the given model.
:param model: The model name.
:returns: The model adapter class, or None if no adapter is found.
"""
if model_family:
if model_family not in cls.SUPPORTED_MODEL_FAMILIES:
msg = f"Model family {model_family} is not supported. Must be one of {get_args(cls.MODEL_FAMILIES)}."
raise AmazonBedrockConfigurationError(msg)
return cls.SUPPORTED_MODEL_FAMILIES[model_family]

for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items():
if re.fullmatch(pattern, model):
return adapter
return None

msg = f"Could not auto-detect model family of {model}. `model_family` parameter must be one of {get_args(cls.MODEL_FAMILIES)}."
raise AmazonBedrockConfigurationError(msg)

def to_dict(self) -> Dict[str, Any]:
"""
Expand Down
41 changes: 39 additions & 2 deletions integrations/amazon_bedrock/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MetaLlamaAdapter,
MistralAdapter,
)
from integrations.amazon_bedrock.src.haystack_integrations.common.amazon_bedrock.errors import AmazonBedrockConfigurationError


@pytest.mark.parametrize(
Expand Down Expand Up @@ -294,17 +295,53 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session):
("eu.mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), # cross-region inference
("us.mistral.mistral-large-2402-v1:0", MistralAdapter), # cross-region inference
("mistral.mistral-medium-v8:0", MistralAdapter), # artificial
("unknown_model", None),
],
)
def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]):
"""
Test that the correct model adapter is returned for a given model
"""
model_adapter = AmazonBedrockGenerator.get_model_adapter(model=model)
model_adapter = AmazonBedrockGenerator.get_model_adapter(model=model, model_family=None)
assert model_adapter == expected_model_adapter


@pytest.mark.parametrize(
"model_family, expected_model_adapter",
[
("anthropic.claude", AnthropicClaudeAdapter),
("cohere.command", CohereCommandAdapter),
("cohere.command-r", CohereCommandRAdapter),
("ai21.j2", AI21LabsJurassic2Adapter),
("amazon.titan-text", AmazonTitanAdapter),
("meta.llama", MetaLlamaAdapter),
("mistral", MistralAdapter),
],
)
def test_get_model_adapter_with_model_family(model_family: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]):
"""
Test that the correct model adapter is returned for a given model
"""
model_adapter = AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family=model_family)
assert model_adapter == expected_model_adapter


def test_get_model_adapter_with_invalid_model_family():
"""
Test that the correct model adapter is returned for a given model
"""
with pytest.raises(AmazonBedrockConfigurationError):
AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family="invalid")


def test_get_model_adapter_auto_detect_family_fails():
"""
Test that the correct model adapter is returned for a given model
"""
with pytest.raises(AmazonBedrockConfigurationError):
AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family=None)



class TestAnthropicClaudeAdapter:
def test_default_init(self) -> None:
adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=100)
Expand Down

0 comments on commit b8e45bc

Please sign in to comment.