From 728dc21432ef78b715765108300d99643fbc36a3 Mon Sep 17 00:00:00 2001 From: tstadel Date: Mon, 14 Oct 2024 18:42:44 +0200 Subject: [PATCH] fix: avoid bedrock read timeout --- .../components/generators/amazon_bedrock/generator.py | 10 +++++++++- integrations/amazon_bedrock/tests/test_generator.py | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 1edde3526..08df615e2 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -4,6 +4,7 @@ from typing import Any, Callable, ClassVar, Dict, List, Optional, Type from botocore.exceptions import ClientError +from botocore.config import Config from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import StreamingChunk from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable @@ -87,6 +88,7 @@ def __init__( max_length: Optional[int] = 100, truncate: Optional[bool] = True, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + boto3_config: Optional[Dict[str, Any]] = None, **kwargs, ): """ @@ -102,6 +104,7 @@ def __init__( :param truncate: Whether to truncate the prompt or not. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. + :param boto3_config: The configuration for the boto3 client. :param kwargs: Additional keyword arguments to be passed to the model. These arguments are specific to the model. You can find them in the model's documentation. :raises ValueError: If the model name is empty or None. @@ -120,6 +123,7 @@ def __init__( self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name self.streaming_callback = streaming_callback + self.boto3_config = boto3_config self.kwargs = kwargs def resolve_secret(secret: Optional[Secret]) -> Optional[str]: @@ -133,7 +137,10 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: aws_region_name=resolve_secret(aws_region_name), aws_profile_name=resolve_secret(aws_profile_name), ) - self.client = session.client("bedrock-runtime") + config: Optional[Config] = None + if self.boto3_config: + config = Config(**self.boto3_config) + self.client = session.client("bedrock-runtime", config=config) except Exception as exception: msg = ( "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " @@ -273,6 +280,7 @@ def to_dict(self) -> Dict[str, Any]: max_length=self.max_length, truncate=self.truncate, streaming_callback=callback_name, + boto3_config=self.boto3_config, **self.kwargs, ) diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index 61ae9d6b4..a14e2e1c4 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -36,6 +36,7 @@ def test_to_dict(mock_boto3_session): "truncate": False, "temperature": 10, "streaming_callback": None, + "botot3_config": None, }, }