Skip to content

Commit

Permalink
feat: make truncation optional for bedrock chat generator (#967)
Browse files Browse the repository at this point in the history
* Added truncate param to chat generator and adapters
* Added tests to check truncation

* Add doc_string

* Fixed linting
  • Loading branch information
Amnah199 authored Aug 13, 2024
1 parent 0451e6f commit 4c8c881
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, Callable, ClassVar, Dict, List
from typing import Any, Callable, ClassVar, Dict, List, Optional

from botocore.eventstream import EventStream
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk
Expand All @@ -21,11 +21,12 @@ class BedrockModelChatAdapter(ABC):
focusing on preparing the requests and extracting the responses from the Amazon Bedrock hosted chat LLMs.
"""

def __init__(self, generation_kwargs: Dict[str, Any]) -> None:
def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None:
"""
Initializes the chat adapter with the generation kwargs.
Initializes the chat adapter with the truncate parameter and generation kwargs.
"""
self.generation_kwargs = generation_kwargs
self.truncate = truncate

@abstractmethod
def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]:
Expand Down Expand Up @@ -166,13 +167,14 @@ class AnthropicClaudeChatAdapter(BedrockModelChatAdapter):
"system",
]

def __init__(self, generation_kwargs: Dict[str, Any]):
def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]):
"""
Initializes the Anthropic Claude chat adapter.
:param truncate: Whether to truncate the prompt if it exceeds the model's max token limit.
:param generation_kwargs: The generation kwargs.
"""
super().__init__(generation_kwargs)
super().__init__(truncate, generation_kwargs)

# We pop the model_max_length as it is not sent to the model
# but used to truncate the prompt if needed
Expand Down Expand Up @@ -216,7 +218,7 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]:
Prepares the chat messages for the Anthropic Claude request.
:param messages: The chat messages to prepare.
:returns: The prepared chat messages as a string.
:returns: The prepared chat messages as a dictionary.
"""
body: Dict[str, Any] = {}
system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None
Expand All @@ -225,6 +227,11 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]:
]
if system:
body["system"] = system
# Ensure token limit for each message in the body
if self.truncate:
for message in body["messages"]:
for content in message["content"]:
content["text"] = self._ensure_token_limit(content["text"])
return body

def check_prompt(self, prompt: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -316,13 +323,13 @@ class MistralChatAdapter(BedrockModelChatAdapter):
"top_p",
]

def __init__(self, generation_kwargs: Dict[str, Any]):
def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]):
"""
Initializes the Mistral chat adapter.
:param truncate: Whether to truncate the prompt if it exceeds the model's max token limit.
:param generation_kwargs: The generation kwargs.
"""
super().__init__(generation_kwargs)
super().__init__(truncate, generation_kwargs)

# We pop the model_max_length as it is not sent to the model
# but used to truncate the prompt if needed
Expand Down Expand Up @@ -384,7 +391,9 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template(
conversation=[self.to_openai_format(m) for m in messages], tokenize=False, chat_template=self.chat_template
)
return self._ensure_token_limit(prepared_prompt)
if self.truncate:
prepared_prompt = self._ensure_token_limit(prepared_prompt)
return prepared_prompt

def to_openai_format(self, m: ChatMessage) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -470,12 +479,13 @@ class MetaLlama2ChatAdapter(BedrockModelChatAdapter):
"{% endfor %}"
)

def __init__(self, generation_kwargs: Dict[str, Any]) -> None:
def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None:
"""
Initializes the Meta Llama 2 chat adapter.
:param truncate: Whether to truncate the prompt if it exceeds the model's max token limit.
:param generation_kwargs: The generation kwargs.
"""
super().__init__(generation_kwargs)
super().__init__(truncate, generation_kwargs)
# We pop the model_max_length as it is not sent to the model
# but used to truncate the prompt if needed
# Llama 2 has context window size of 4096 tokens
Expand Down Expand Up @@ -519,7 +529,10 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template(
conversation=messages, tokenize=False, chat_template=self.chat_template
)
return self._ensure_token_limit(prepared_prompt)

if self.truncate:
prepared_prompt = self._ensure_token_limit(prepared_prompt)
return prepared_prompt

def check_prompt(self, prompt: str) -> Dict[str, Any]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
truncate: Optional[bool] = True,
):
"""
Initializes the `AmazonBedrockChatGenerator` with the provided parameters. The parameters are passed to the
Expand Down Expand Up @@ -108,6 +109,7 @@ def __init__(
function that handles the streaming chunks. The callback function receives a
[StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and
switches the streaming mode on.
:param truncate: Whether to truncate the prompt messages or not.
"""
if not model:
msg = "'model' cannot be None or empty string"
Expand All @@ -118,13 +120,14 @@ def __init__(
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.truncate = truncate

# get the model adapter for the given model
model_adapter_cls = self.get_model_adapter(model=model)
if not model_adapter_cls:
msg = f"AmazonBedrockGenerator doesn't support the model {model}."
raise AmazonBedrockConfigurationError(msg)
self.model_adapter = model_adapter_cls(generation_kwargs or {})
self.model_adapter = model_adapter_cls(self.truncate, generation_kwargs or {})

# create the AWS session and client
def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
Expand Down Expand Up @@ -243,6 +246,7 @@ def to_dict(self) -> Dict[str, Any]:
stop_words=self.stop_words,
generation_kwargs=self.model_adapter.generation_kwargs,
streaming_callback=callback_name,
truncate=self.truncate,
)

@classmethod
Expand Down
Loading

0 comments on commit 4c8c881

Please sign in to comment.