diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index dce52ac03..1b1287c59 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -156,32 +156,6 @@ class MetaLlama2ChatAdapter(BedrockModelChatAdapter): Model adapter for the Meta Llama 2 models. """ - # Llama 2 chat template - chat_template = """ - {% if messages[0]['role'] == 'system' %} - {% set loop_messages = messages[1:] %} - {% set system_message = messages[0]['content'] %} - {% else %} - {% set loop_messages = messages %} - {% set system_message = false %} - {% endif %} - {% for message in loop_messages %} - {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} - {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} - {% endif %} - {% if loop.index0 == 0 and system_message != false %} - {% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %} - {% else %} - {% set content = message['content'] %} - {% endif %} - {% if message['role'] == 'user' %} - {{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }} - {% elif message['role'] == 'assistant' %} - {{ ' ' + content.strip() + ' ' + eos_token }} - {% endif %} - {% endfor %} - """ - def __init__(self, generation_kwargs: Dict[str, Any]) -> None: super().__init__(generation_kwargs) # We pop the model_max_length as it is not sent to the model @@ -190,7 +164,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: model_max_length = self.generation_kwargs.get("model_max_length", 4096) # Truncate prompt if prompt tokens > model_max_length-max_length self.prompt_handler = DefaultPromptHandler( - model="gpt2", # use gpt2 tokenizer to estimate prompt length + model="meta-llama/Llama-2-7b-chat-hf", model_max_length=model_max_length, max_length=self.generation_kwargs.get("max_gen_len") or 512, ) @@ -208,7 +182,7 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( - conversation=messages, tokenize=False, chat_template=self.chat_template + conversation=messages, tokenize=False ) return prepared_prompt diff --git a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py index 866a5a99d..622bae4ef 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock, patch import pytest -from haystack.components.generators.utils import default_streaming_callback +from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator @@ -50,7 +50,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): aws_profile_name="some_fake_profile", aws_region_name="fake_region", generation_kwargs={"temperature": 0.7}, - streaming_callback=default_streaming_callback, + streaming_callback=print_streaming_chunk, ) expected_dict = { "type": clazz, @@ -58,7 +58,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): "model": "anthropic.claude-v2", "generation_kwargs": {"temperature": 0.7}, "stop_words": [], - "streaming_callback": default_streaming_callback, + "streaming_callback": print_streaming_chunk, }, } @@ -75,13 +75,13 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): "init_parameters": { "model": "anthropic.claude-v2", "generation_kwargs": {"temperature": 0.7}, - "streaming_callback": "haystack.components.generators.utils.default_streaming_callback", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", }, } ) assert generator.model == "anthropic.claude-v2" assert generator.model_adapter.generation_kwargs == {"temperature": 0.7} - assert generator.streaming_callback == default_streaming_callback + assert generator.streaming_callback == print_streaming_chunk def test_default_constructor(mock_auto_tokenizer, mock_boto3_session):