diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 67e833f73..1f0430810 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -2,7 +2,7 @@ import logging import os from abc import ABC, abstractmethod -from typing import Any, Callable, ClassVar, Dict, List +from typing import Any, Callable, ClassVar, Dict, List, Optional from botocore.eventstream import EventStream from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk @@ -21,11 +21,12 @@ class BedrockModelChatAdapter(ABC): focusing on preparing the requests and extracting the responses from the Amazon Bedrock hosted chat LLMs. """ - def __init__(self, generation_kwargs: Dict[str, Any]) -> None: + def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None: """ - Initializes the chat adapter with the generation kwargs. + Initializes the chat adapter with the truncate parameter and generation kwargs. """ self.generation_kwargs = generation_kwargs + self.truncate = truncate @abstractmethod def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: @@ -166,13 +167,14 @@ class AnthropicClaudeChatAdapter(BedrockModelChatAdapter): "system", ] - def __init__(self, generation_kwargs: Dict[str, Any]): + def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): """ Initializes the Anthropic Claude chat adapter. + :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. :param generation_kwargs: The generation kwargs. """ - super().__init__(generation_kwargs) + super().__init__(truncate, generation_kwargs) # We pop the model_max_length as it is not sent to the model # but used to truncate the prompt if needed @@ -216,7 +218,7 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]: Prepares the chat messages for the Anthropic Claude request. :param messages: The chat messages to prepare. - :returns: The prepared chat messages as a string. + :returns: The prepared chat messages as a dictionary. """ body: Dict[str, Any] = {} system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None @@ -225,6 +227,11 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]: ] if system: body["system"] = system + # Ensure token limit for each message in the body + if self.truncate: + for message in body["messages"]: + for content in message["content"]: + content["text"] = self._ensure_token_limit(content["text"]) return body def check_prompt(self, prompt: str) -> Dict[str, Any]: @@ -316,13 +323,13 @@ class MistralChatAdapter(BedrockModelChatAdapter): "top_p", ] - def __init__(self, generation_kwargs: Dict[str, Any]): + def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): """ Initializes the Mistral chat adapter. - + :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. :param generation_kwargs: The generation kwargs. """ - super().__init__(generation_kwargs) + super().__init__(truncate, generation_kwargs) # We pop the model_max_length as it is not sent to the model # but used to truncate the prompt if needed @@ -384,7 +391,9 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( conversation=[self.to_openai_format(m) for m in messages], tokenize=False, chat_template=self.chat_template ) - return self._ensure_token_limit(prepared_prompt) + if self.truncate: + prepared_prompt = self._ensure_token_limit(prepared_prompt) + return prepared_prompt def to_openai_format(self, m: ChatMessage) -> Dict[str, Any]: """ @@ -470,12 +479,13 @@ class MetaLlama2ChatAdapter(BedrockModelChatAdapter): "{% endfor %}" ) - def __init__(self, generation_kwargs: Dict[str, Any]) -> None: + def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None: """ Initializes the Meta Llama 2 chat adapter. + :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. :param generation_kwargs: The generation kwargs. """ - super().__init__(generation_kwargs) + super().__init__(truncate, generation_kwargs) # We pop the model_max_length as it is not sent to the model # but used to truncate the prompt if needed # Llama 2 has context window size of 4096 tokens @@ -519,7 +529,10 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( conversation=messages, tokenize=False, chat_template=self.chat_template ) - return self._ensure_token_limit(prepared_prompt) + + if self.truncate: + prepared_prompt = self._ensure_token_limit(prepared_prompt) + return prepared_prompt def check_prompt(self, prompt: str) -> Dict[str, Any]: """ diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index f7bb0ba23..5fa9e0b8a 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -76,6 +76,7 @@ def __init__( generation_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + truncate: Optional[bool] = True, ): """ Initializes the `AmazonBedrockChatGenerator` with the provided parameters. The parameters are passed to the @@ -108,6 +109,7 @@ def __init__( function that handles the streaming chunks. The callback function receives a [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and switches the streaming mode on. + :param truncate: Whether to truncate the prompt messages or not. """ if not model: msg = "'model' cannot be None or empty string" @@ -118,13 +120,14 @@ def __init__( self.aws_session_token = aws_session_token self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name + self.truncate = truncate # get the model adapter for the given model model_adapter_cls = self.get_model_adapter(model=model) if not model_adapter_cls: msg = f"AmazonBedrockGenerator doesn't support the model {model}." raise AmazonBedrockConfigurationError(msg) - self.model_adapter = model_adapter_cls(generation_kwargs or {}) + self.model_adapter = model_adapter_cls(self.truncate, generation_kwargs or {}) # create the AWS session and client def resolve_secret(secret: Optional[Secret]) -> Optional[str]: @@ -243,6 +246,7 @@ def to_dict(self) -> Dict[str, Any]: stop_words=self.stop_words, generation_kwargs=self.model_adapter.generation_kwargs, streaming_callback=callback_name, + truncate=self.truncate, ) @classmethod diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 79a04d52b..64e9ce2ef 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,7 +1,7 @@ import logging import os from typing import Optional, Type -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from haystack.components.generators.utils import print_streaming_chunk @@ -45,6 +45,7 @@ def test_to_dict(mock_boto3_session): "generation_kwargs": {"temperature": 0.7}, "stop_words": [], "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "truncate": True, }, } @@ -67,6 +68,7 @@ def test_from_dict(mock_boto3_session): "model": "anthropic.claude-v2", "generation_kwargs": {"temperature": 0.7}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "truncate": True, }, } ) @@ -85,7 +87,7 @@ def test_default_constructor(mock_boto3_session, set_env_variables): ) assert layer.model == "anthropic.claude-v2" - + assert layer.truncate is True assert layer.model_adapter.prompt_handler is not None assert layer.model_adapter.prompt_handler.model_max_length == 100000 @@ -111,6 +113,15 @@ def test_constructor_with_generation_kwargs(mock_boto3_session): layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", generation_kwargs=generation_kwargs) assert "temperature" in layer.model_adapter.generation_kwargs assert layer.model_adapter.generation_kwargs["temperature"] == 0.7 + assert layer.model_adapter.truncate is True + + +def test_constructor_with_truncate(mock_boto3_session): + """ + Test that truncate param is correctly set in the model constructor + """ + layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", truncate=False) + assert layer.model_adapter.truncate is False def test_constructor_with_empty_model(): @@ -121,6 +132,123 @@ def test_constructor_with_empty_model(): AmazonBedrockChatGenerator(model="") +def test_short_prompt_is_not_truncated(mock_boto3_session): + """ + Test that a short prompt is not truncated + """ + # Define a short mock prompt and its tokenized version + mock_prompt_text = "I am a tokenized prompt" + mock_prompt_tokens = mock_prompt_text.split() + + # Mock the tokenizer so it returns our predefined tokens + mock_tokenizer = MagicMock() + mock_tokenizer.tokenize.return_value = mock_prompt_tokens + + # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens + # Since our mock prompt is 5 tokens long, it doesn't exceed the + # total limit (5 prompt tokens + 3 generated tokens < 10 tokens) + max_length_generated_text = 3 + total_model_max_length = 10 + + with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): + layer = AmazonBedrockChatGenerator( + "anthropic.claude-v2", + generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, + ) + prompt_after_resize = layer.model_adapter._ensure_token_limit(mock_prompt_text) + + # The prompt doesn't exceed the limit, _ensure_token_limit doesn't truncate it + assert prompt_after_resize == mock_prompt_text + + +def test_long_prompt_is_truncated(mock_boto3_session): + """ + Test that a long prompt is truncated + """ + # Define a long mock prompt and its tokenized version + long_prompt_text = "I am a tokenized prompt of length eight" + long_prompt_tokens = long_prompt_text.split() + + # _ensure_token_limit will truncate the prompt to make it fit into the model's max token limit + truncated_prompt_text = "I am a tokenized prompt of length" + + # Mock the tokenizer to return our predefined tokens + # convert tokens to our predefined truncated text + mock_tokenizer = MagicMock() + mock_tokenizer.tokenize.return_value = long_prompt_tokens + mock_tokenizer.convert_tokens_to_string.return_value = truncated_prompt_text + + # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens + # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) + max_length_generated_text = 3 + total_model_max_length = 10 + + with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): + layer = AmazonBedrockChatGenerator( + "anthropic.claude-v2", + generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, + ) + prompt_after_resize = layer.model_adapter._ensure_token_limit(long_prompt_text) + + # The prompt exceeds the limit, _ensure_token_limit truncates it + assert prompt_after_resize == truncated_prompt_text + + +def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): + """ + Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False + """ + messages = [ChatMessage.from_system("What is the biggest city in United States?")] + + # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) + max_length_generated_text = 3 + total_model_max_length = 10 + + with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()): + generator = AmazonBedrockChatGenerator( + model="anthropic.claude-v2", + truncate=False, + generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, + ) + + # Mock the _ensure_token_limit method to track if it is called + with patch.object( + generator.model_adapter, "_ensure_token_limit", wraps=generator.model_adapter._ensure_token_limit + ) as mock_ensure_token_limit: + # Mock the model adapter to avoid actual invocation + generator.model_adapter.prepare_body = MagicMock(return_value={}) + generator.client = MagicMock() + generator.client.invoke_model = MagicMock( + return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))} + ) + + generator.model_adapter.get_responses = MagicMock( + return_value=[ + ChatMessage( + content="Some text", + role=ChatRole.ASSISTANT, + name=None, + meta=[ + { + "model": "claude-3-sonnet-20240229", + "index": 0, + "finish_reason": "end_turn", + "usage": {"prompt_tokens": 16, "completion_tokens": 55}, + } + ], + ) + ] + ) + # Invoke the generator + generator.run(messages=messages) + + # Ensure _ensure_token_limit was not called + mock_ensure_token_limit.assert_not_called(), + + # Check the prompt passed to prepare_body + generator.model_adapter.prepare_body.assert_called_with(messages=messages, stop_words=[], stream=False) + + @pytest.mark.parametrize( "model, expected_model_adapter", [ @@ -144,7 +272,7 @@ def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[Bed class TestAnthropicClaudeAdapter: def test_prepare_body_with_default_params(self) -> None: - layer = AnthropicClaudeChatAdapter(generation_kwargs={}) + layer = AnthropicClaudeChatAdapter(truncate=True, generation_kwargs={}) prompt = "Hello, how are you?" expected_body = { "anthropic_version": "bedrock-2023-05-31", @@ -157,7 +285,9 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body def test_prepare_body_with_custom_inference_params(self) -> None: - layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) + layer = AnthropicClaudeChatAdapter( + truncate=True, generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4} + ) prompt = "Hello, how are you?" expected_body = { "anthropic_version": "bedrock-2023-05-31", @@ -178,7 +308,7 @@ def test_prepare_body_with_custom_inference_params(self) -> None: class TestMistralAdapter: def test_prepare_body_with_default_params(self) -> None: - layer = MistralChatAdapter(generation_kwargs={}) + layer = MistralChatAdapter(truncate=True, generation_kwargs={}) prompt = "Hello, how are you?" expected_body = { "max_tokens": 512, @@ -190,7 +320,7 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body def test_prepare_body_with_custom_inference_params(self) -> None: - layer = MistralChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) + layer = MistralChatAdapter(truncate=True, generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) prompt = "Hello, how are you?" expected_body = { "prompt": "[INST] Hello, how are you? [/INST]", @@ -204,12 +334,12 @@ def test_prepare_body_with_custom_inference_params(self) -> None: assert body == expected_body def test_mistral_chat_template_correct_order(self): - layer = MistralChatAdapter(generation_kwargs={}) + layer = MistralChatAdapter(truncate=True, generation_kwargs={}) layer.prepare_body([ChatMessage.from_user("A"), ChatMessage.from_assistant("B"), ChatMessage.from_user("C")]) layer.prepare_body([ChatMessage.from_system("A"), ChatMessage.from_user("B"), ChatMessage.from_assistant("C")]) def test_mistral_chat_template_incorrect_order(self): - layer = MistralChatAdapter(generation_kwargs={}) + layer = MistralChatAdapter(truncate=True, generation_kwargs={}) try: layer.prepare_body([ChatMessage.from_assistant("B"), ChatMessage.from_assistant("C")]) msg = "Expected TemplateError" @@ -238,7 +368,7 @@ def test_use_mistral_adapter_without_hf_token(self, monkeypatch, caplog) -> None patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), caplog.at_level(logging.WARNING), ): - MistralChatAdapter(generation_kwargs={}) + MistralChatAdapter(truncate=True, generation_kwargs={}) mock_pretrained.assert_called_with("NousResearch/Llama-2-7b-chat-hf") assert "no HF_TOKEN was found" in caplog.text @@ -248,7 +378,7 @@ def test_use_mistral_adapter_with_hf_token(self, monkeypatch) -> None: patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), ): - MistralChatAdapter(generation_kwargs={}) + MistralChatAdapter(truncate=True, generation_kwargs={}) mock_pretrained.assert_called_with("mistralai/Mistral-7B-Instruct-v0.1") @pytest.mark.skipif( @@ -291,7 +421,7 @@ class TestMetaLlama2ChatAdapter: def test_prepare_body_with_default_params(self) -> None: # leave this test as integration because we really need only tokenizer from HF # that way we can ensure prompt chat message formatting - layer = MetaLlama2ChatAdapter(generation_kwargs={}) + layer = MetaLlama2ChatAdapter(truncate=True, generation_kwargs={}) prompt = "Hello, how are you?" expected_body = {"prompt": "[INST] Hello, how are you? [/INST]", "max_gen_len": 512} @@ -304,7 +434,8 @@ def test_prepare_body_with_custom_inference_params(self) -> None: # leave this test as integration because we really need only tokenizer from HF # that way we can ensure prompt chat message formatting layer = MetaLlama2ChatAdapter( - generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]} + truncate=True, + generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]}, ) prompt = "Hello, how are you?" @@ -329,7 +460,7 @@ def test_prepare_body_with_custom_inference_params(self) -> None: @pytest.mark.integration def test_get_responses(self) -> None: - adapter = MetaLlama2ChatAdapter(generation_kwargs={}) + adapter = MetaLlama2ChatAdapter(truncate=True, generation_kwargs={}) response_body = {"generation": "This is a single response."} expected_response = "This is a single response." response_message = adapter.get_responses(response_body)