Skip to content

Commit

Permalink
feat: add prefixes to supported model patterns to allow cross region …
Browse files Browse the repository at this point in the history
…model ids (#1127)

* feat: add prefixes to supported model patterns to allow cross region model ids
  • Loading branch information
abrahamy authored Oct 17, 2024
1 parent f95dd06 commit ac0e4c2
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ class AmazonBedrockChatGenerator:
"""

SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = {
r"(.+\.)?anthropic.claude.*": AnthropicClaudeChatAdapter,
r"meta.llama2.*": MetaLlama2ChatAdapter,
r"mistral.*": MistralChatAdapter,
r"([a-z]{2}\.)?anthropic.claude.*": AnthropicClaudeChatAdapter,
r"([a-z]{2}\.)?meta.llama2.*": MetaLlama2ChatAdapter,
r"([a-z]{2}\.)?mistral.*": MistralChatAdapter,
}

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ class AmazonBedrockGenerator:
"""

SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelAdapter]]] = {
r"amazon.titan-text.*": AmazonTitanAdapter,
r"ai21.j2.*": AI21LabsJurassic2Adapter,
r"cohere.command-[^r].*": CohereCommandAdapter,
r"cohere.command-r.*": CohereCommandRAdapter,
r"(.+\.)?anthropic.claude.*": AnthropicClaudeAdapter,
r"meta.llama.*": MetaLlamaAdapter,
r"mistral.*": MistralAdapter,
r"([a-z]{2}\.)?amazon.titan-text.*": AmazonTitanAdapter,
r"([a-z]{2}\.)?ai21.j2.*": AI21LabsJurassic2Adapter,
r"([a-z]{2}\.)?cohere.command-[^r].*": CohereCommandAdapter,
r"([a-z]{2}\.)?cohere.command-r.*": CohereCommandRAdapter,
r"([a-z]{2}\.)?anthropic.claude.*": AnthropicClaudeAdapter,
r"([a-z]{2}\.)?meta.llama.*": MetaLlamaAdapter,
r"([a-z]{2}\.)?mistral.*": MistralAdapter,
}

def __init__(
Expand Down
6 changes: 4 additions & 2 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session):
generator.run(messages=messages)

# Ensure _ensure_token_limit was not called
mock_ensure_token_limit.assert_not_called(),
mock_ensure_token_limit.assert_not_called()

# Check the prompt passed to prepare_body
generator.model_adapter.prepare_body.assert_called_with(messages=messages, stop_words=[], stream=False)
Expand All @@ -261,6 +261,9 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session):
("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter),
("meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter),
("meta.llama2-130b-v5", MetaLlama2ChatAdapter), # artificial
("us.meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), # cross-region inference
("eu.meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), # cross-region inference
("de.meta.llama2-130b-v5", MetaLlama2ChatAdapter), # cross-region inference
("unknown_model", None),
],
)
Expand Down Expand Up @@ -517,7 +520,6 @@ def test_get_responses(self) -> None:
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
@pytest.mark.integration
def test_default_inference_params(self, model_name, chat_messages):

client = AmazonBedrockChatGenerator(model=model_name)
response = client.run(chat_messages)

Expand Down
7 changes: 6 additions & 1 deletion integrations/amazon_bedrock/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session):
generator.run(prompt=long_prompt_text)

# Ensure _ensure_token_limit was not called
mock_ensure_token_limit.assert_not_called(),
mock_ensure_token_limit.assert_not_called()

# Check the prompt passed to prepare_body
generator.model_adapter.prepare_body.assert_called_with(prompt=long_prompt_text, stream=False)
Expand All @@ -251,17 +251,22 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session):
("ai21.j2-mega-v5", AI21LabsJurassic2Adapter), # artificial
("amazon.titan-text-lite-v1", AmazonTitanAdapter),
("amazon.titan-text-express-v1", AmazonTitanAdapter),
("us.amazon.titan-text-express-v1", AmazonTitanAdapter), # cross-region inference
("amazon.titan-text-agile-v1", AmazonTitanAdapter),
("amazon.titan-text-lightning-v8", AmazonTitanAdapter), # artificial
("meta.llama2-13b-chat-v1", MetaLlamaAdapter),
("meta.llama2-70b-chat-v1", MetaLlamaAdapter),
("eu.meta.llama2-13b-chat-v1", MetaLlamaAdapter), # cross-region inference
("us.meta.llama2-70b-chat-v1", MetaLlamaAdapter), # cross-region inference
("meta.llama2-130b-v5", MetaLlamaAdapter), # artificial
("meta.llama3-8b-instruct-v1:0", MetaLlamaAdapter),
("meta.llama3-70b-instruct-v1:0", MetaLlamaAdapter),
("meta.llama3-130b-instruct-v5:9", MetaLlamaAdapter), # artificial
("mistral.mistral-7b-instruct-v0:2", MistralAdapter),
("mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter),
("mistral.mistral-large-2402-v1:0", MistralAdapter),
("eu.mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), # cross-region inference
("us.mistral.mistral-large-2402-v1:0", MistralAdapter), # cross-region inference
("mistral.mistral-medium-v8:0", MistralAdapter), # artificial
("unknown_model", None),
],
Expand Down

0 comments on commit ac0e4c2

Please sign in to comment.