From 13279db6dd162880749e98eed4471291e144f97a Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 6 Feb 2024 14:24:58 +0100 Subject: [PATCH] Use gpt2 with special_token_map, use llama2 chat template --- .../amazon_bedrock/chat/adapters.py | 36 +++++++++++++++++-- .../generators/amazon_bedrock/generator.py | 2 +- .../generators/amazon_bedrock/handlers.py | 13 +++++-- .../tests/test_amazon_chat_bedrock.py | 3 +- 4 files changed, 46 insertions(+), 8 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 1b1287c59..0c6335635 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from transformers import AutoTokenizer, PreTrainedTokenizer from haystack_integrations.components.generators.amazon_bedrock.handlers import DefaultPromptHandler @@ -102,7 +103,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: # TODO use Anthropic tokenizer to get the precise prompt length # See https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#token-counting self.prompt_handler = DefaultPromptHandler( - model="gpt2", + tokenizer="gpt2", model_max_length=model_max_length, max_length=self.generation_kwargs.get("max_tokens_to_sample") or 512, ) @@ -156,6 +157,31 @@ class MetaLlama2ChatAdapter(BedrockModelChatAdapter): Model adapter for the Meta Llama 2 models. """ + chat_template = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'] %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" + "{% set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}" + ) + def __init__(self, generation_kwargs: Dict[str, Any]) -> None: super().__init__(generation_kwargs) # We pop the model_max_length as it is not sent to the model @@ -163,8 +189,12 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: # Llama 2 has context window size of 4096 tokens model_max_length = self.generation_kwargs.get("model_max_length", 4096) # Truncate prompt if prompt tokens > model_max_length-max_length + tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer.bos_token = "" + tokenizer.eos_token = "" + tokenizer.unk_token = "" self.prompt_handler = DefaultPromptHandler( - model="meta-llama/Llama-2-7b-chat-hf", + tokenizer=tokenizer, model_max_length=model_max_length, max_length=self.generation_kwargs.get("max_gen_len") or 512, ) @@ -182,7 +212,7 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( - conversation=messages, tokenize=False + conversation=messages, tokenize=False, chat_template=self.chat_template ) return prepared_prompt diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 2d19159f9..48f22f59b 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -104,7 +104,7 @@ def __init__( # It is hard to determine which tokenizer to use for the SageMaker model # so we use GPT2 tokenizer which will likely provide good token count approximation self.prompt_handler = DefaultPromptHandler( - model="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100 + tokenizer="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100 ) model_adapter_cls = self.get_model_adapter(model=model) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py index 56dcb24d3..71450bec0 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, Union -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast class DefaultPromptHandler: @@ -10,8 +10,15 @@ class DefaultPromptHandler: are within the model_max_length. """ - def __init__(self, model: str, model_max_length: int, max_length: int = 100): - self.tokenizer = AutoTokenizer.from_pretrained(model) + def __init__(self, tokenizer: Union[str, PreTrainedTokenizerBase], model_max_length: int, max_length: int = 100): + if isinstance(tokenizer, str): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) + elif isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + self.tokenizer = tokenizer + else: + msg = "model must be a string or a PreTrainedTokenizer instance" + raise ValueError(msg) + self.tokenizer.model_max_length = model_max_length self.model_max_length = model_max_length self.max_length = max_length diff --git a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py index 622bae4ef..9592b5b39 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py @@ -241,7 +241,8 @@ def test_prepare_body_with_custom_inference_params(self) -> None: assert body == expected_body - def test_get_responses(self, mock_auto_tokenizer) -> None: + @pytest.mark.integration + def test_get_responses(self) -> None: adapter = MetaLlama2ChatAdapter(generation_kwargs={}) response_body = {"generation": "This is a single response."} expected_response = "This is a single response."