diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index c6c814de4..941fdbf71 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -152,15 +152,16 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: # We pop the model_max_length as it is not sent to the model but used to truncate the prompt if needed model_max_length = kwargs.get("model_max_length", 4096) - # Truncate prompt if prompt tokens > model_max_length-max_length - # (max_length is the length of the generated text) - # we use GPT2 tokenizer which will likely provide good token count approximation - - self.prompt_handler = DefaultPromptHandler( - tokenizer="gpt2", - model_max_length=model_max_length, - max_length=self.max_length or 100, - ) + # we initialize the prompt handler only if truncate is True: we avoid unnecessarily downloading the tokenizer + if self.truncate: + # Truncate prompt if prompt tokens > model_max_length-max_length + # (max_length is the length of the generated text) + # we use GPT2 tokenizer which will likely provide good token count approximation + self.prompt_handler = DefaultPromptHandler( + tokenizer="gpt2", + model_max_length=model_max_length, + max_length=self.max_length or 100, + ) model_adapter_cls = self.get_model_adapter(model=model) if not model_adapter_cls: diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index 79246b4aa..be645218e 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -108,6 +108,14 @@ def test_constructor_prompt_handler_initialized(mock_boto3_session, mock_prompt_ assert layer.prompt_handler.model_max_length == 4096 +def test_prompt_handler_absent_when_truncate_false(mock_boto3_session): + """ + Test that the prompt_handler is not initialized when truncate is set to False. + """ + generator = AmazonBedrockGenerator(model="anthropic.claude-v2", truncate=False) + assert not hasattr(generator, "prompt_handler") + + def test_constructor_with_model_kwargs(mock_boto3_session): """ Test that model_kwargs are correctly set in the constructor