Skip to content

Commit

Permalink
Revert back to llama2 tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Feb 6, 2024
1 parent 3e842c5 commit af9bc60
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,32 +156,6 @@ class MetaLlama2ChatAdapter(BedrockModelChatAdapter):
Model adapter for the Meta Llama 2 models.
"""

# Llama 2 chat template
chat_template = """
{% if messages[0]['role'] == 'system' %}
{% set loop_messages = messages[1:] %}
{% set system_message = messages[0]['content'] %}
{% else %}
{% set loop_messages = messages %}
{% set system_message = false %}
{% endif %}
{% for message in loop_messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}
{% if loop.index0 == 0 and system_message != false %}
{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}
{% else %}
{% set content = message['content'] %}
{% endif %}
{% if message['role'] == 'user' %}
{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ ' ' + content.strip() + ' ' + eos_token }}
{% endif %}
{% endfor %}
"""

def __init__(self, generation_kwargs: Dict[str, Any]) -> None:
super().__init__(generation_kwargs)
# We pop the model_max_length as it is not sent to the model
Expand All @@ -190,7 +164,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None:
model_max_length = self.generation_kwargs.get("model_max_length", 4096)
# Truncate prompt if prompt tokens > model_max_length-max_length
self.prompt_handler = DefaultPromptHandler(
model="gpt2", # use gpt2 tokenizer to estimate prompt length
model="meta-llama/Llama-2-7b-chat-hf",
model_max_length=model_max_length,
max_length=self.generation_kwargs.get("max_gen_len") or 512,
)
Expand All @@ -208,7 +182,7 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[

def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template(
conversation=messages, tokenize=False, chat_template=self.chat_template
conversation=messages, tokenize=False
)
return prepared_prompt

Expand Down
10 changes: 5 additions & 5 deletions integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unittest.mock import MagicMock, patch

import pytest
from haystack.components.generators.utils import default_streaming_callback
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage

from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator
Expand Down Expand Up @@ -50,15 +50,15 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session):
aws_profile_name="some_fake_profile",
aws_region_name="fake_region",
generation_kwargs={"temperature": 0.7},
streaming_callback=default_streaming_callback,
streaming_callback=print_streaming_chunk,
)
expected_dict = {
"type": clazz,
"init_parameters": {
"model": "anthropic.claude-v2",
"generation_kwargs": {"temperature": 0.7},
"stop_words": [],
"streaming_callback": default_streaming_callback,
"streaming_callback": print_streaming_chunk,
},
}

Expand All @@ -75,13 +75,13 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session):
"init_parameters": {
"model": "anthropic.claude-v2",
"generation_kwargs": {"temperature": 0.7},
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
},
}
)
assert generator.model == "anthropic.claude-v2"
assert generator.model_adapter.generation_kwargs == {"temperature": 0.7}
assert generator.streaming_callback == default_streaming_callback
assert generator.streaming_callback == print_streaming_chunk


def test_default_constructor(mock_auto_tokenizer, mock_boto3_session):
Expand Down

0 comments on commit af9bc60

Please sign in to comment.