From 2a62b8c48d30dcdbec2549cb342ef942ca13ff15 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 7 Feb 2024 11:35:09 +0100 Subject: [PATCH] Hook prompt length check --- .../amazon_bedrock/chat/adapters.py | 48 ++++++++++++++++--- .../amazon_bedrock/chat/chat_generator.py | 6 +-- .../generators/amazon_bedrock/handlers.py | 5 ++ 3 files changed, 47 insertions(+), 12 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 0c6335635..37c490f6b 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -72,6 +72,29 @@ def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str return kwargs + def _ensure_token_limit(self, prompt: str) -> str: + resize_info = self.check_prompt(prompt) + if resize_info["prompt_length"] != resize_info["new_prompt_length"]: + logger.warning( + "The prompt was truncated from %s tokens to %s tokens so that the prompt length and " + "the answer length (%s tokens) fit within the model's max token limit (%s tokens). " + "Shorten the prompt or it will be cut off.", + resize_info["prompt_length"], + max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore + resize_info["max_length"], + resize_info["model_max_length"], + ) + return str(resize_info["resized_prompt"]) + + @abstractmethod + def check_prompt(self, prompt: str) -> Dict[str, Any]: + """ + Checks the prompt length and resizes it if necessary. + + :param prompt: The prompt to check. + :return: A dictionary containing the resized prompt and additional information. + """ + @abstractmethod def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """Extracts the responses from the Amazon Bedrock response.""" @@ -89,14 +112,14 @@ class AnthropicClaudeChatAdapter(BedrockModelChatAdapter): ANTHROPIC_USER_TOKEN = "\n\nHuman:" ANTHROPIC_ASSISTANT_TOKEN = "\n\nAssistant:" - def __init__(self, generation_kwargs: Dict[str, Any]) -> None: + def __init__(self, generation_kwargs: Dict[str, Any]): super().__init__(generation_kwargs) # We pop the model_max_length as it is not sent to the model # but used to truncate the prompt if needed # Anthropic Claude has a limit of at least 100000 tokens # https://docs.anthropic.com/claude/reference/input-and-output-sizes - model_max_length = self.generation_kwargs.get("model_max_length", 100000) + model_max_length = self.generation_kwargs.pop("model_max_length", 100000) # Truncate prompt if prompt tokens > model_max_length-max_length # (max_length is the length of the generated text) @@ -142,7 +165,11 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: invalid_role = f"Invalid role {message.role} for message {message.content}" raise ValueError(invalid_role) - return "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " " + prepared_prompt = "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " " + return self._ensure_token_limit(prepared_prompt) + + def check_prompt(self, prompt: str) -> Dict[str, Any]: + return self.prompt_handler(prompt) def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: metadata = {k: v for (k, v) in response_body.items() if k != "completion"} @@ -187,9 +214,13 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: # We pop the model_max_length as it is not sent to the model # but used to truncate the prompt if needed # Llama 2 has context window size of 4096 tokens - model_max_length = self.generation_kwargs.get("model_max_length", 4096) - # Truncate prompt if prompt tokens > model_max_length-max_length - tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("gpt2") + # with some exceptions when the context window has been extended + model_max_length = self.generation_kwargs.pop("model_max_length", 4096) + + # Use `google/flan-t5-base` as it's also BPE sentencepiece tokenizer just like llama 2 + # a) we should get good estimates for the prompt length (empirically close to llama 2) + # b) we can use apply_chat_template with the template above to delineate ChatMessages + tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") tokenizer.bos_token = "" tokenizer.eos_token = "" tokenizer.unk_token = "" @@ -214,7 +245,10 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( conversation=messages, tokenize=False, chat_template=self.chat_template ) - return prepared_prompt + return self._ensure_token_limit(prepared_prompt) + + def check_prompt(self, prompt: str) -> Dict[str, Any]: + return self.prompt_handler(prompt) def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: metadata = {k: v for (k, v) in response_body.items() if k != "generation"} diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 94bec3a72..fda9d4fff 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -153,11 +153,7 @@ def invoke(self, *args, **kwargs): response_body = json.loads(response.get("body").read().decode("utf-8")) responses = self.model_adapter.get_responses(response_body=response_body) except ClientError as exception: - msg = ( - f"Could not connect to Amazon Bedrock model {self.model}. " - f"Make sure your AWS environment is configured correctly, " - f"the model is available in the configured AWS region, and you have access." - ) + msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}" raise AmazonBedrockInferenceError(msg) from exception return responses diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py index 71450bec0..b7b555ec0 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py @@ -11,6 +11,11 @@ class DefaultPromptHandler: """ def __init__(self, tokenizer: Union[str, PreTrainedTokenizerBase], model_max_length: int, max_length: int = 100): + """ + :param tokenizer: The tokenizer to be used to tokenize the prompt. + :param model_max_length: The maximum length of the prompt and answer tokens combined. + :param max_length: The maximum length of the answer tokens. + """ if isinstance(tokenizer, str): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) elif isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):