Skip to content

Commit

Permalink
Use gpt2 with special_token_map, use llama2 chat template
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Feb 7, 2024
1 parent baac9fe commit 13279db
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, List

from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk
from transformers import AutoTokenizer, PreTrainedTokenizer

from haystack_integrations.components.generators.amazon_bedrock.handlers import DefaultPromptHandler

Expand Down Expand Up @@ -102,7 +103,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None:
# TODO use Anthropic tokenizer to get the precise prompt length
# See https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#token-counting
self.prompt_handler = DefaultPromptHandler(
model="gpt2",
tokenizer="gpt2",
model_max_length=model_max_length,
max_length=self.generation_kwargs.get("max_tokens_to_sample") or 512,
)
Expand Down Expand Up @@ -156,15 +157,44 @@ class MetaLlama2ChatAdapter(BedrockModelChatAdapter):
Model adapter for the Meta Llama 2 models.
"""

chat_template = (
"{% if messages[0]['role'] == 'system' %}"
"{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}"
"{% else %}"
"{% set loop_messages = messages %}"
"{% set system_message = false %}"
"{% endif %}"
"{% for message in loop_messages %}"
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
"{% endif %}"
"{% if loop.index0 == 0 and system_message != false %}"
"{% set content = '<<SYS>>\n' + system_message + '\n<</SYS>>\n\n' + message['content'] %}"
"{% else %}"
"{% set content = message['content'] %}"
"{% endif %}"
"{% if message['role'] == 'user' %}"
"{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
"{% elif message['role'] == 'assistant' %}"
"{{ ' ' + content.strip() + ' ' + eos_token }}"
"{% endif %}"
"{% endfor %}"
)

def __init__(self, generation_kwargs: Dict[str, Any]) -> None:
super().__init__(generation_kwargs)
# We pop the model_max_length as it is not sent to the model
# but used to truncate the prompt if needed
# Llama 2 has context window size of 4096 tokens
model_max_length = self.generation_kwargs.get("model_max_length", 4096)
# Truncate prompt if prompt tokens > model_max_length-max_length
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.bos_token = "<s>"
tokenizer.eos_token = "</s>"
tokenizer.unk_token = "<unk>"
self.prompt_handler = DefaultPromptHandler(
model="meta-llama/Llama-2-7b-chat-hf",
tokenizer=tokenizer,
model_max_length=model_max_length,
max_length=self.generation_kwargs.get("max_gen_len") or 512,
)
Expand All @@ -182,7 +212,7 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[

def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template(
conversation=messages, tokenize=False
conversation=messages, tokenize=False, chat_template=self.chat_template
)
return prepared_prompt

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(
# It is hard to determine which tokenizer to use for the SageMaker model
# so we use GPT2 tokenizer which will likely provide good token count approximation
self.prompt_handler = DefaultPromptHandler(
model="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100
tokenizer="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100
)

model_adapter_cls = self.get_model_adapter(model=model)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Dict, Union

from transformers import AutoTokenizer
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast


class DefaultPromptHandler:
Expand All @@ -10,8 +10,15 @@ class DefaultPromptHandler:
are within the model_max_length.
"""

def __init__(self, model: str, model_max_length: int, max_length: int = 100):
self.tokenizer = AutoTokenizer.from_pretrained(model)
def __init__(self, tokenizer: Union[str, PreTrainedTokenizerBase], model_max_length: int, max_length: int = 100):
if isinstance(tokenizer, str):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
elif isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
self.tokenizer = tokenizer
else:
msg = "model must be a string or a PreTrainedTokenizer instance"
raise ValueError(msg)

self.tokenizer.model_max_length = model_max_length
self.model_max_length = model_max_length
self.max_length = max_length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ def test_prepare_body_with_custom_inference_params(self) -> None:

assert body == expected_body

def test_get_responses(self, mock_auto_tokenizer) -> None:
@pytest.mark.integration
def test_get_responses(self) -> None:
adapter = MetaLlama2ChatAdapter(generation_kwargs={})
response_body = {"generation": "This is a single response."}
expected_response = "This is a single response."
Expand Down

0 comments on commit 13279db

Please sign in to comment.