Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support model_arn in AmazonBedrockGenerator #1244

Merged
merged 5 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import re
from typing import Any, Callable, ClassVar, Dict, List, Optional, Type
from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type, get_args

from botocore.config import Config
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -75,6 +75,26 @@ class AmazonBedrockGenerator:
r"([a-z]{2}\.)?mistral.*": MistralAdapter,
}

SUPPORTED_MODEL_FAMILIES: ClassVar[Dict[str, Type[BedrockModelAdapter]]] = {
"amazon.titan-text": AmazonTitanAdapter,
"ai21.j2": AI21LabsJurassic2Adapter,
"cohere.command": CohereCommandAdapter,
"cohere.command-r": CohereCommandRAdapter,
"anthropic.claude": AnthropicClaudeAdapter,
"meta.llama": MetaLlamaAdapter,
"mistral": MistralAdapter,
}

MODEL_FAMILIES = Literal[
"amazon.titan-text",
"ai21.j2",
"cohere.command",
"cohere.command-r",
"anthropic.claude",
"meta.llama",
"mistral",
]

def __init__(
self,
model: str,
Expand All @@ -89,6 +109,7 @@ def __init__(
truncate: Optional[bool] = True,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
boto3_config: Optional[Dict[str, Any]] = None,
model_family: Optional[MODEL_FAMILIES] = None,
**kwargs,
):
"""
Expand All @@ -105,6 +126,8 @@ def __init__(
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param boto3_config: The configuration for the boto3 client.
:param model_family: The model family to use. If not provided, the model adapter is selected based on the model
name.
:param kwargs: Additional keyword arguments to be passed to the model.
These arguments are specific to the model. You can find them in the model's documentation.
:raises ValueError: If the model name is empty or None.
Expand All @@ -125,6 +148,7 @@ def __init__(
self.streaming_callback = streaming_callback
self.boto3_config = boto3_config
self.kwargs = kwargs
self.model_family = model_family

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None
Expand Down Expand Up @@ -163,10 +187,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
max_length=self.max_length or 100,
)

model_adapter_cls = self.get_model_adapter(model=model)
if not model_adapter_cls:
msg = f"AmazonBedrockGenerator doesn't support the model {model}."
raise AmazonBedrockConfigurationError(msg)
model_adapter_cls = self.get_model_adapter(model=model, model_family=model_family)
self.model_adapter = model_adapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length)

def _ensure_token_limit(self, prompt: str) -> str:
Expand Down Expand Up @@ -250,17 +271,34 @@ def run(
return {"replies": replies}

@classmethod
def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelAdapter]]:
def get_model_adapter(cls, model: str, model_family: Optional[str] = None) -> Type[BedrockModelAdapter]:
"""
Gets the model adapter for the given model.

If `model_family` is provided, the adapter for the model family is returned.
If `model_family` is not provided, the adapter is auto-detected based on the model name.

:param model: The model name.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a docstring for model_family here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not set model_family to None by default?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

:param model_family: The model family.
:returns: The model adapter class, or None if no adapter is found.
:raises AmazonBedrockConfigurationError: If the model family is not supported or the model cannot be
auto-detected.
"""
if model_family:
if model_family not in cls.SUPPORTED_MODEL_FAMILIES:
msg = f"Model family {model_family} is not supported. Must be one of {get_args(cls.MODEL_FAMILIES)}."
raise AmazonBedrockConfigurationError(msg)
return cls.SUPPORTED_MODEL_FAMILIES[model_family]

for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items():
if re.fullmatch(pattern, model):
return adapter
return None

msg = (
f"Could not auto-detect model family of {model}. "
f"`model_family` parameter must be one of {get_args(cls.MODEL_FAMILIES)}."
)
raise AmazonBedrockConfigurationError(msg)

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -282,6 +320,7 @@ def to_dict(self) -> Dict[str, Any]:
truncate=self.truncate,
streaming_callback=callback_name,
boto3_config=self.boto3_config,
model_family=self.model_family,
**self.kwargs,
)

Expand Down
55 changes: 54 additions & 1 deletion integrations/amazon_bedrock/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import pytest
from haystack.dataclasses import StreamingChunk

from haystack_integrations.common.amazon_bedrock.errors import (
AmazonBedrockConfigurationError,
)
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator
from haystack_integrations.components.generators.amazon_bedrock.adapters import (
AI21LabsJurassic2Adapter,
Expand Down Expand Up @@ -48,6 +51,7 @@ def test_to_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]
"temperature": 10,
"streaming_callback": None,
"boto3_config": boto3_config,
"model_family": None,
},
}

Expand Down Expand Up @@ -79,13 +83,15 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any
"model": "anthropic.claude-v2",
"max_length": 99,
"boto3_config": boto3_config,
"model_family": "anthropic.claude",
},
}
)

assert generator.max_length == 99
assert generator.model == "anthropic.claude-v2"
assert generator.boto3_config == boto3_config
assert generator.model_family == "anthropic.claude"


def test_default_constructor(mock_boto3_session, set_env_variables):
Expand Down Expand Up @@ -294,7 +300,6 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session):
("eu.mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), # cross-region inference
("us.mistral.mistral-large-2402-v1:0", MistralAdapter), # cross-region inference
("mistral.mistral-medium-v8:0", MistralAdapter), # artificial
("unknown_model", None),
],
)
def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]):
Expand All @@ -305,6 +310,54 @@ def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[Bed
assert model_adapter == expected_model_adapter


@pytest.mark.parametrize(
"model_family, expected_model_adapter",
[
("anthropic.claude", AnthropicClaudeAdapter),
("cohere.command", CohereCommandAdapter),
("cohere.command-r", CohereCommandRAdapter),
("ai21.j2", AI21LabsJurassic2Adapter),
("amazon.titan-text", AmazonTitanAdapter),
("meta.llama", MetaLlamaAdapter),
("mistral", MistralAdapter),
],
)
def test_get_model_adapter_with_model_family(
model_family: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]
):
"""
Test that the correct model adapter is returned for a given model model_family
"""
model_adapter = AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family=model_family)
assert model_adapter == expected_model_adapter


def test_get_model_adapter_with_invalid_model_family():
"""
Test that an error is raised when an invalid model_family is provided
"""
with pytest.raises(AmazonBedrockConfigurationError):
AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family="invalid")


def test_get_model_adapter_auto_detect_family_fails():
"""
Test that an error is raised when auto-detection of model_family fails
"""
with pytest.raises(AmazonBedrockConfigurationError):
AmazonBedrockGenerator.get_model_adapter(model="arn:123435423")


def test_get_model_adapter_model_family_over_auto_detection():
"""
Test that the model_family is used over auto-detection
"""
model_adapter = AmazonBedrockGenerator.get_model_adapter(
model="cohere.command-text-v14", model_family="anthropic.claude"
)
assert model_adapter == AnthropicClaudeAdapter


class TestAnthropicClaudeAdapter:
def test_default_init(self) -> None:
adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=100)
Expand Down
Loading