-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* wip * Bedrock refactoring * rm wip embedder * bedrock - remove supports method * rename commons to common * fix pydoc config * text embedder! * more cleaning * lint * rename test module
- Loading branch information
Showing
3 changed files
with
337 additions
and
0 deletions.
There are no files selected for viewing
6 changes: 6 additions & 0 deletions
6
.../amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from .text_embedder import AmazonBedrockTextEmbedder | ||
|
||
__all__ = ["AmazonBedrockTextEmbedder"] |
181 changes: 181 additions & 0 deletions
181
...on_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
import json | ||
import logging | ||
from typing import Any, Dict, List, Literal, Optional | ||
|
||
from botocore.exceptions import ClientError | ||
from haystack import component, default_from_dict, default_to_dict | ||
from haystack.utils.auth import Secret, deserialize_secrets_inplace | ||
|
||
from haystack_integrations.common.amazon_bedrock.errors import ( | ||
AmazonBedrockConfigurationError, | ||
AmazonBedrockInferenceError, | ||
) | ||
from haystack_integrations.common.amazon_bedrock.utils import get_aws_session | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
SUPPORTED_EMBEDDING_MODELS = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"] | ||
|
||
|
||
@component | ||
class AmazonBedrockTextEmbedder: | ||
""" | ||
A component for embedding strings using Amazon Bedrock. | ||
Usage example: | ||
```python | ||
import os | ||
from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockTextEmbedder | ||
os.environ["AWS_ACCESS_KEY_ID"] = "..." | ||
os.environ["AWS_SECRET_ACCESS_KEY_ID"] = "..." | ||
os.environ["AWS_REGION_NAME"] = "..." | ||
embedder = AmazonBedrockTextEmbedder( | ||
model="cohere.embed-english-v3", | ||
input_type="search_query", | ||
) | ||
print(text_embedder.run("I love Paris in the summer.")) | ||
# {'embedding': [0.002, 0.032, 0.504, ...]} | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: Literal["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"], | ||
aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008 | ||
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 | ||
"AWS_SECRET_ACCESS_KEY", strict=False | ||
), | ||
aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008 | ||
aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008 | ||
aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008 | ||
**kwargs, | ||
): | ||
""" | ||
Initializes the AmazonBedrockTextEmbedder with the provided parameters. The parameters are passed to the | ||
Amazon Bedrock client. | ||
Note that the AWS credentials are not required if the AWS environment is configured correctly. These are loaded | ||
automatically from the environment or the AWS configuration file and do not need to be provided explicitly via | ||
the constructor. If the AWS environment is not configured users need to provide the AWS credentials via the | ||
constructor. Aside from model, three required parameters are `aws_access_key_id`, `aws_secret_access_key`, | ||
and `aws_region_name`. | ||
:param model: The embedding model to use. The model has to be specified in the format outlined in the Amazon | ||
Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html). | ||
:type model: Literal["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"] | ||
:param aws_access_key_id: AWS access key ID. | ||
:param aws_secret_access_key: AWS secret access key. | ||
:param aws_session_token: AWS session token. | ||
:param aws_region_name: AWS region name. | ||
:param aws_profile_name: AWS profile name. | ||
:param kwargs: Additional parameters to pass for model inference. For example, `input_type` and `truncate` for | ||
Cohere models. | ||
""" | ||
if not model or model not in SUPPORTED_EMBEDDING_MODELS: | ||
msg = "Please provide a valid model from the list of supported models: " + ", ".join( | ||
SUPPORTED_EMBEDDING_MODELS | ||
) | ||
raise ValueError(msg) | ||
|
||
def resolve_secret(secret: Optional[Secret]) -> Optional[str]: | ||
return secret.resolve_value() if secret else None | ||
|
||
try: | ||
session = get_aws_session( | ||
aws_access_key_id=resolve_secret(aws_access_key_id), | ||
aws_secret_access_key=resolve_secret(aws_secret_access_key), | ||
aws_session_token=resolve_secret(aws_session_token), | ||
aws_region_name=resolve_secret(aws_region_name), | ||
aws_profile_name=resolve_secret(aws_profile_name), | ||
) | ||
self._client = session.client("bedrock-runtime") | ||
except Exception as exception: | ||
msg = ( | ||
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " | ||
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration" | ||
) | ||
raise AmazonBedrockConfigurationError(msg) from exception | ||
|
||
self.model = model | ||
self.aws_access_key_id = aws_access_key_id | ||
self.aws_secret_access_key = aws_secret_access_key | ||
self.aws_session_token = aws_session_token | ||
self.aws_region_name = aws_region_name | ||
self.aws_profile_name = aws_profile_name | ||
self.kwargs = kwargs | ||
|
||
@component.output_types(embedding=List[float]) | ||
def run(self, text: str): | ||
if not isinstance(text, str): | ||
msg = ( | ||
"AmazonBedrockTextEmbedder expects a string as an input." | ||
"In case you want to embed a list of Documents, please use the AmazonBedrockTextEmbedder." | ||
) | ||
raise TypeError(msg) | ||
|
||
if "cohere" in self.model: | ||
body = { | ||
"texts": [text], | ||
"input_type": self.kwargs.get("input_type", "search_query"), # mandatory parameter for Cohere models | ||
} | ||
if truncate := self.kwargs.get("truncate"): | ||
body["truncate"] = truncate # optional parameter for Cohere models | ||
|
||
elif "titan" in self.model: | ||
body = { | ||
"inputText": text, | ||
} | ||
|
||
try: | ||
response = self._client.invoke_model( | ||
body=json.dumps(body), modelId=self.model, accept="*/*", contentType="application/json" | ||
) | ||
except ClientError as exception: | ||
msg = ( | ||
f"Could not connect to Amazon Bedrock model {self.model}. " | ||
f"Make sure your AWS environment is configured correctly, " | ||
f"the model is available in the configured AWS region, and you have access." | ||
) | ||
raise AmazonBedrockInferenceError(msg) from exception | ||
|
||
response_body = json.loads(response.get("body").read()) | ||
|
||
if "cohere" in self.model: | ||
embedding = response_body["embeddings"][0] | ||
elif "titan" in self.model: | ||
embedding = response_body["embedding"] | ||
|
||
return {"embedding": embedding} | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serialize this component to a dictionary. | ||
:return: The serialized component as a dictionary. | ||
""" | ||
return default_to_dict( | ||
self, | ||
aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, | ||
aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None, | ||
aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None, | ||
aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, | ||
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, | ||
model=self.model, | ||
**self.kwargs, | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockTextEmbedder": | ||
""" | ||
Deserialize this component from a dictionary. | ||
:param data: The dictionary representation of this component. | ||
:return: The deserialized component instance. | ||
""" | ||
deserialize_secrets_inplace( | ||
data["init_parameters"], | ||
["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], | ||
) | ||
return default_from_dict(cls, data) |
150 changes: 150 additions & 0 deletions
150
integrations/amazon_bedrock/tests/test_text_embedder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import io | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
from botocore.exceptions import ClientError | ||
|
||
from haystack_integrations.common.amazon_bedrock.errors import ( | ||
AmazonBedrockConfigurationError, | ||
AmazonBedrockInferenceError, | ||
) | ||
from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockTextEmbedder | ||
|
||
|
||
class TestAmazonBedrockTextEmbedder: | ||
def test_init(self, mock_boto3_session, set_env_variables): | ||
embedder = AmazonBedrockTextEmbedder( | ||
model="cohere.embed-english-v3", | ||
input_type="fake_input_type", | ||
) | ||
|
||
assert embedder.model == "cohere.embed-english-v3" | ||
assert embedder.kwargs == {"input_type": "fake_input_type"} | ||
|
||
# assert mocked boto3 client called exactly once | ||
mock_boto3_session.assert_called_once() | ||
|
||
# assert mocked boto3 client was called with the correct parameters | ||
mock_boto3_session.assert_called_with( | ||
aws_access_key_id="some_fake_id", | ||
aws_secret_access_key="some_fake_key", | ||
aws_session_token="some_fake_token", | ||
profile_name="some_fake_profile", | ||
region_name="fake_region", | ||
) | ||
|
||
def test_connection_error(self, mock_boto3_session): | ||
mock_boto3_session.side_effect = Exception("some connection error") | ||
|
||
with pytest.raises(AmazonBedrockConfigurationError): | ||
AmazonBedrockTextEmbedder( | ||
model="cohere.embed-english-v3", | ||
input_type="fake_input_type", | ||
) | ||
|
||
def test_to_dict(self, mock_boto3_session): | ||
|
||
embedder = AmazonBedrockTextEmbedder( | ||
model="cohere.embed-english-v3", | ||
input_type="search_query", | ||
) | ||
|
||
expected_dict = { | ||
"type": "haystack_integrations.components.embedders.amazon_bedrock.text_embedder.AmazonBedrockTextEmbedder", | ||
"init_parameters": { | ||
"aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, | ||
"aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, | ||
"aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, | ||
"aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, | ||
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, | ||
"model": "cohere.embed-english-v3", | ||
"input_type": "search_query", | ||
}, | ||
} | ||
|
||
assert embedder.to_dict() == expected_dict | ||
|
||
def test_from_dict(self, mock_boto3_session): | ||
|
||
data = { | ||
"type": "haystack_integrations.components.embedders.amazon_bedrock.text_embedder.AmazonBedrockTextEmbedder", | ||
"init_parameters": { | ||
"aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, | ||
"aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, | ||
"aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, | ||
"aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, | ||
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, | ||
"model": "cohere.embed-english-v3", | ||
"input_type": "search_query", | ||
}, | ||
} | ||
|
||
embedder = AmazonBedrockTextEmbedder.from_dict(data) | ||
|
||
assert embedder.model == "cohere.embed-english-v3" | ||
assert embedder.kwargs == {"input_type": "search_query"} | ||
|
||
|
||
def test_init_invalid_model(): | ||
with pytest.raises(ValueError): | ||
AmazonBedrockTextEmbedder(model="") | ||
|
||
with pytest.raises(ValueError): | ||
AmazonBedrockTextEmbedder(model="my-unsupported-model") | ||
|
||
|
||
def test_run_wrong_type(mock_boto3_session): | ||
embedder = AmazonBedrockTextEmbedder(model="cohere.embed-english-v3") | ||
with pytest.raises(TypeError): | ||
embedder.run(text=123) | ||
|
||
|
||
def test_cohere_invocation(mock_boto3_session): | ||
embedder = AmazonBedrockTextEmbedder(model="cohere.embed-english-v3") | ||
|
||
with patch.object(embedder._client, "invoke_model") as mock_invoke_model: | ||
mock_invoke_model.return_value = { | ||
"body": io.StringIO('{"embeddings": [[0.1, 0.2, 0.3]]}'), | ||
} | ||
result = embedder.run(text="some text") | ||
|
||
mock_invoke_model.assert_called_once_with( | ||
body='{"texts": ["some text"], "input_type": "search_query"}', | ||
modelId="cohere.embed-english-v3", | ||
accept="*/*", | ||
contentType="application/json", | ||
) | ||
|
||
assert result == {"embedding": [0.1, 0.2, 0.3]} | ||
|
||
|
||
def test_titan_invocation(mock_boto3_session): | ||
embedder = AmazonBedrockTextEmbedder(model="amazon.titan-embed-text-v1") | ||
|
||
with patch.object(embedder._client, "invoke_model") as mock_invoke_model: | ||
mock_invoke_model.return_value = { | ||
"body": io.StringIO('{"embedding": [0.1, 0.2, 0.3]}'), | ||
} | ||
result = embedder.run(text="some text") | ||
|
||
mock_invoke_model.assert_called_once_with( | ||
body='{"inputText": "some text"}', | ||
modelId="amazon.titan-embed-text-v1", | ||
accept="*/*", | ||
contentType="application/json", | ||
) | ||
|
||
assert result == {"embedding": [0.1, 0.2, 0.3]} | ||
|
||
|
||
def test_run_invocation_error(mock_boto3_session): | ||
embedder = AmazonBedrockTextEmbedder(model="cohere.embed-english-v3") | ||
|
||
with patch.object(embedder._client, "invoke_model") as mock_invoke_model: | ||
mock_invoke_model.side_effect = ClientError( | ||
error_response={"Error": {"Code": "some_code", "Message": "some_message"}}, | ||
operation_name="some_operation", | ||
) | ||
|
||
with pytest.raises(AmazonBedrockInferenceError): | ||
embedder.run(text="some text") |