Skip to content

Commit

Permalink
fix: Allow passing boto3 config to all AWS Bedrock classes (#1166)
Browse files Browse the repository at this point in the history
* Allow passing boto3 config to AmazonBedrockChatGenerator

* Allow passing boto3 config to AmazonBedrockDocumentEmbedder

* Allow passing boto3 config to AmazonBedrockTextEmbedder

* Remove whitespace from blank line

* Reorder setting attributes for readability

* Remove blank line

* fix: adapt our implementation to breaking changes in Chroma 0.5.17  (#1165)

* fix chroma breaking changes

* improve warning

* better warning

* Update the changelog

* Parametrize to_dict and from_dict tests with boto3_config

---------

Co-authored-by: Stefano Fiorucci <[email protected]>
Co-authored-by: HaystackBot <[email protected]>
Co-authored-by: David S. Batista <[email protected]>
  • Loading branch information
4 people authored Dec 3, 2024
1 parent 319b64b commit 798fb98
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from typing import Any, Dict, List, Literal, Optional

from botocore.config import Config
from botocore.exceptions import ClientError
from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
boto3_config: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""
Expand All @@ -98,6 +100,7 @@ def __init__(
to keep the logs clean.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
:param boto3_config: The configuration for the boto3 client.
:param kwargs: Additional parameters to pass for model inference. For example, `input_type` and `truncate` for
Cohere models.
:raises ValueError: If the model is not supported.
Expand All @@ -110,6 +113,19 @@ def __init__(
)
raise ValueError(msg)

self.model = model
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.batch_size = batch_size
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.boto3_config = boto3_config
self.kwargs = kwargs

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None

Expand All @@ -121,26 +137,17 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self._client = session.client("bedrock-runtime")
config: Optional[Config] = None
if self.boto3_config:
config = Config(**self.boto3_config)
self._client = session.client("bedrock-runtime", config=config)
except Exception as exception:
msg = (
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration"
)
raise AmazonBedrockConfigurationError(msg) from exception

self.model = model
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.batch_size = batch_size
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.kwargs = kwargs

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
Expand Down Expand Up @@ -269,6 +276,7 @@ def to_dict(self) -> Dict[str, Any]:
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
boto3_config=self.boto3_config,
**self.kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from typing import Any, Dict, List, Literal, Optional

from botocore.config import Config
from botocore.exceptions import ClientError
from haystack import component, default_from_dict, default_to_dict
from haystack.utils.auth import Secret, deserialize_secrets_inplace
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(
aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
boto3_config: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""
Expand All @@ -81,6 +83,7 @@ def __init__(
:param aws_session_token: AWS session token.
:param aws_region_name: AWS region name.
:param aws_profile_name: AWS profile name.
:param boto3_config: The configuration for the boto3 client.
:param kwargs: Additional parameters to pass for model inference. For example, `input_type` and `truncate` for
Cohere models.
:raises ValueError: If the model is not supported.
Expand All @@ -92,6 +95,15 @@ def __init__(
)
raise ValueError(msg)

self.model = model
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.boto3_config = boto3_config
self.kwargs = kwargs

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None

Expand All @@ -103,22 +115,17 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self._client = session.client("bedrock-runtime")
config: Optional[Config] = None
if self.boto3_config:
config = Config(**self.boto3_config)
self._client = session.client("bedrock-runtime", config=config)
except Exception as exception:
msg = (
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration"
)
raise AmazonBedrockConfigurationError(msg) from exception

self.model = model
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.kwargs = kwargs

@component.output_types(embedding=List[float])
def run(self, text: str):
"""Embeds the input text using the Amazon Bedrock model.
Expand Down Expand Up @@ -185,6 +192,7 @@ def to_dict(self) -> Dict[str, Any]:
aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None,
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
model=self.model,
boto3_config=self.boto3_config,
**self.kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
from typing import Any, Callable, ClassVar, Dict, List, Optional, Type

from botocore.config import Config
from botocore.exceptions import ClientError
from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import ChatMessage, StreamingChunk
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
truncate: Optional[bool] = True,
boto3_config: Optional[Dict[str, Any]] = None,
):
"""
Initializes the `AmazonBedrockChatGenerator` with the provided parameters. The parameters are passed to the
Expand Down Expand Up @@ -110,6 +112,11 @@ def __init__(
[StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and
switches the streaming mode on.
:param truncate: Whether to truncate the prompt messages or not.
:param boto3_config: The configuration for the boto3 client.
:raises ValueError: If the model name is empty or None.
:raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is
not supported.
"""
if not model:
msg = "'model' cannot be None or empty string"
Expand All @@ -120,7 +127,10 @@ def __init__(
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.stop_words = stop_words or []
self.streaming_callback = streaming_callback
self.truncate = truncate
self.boto3_config = boto3_config

# get the model adapter for the given model
model_adapter_cls = self.get_model_adapter(model=model)
Expand All @@ -141,17 +151,17 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self.client = session.client("bedrock-runtime")
config: Optional[Config] = None
if self.boto3_config:
config = Config(**self.boto3_config)
self.client = session.client("bedrock-runtime", config=config)
except Exception as exception:
msg = (
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration"
)
raise AmazonBedrockConfigurationError(msg) from exception

self.stop_words = stop_words or []
self.streaming_callback = streaming_callback

@component.output_types(replies=List[ChatMessage])
def run(
self,
Expand Down Expand Up @@ -256,6 +266,7 @@ def to_dict(self) -> Dict[str, Any]:
generation_kwargs=self.model_adapter.generation_kwargs,
streaming_callback=callback_name,
truncate=self.truncate,
boto3_config=self.boto3_config,
)

@classmethod
Expand Down
28 changes: 25 additions & 3 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import os
from typing import Optional, Type
from typing import Any, Dict, Optional, Type
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -26,14 +26,24 @@
]


def test_to_dict(mock_boto3_session):
@pytest.mark.parametrize(
"boto3_config",
[
None,
{
"read_timeout": 1000,
},
],
)
def test_to_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]):
"""
Test that the to_dict method returns the correct dictionary without aws credentials
"""
generator = AmazonBedrockChatGenerator(
model="anthropic.claude-v2",
generation_kwargs={"temperature": 0.7},
streaming_callback=print_streaming_chunk,
boto3_config=boto3_config,
)
expected_dict = {
"type": KLASS,
Expand All @@ -48,13 +58,23 @@ def test_to_dict(mock_boto3_session):
"stop_words": [],
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"truncate": True,
"boto3_config": boto3_config,
},
}

assert generator.to_dict() == expected_dict


def test_from_dict(mock_boto3_session):
@pytest.mark.parametrize(
"boto3_config",
[
None,
{
"read_timeout": 1000,
},
],
)
def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]):
"""
Test that the from_dict method returns the correct object
"""
Expand All @@ -71,12 +91,14 @@ def test_from_dict(mock_boto3_session):
"generation_kwargs": {"temperature": 0.7},
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"truncate": True,
"boto3_config": boto3_config,
},
}
)
assert generator.model == "anthropic.claude-v2"
assert generator.model_adapter.generation_kwargs == {"temperature": 0.7}
assert generator.streaming_callback == print_streaming_chunk
assert generator.boto3_config == boto3_config


def test_default_constructor(mock_boto3_session, set_env_variables):
Expand Down
27 changes: 25 additions & 2 deletions integrations/amazon_bedrock/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
from typing import Any, Dict, Optional
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -66,10 +67,20 @@ def test_connection_error(self, mock_boto3_session):
input_type="fake_input_type",
)

def test_to_dict(self, mock_boto3_session):
@pytest.mark.parametrize(
"boto3_config",
[
None,
{
"read_timeout": 1000,
},
],
)
def test_to_dict(self, mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]):
embedder = AmazonBedrockDocumentEmbedder(
model="cohere.embed-english-v3",
input_type="search_document",
boto3_config=boto3_config,
)

expected_dict = {
Expand All @@ -86,12 +97,22 @@ def test_to_dict(self, mock_boto3_session):
"progress_bar": True,
"meta_fields_to_embed": [],
"embedding_separator": "\n",
"boto3_config": boto3_config,
},
}

assert embedder.to_dict() == expected_dict

def test_from_dict(self, mock_boto3_session):
@pytest.mark.parametrize(
"boto3_config",
[
None,
{
"read_timeout": 1000,
},
],
)
def test_from_dict(self, mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]):
data = {
"type": TYPE,
"init_parameters": {
Expand All @@ -106,6 +127,7 @@ def test_from_dict(self, mock_boto3_session):
"progress_bar": True,
"meta_fields_to_embed": [],
"embedding_separator": "\n",
"boto3_config": boto3_config,
},
}

Expand All @@ -117,6 +139,7 @@ def test_from_dict(self, mock_boto3_session):
assert embedder.progress_bar
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"
assert embedder.boto3_config == boto3_config

def test_init_invalid_model(self):
with pytest.raises(ValueError):
Expand Down
Loading

0 comments on commit 798fb98

Please sign in to comment.