From 62483b78232a75e729a5325cf774673ca0906c1e Mon Sep 17 00:00:00 2001 From: 1greentangerine <158560711+1greentangerine@users.noreply.github.com> Date: Fri, 4 Oct 2024 12:19:06 +0200 Subject: [PATCH] modify regex to allow cross-region inference in bedrock (#1120) * modify regex to allow cross-region inference in bedrock (only possible for claude models) * add tests for multi-region inference with claude models --- .../components/generators/amazon_bedrock/chat/chat_generator.py | 2 +- .../components/generators/amazon_bedrock/generator.py | 2 +- integrations/amazon_bedrock/tests/test_chat_generator.py | 2 ++ integrations/amazon_bedrock/tests/test_generator.py | 2 ++ 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 988452a97..e1732646a 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -58,7 +58,7 @@ class AmazonBedrockChatGenerator: """ SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = { - r"anthropic.claude.*": AnthropicClaudeChatAdapter, + r"(.+\.)?anthropic.claude.*": AnthropicClaudeChatAdapter, r"meta.llama2.*": MetaLlama2ChatAdapter, r"mistral.*": MistralChatAdapter, } diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 6ef0a4765..1edde3526 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -69,7 +69,7 @@ class AmazonBedrockGenerator: r"ai21.j2.*": AI21LabsJurassic2Adapter, r"cohere.command-[^r].*": CohereCommandAdapter, r"cohere.command-r.*": CohereCommandRAdapter, - r"anthropic.claude.*": AnthropicClaudeAdapter, + r"(.+\.)?anthropic.claude.*": AnthropicClaudeAdapter, r"meta.llama.*": MetaLlamaAdapter, r"mistral.*": MistralAdapter, } diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index a455d2c93..49abc0979 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -254,6 +254,8 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): [ ("anthropic.claude-v1", AnthropicClaudeChatAdapter), ("anthropic.claude-v2", AnthropicClaudeChatAdapter), + ("eu.anthropic.claude-v1", AnthropicClaudeChatAdapter), # cross-region inference + ("us.anthropic.claude-v2", AnthropicClaudeChatAdapter), # cross-region inference ("anthropic.claude-instant-v1", AnthropicClaudeChatAdapter), ("anthropic.claude-super-v5", AnthropicClaudeChatAdapter), # artificial ("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index f0233888c..61ae9d6b4 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -231,6 +231,8 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): [ ("anthropic.claude-v1", AnthropicClaudeAdapter), ("anthropic.claude-v2", AnthropicClaudeAdapter), + ("eu.anthropic.claude-v1", AnthropicClaudeAdapter), # cross-region inference + ("us.anthropic.claude-v2", AnthropicClaudeAdapter), # cross-region inference ("anthropic.claude-instant-v1", AnthropicClaudeAdapter), ("anthropic.claude-super-v5", AnthropicClaudeAdapter), # artificial ("cohere.command-text-v14", CohereCommandAdapter),