Skip to content

Commit

Permalink
fix: avoid bedrock read timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
tstadel committed Oct 14, 2024
1 parent 518cf27 commit 728dc21
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, ClassVar, Dict, List, Optional, Type

from botocore.exceptions import ClientError
from botocore.config import Config
from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import StreamingChunk
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(
max_length: Optional[int] = 100,
truncate: Optional[bool] = True,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
boto3_config: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""
Expand All @@ -102,6 +104,7 @@ def __init__(
:param truncate: Whether to truncate the prompt or not.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param boto3_config: The configuration for the boto3 client.
:param kwargs: Additional keyword arguments to be passed to the model.
These arguments are specific to the model. You can find them in the model's documentation.
:raises ValueError: If the model name is empty or None.
Expand All @@ -120,6 +123,7 @@ def __init__(
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.streaming_callback = streaming_callback
self.boto3_config = boto3_config
self.kwargs = kwargs

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
Expand All @@ -133,7 +137,10 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self.client = session.client("bedrock-runtime")
config: Optional[Config] = None
if self.boto3_config:
config = Config(**self.boto3_config)
self.client = session.client("bedrock-runtime", config=config)
except Exception as exception:
msg = (
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
Expand Down Expand Up @@ -273,6 +280,7 @@ def to_dict(self) -> Dict[str, Any]:
max_length=self.max_length,
truncate=self.truncate,
streaming_callback=callback_name,
boto3_config=self.boto3_config,
**self.kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions integrations/amazon_bedrock/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_to_dict(mock_boto3_session):
"truncate": False,
"temperature": 10,
"streaming_callback": None,
"botot3_config": None,
},
}

Expand Down

0 comments on commit 728dc21

Please sign in to comment.