Skip to content

Commit

Permalink
Bedrock Text Embedder (#466)
Browse files Browse the repository at this point in the history
* wip

* Bedrock refactoring

* rm wip embedder

* bedrock - remove supports method

* rename commons to common

* fix pydoc config

* text embedder!

* more cleaning

* lint

* rename test module
  • Loading branch information
anakin87 authored Feb 22, 2024
1 parent e5011a7 commit 66fb26e
Show file tree
Hide file tree
Showing 3 changed files with 337 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from .text_embedder import AmazonBedrockTextEmbedder

__all__ = ["AmazonBedrockTextEmbedder"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import json
import logging
from typing import Any, Dict, List, Literal, Optional

from botocore.exceptions import ClientError
from haystack import component, default_from_dict, default_to_dict
from haystack.utils.auth import Secret, deserialize_secrets_inplace

from haystack_integrations.common.amazon_bedrock.errors import (
AmazonBedrockConfigurationError,
AmazonBedrockInferenceError,
)
from haystack_integrations.common.amazon_bedrock.utils import get_aws_session

logger = logging.getLogger(__name__)

SUPPORTED_EMBEDDING_MODELS = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"]


@component
class AmazonBedrockTextEmbedder:
"""
A component for embedding strings using Amazon Bedrock.
Usage example:
```python
import os
from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockTextEmbedder
os.environ["AWS_ACCESS_KEY_ID"] = "..."
os.environ["AWS_SECRET_ACCESS_KEY_ID"] = "..."
os.environ["AWS_REGION_NAME"] = "..."
embedder = AmazonBedrockTextEmbedder(
model="cohere.embed-english-v3",
input_type="search_query",
)
print(text_embedder.run("I love Paris in the summer."))
# {'embedding': [0.002, 0.032, 0.504, ...]}
```
"""

def __init__(
self,
model: Literal["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"],
aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008
"AWS_SECRET_ACCESS_KEY", strict=False
),
aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
**kwargs,
):
"""
Initializes the AmazonBedrockTextEmbedder with the provided parameters. The parameters are passed to the
Amazon Bedrock client.
Note that the AWS credentials are not required if the AWS environment is configured correctly. These are loaded
automatically from the environment or the AWS configuration file and do not need to be provided explicitly via
the constructor. If the AWS environment is not configured users need to provide the AWS credentials via the
constructor. Aside from model, three required parameters are `aws_access_key_id`, `aws_secret_access_key`,
and `aws_region_name`.
:param model: The embedding model to use. The model has to be specified in the format outlined in the Amazon
Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html).
:type model: Literal["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"]
:param aws_access_key_id: AWS access key ID.
:param aws_secret_access_key: AWS secret access key.
:param aws_session_token: AWS session token.
:param aws_region_name: AWS region name.
:param aws_profile_name: AWS profile name.
:param kwargs: Additional parameters to pass for model inference. For example, `input_type` and `truncate` for
Cohere models.
"""
if not model or model not in SUPPORTED_EMBEDDING_MODELS:
msg = "Please provide a valid model from the list of supported models: " + ", ".join(
SUPPORTED_EMBEDDING_MODELS
)
raise ValueError(msg)

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None

try:
session = get_aws_session(
aws_access_key_id=resolve_secret(aws_access_key_id),
aws_secret_access_key=resolve_secret(aws_secret_access_key),
aws_session_token=resolve_secret(aws_session_token),
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self._client = session.client("bedrock-runtime")
except Exception as exception:
msg = (
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration"
)
raise AmazonBedrockConfigurationError(msg) from exception

self.model = model
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.kwargs = kwargs

@component.output_types(embedding=List[float])
def run(self, text: str):
if not isinstance(text, str):
msg = (
"AmazonBedrockTextEmbedder expects a string as an input."
"In case you want to embed a list of Documents, please use the AmazonBedrockTextEmbedder."
)
raise TypeError(msg)

if "cohere" in self.model:
body = {
"texts": [text],
"input_type": self.kwargs.get("input_type", "search_query"), # mandatory parameter for Cohere models
}
if truncate := self.kwargs.get("truncate"):
body["truncate"] = truncate # optional parameter for Cohere models

elif "titan" in self.model:
body = {
"inputText": text,
}

try:
response = self._client.invoke_model(
body=json.dumps(body), modelId=self.model, accept="*/*", contentType="application/json"
)
except ClientError as exception:
msg = (
f"Could not connect to Amazon Bedrock model {self.model}. "
f"Make sure your AWS environment is configured correctly, "
f"the model is available in the configured AWS region, and you have access."
)
raise AmazonBedrockInferenceError(msg) from exception

response_body = json.loads(response.get("body").read())

if "cohere" in self.model:
embedding = response_body["embeddings"][0]
elif "titan" in self.model:
embedding = response_body["embedding"]

return {"embedding": embedding}

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:return: The serialized component as a dictionary.
"""
return default_to_dict(
self,
aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None,
aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None,
aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None,
aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None,
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
model=self.model,
**self.kwargs,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockTextEmbedder":
"""
Deserialize this component from a dictionary.
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
"""
deserialize_secrets_inplace(
data["init_parameters"],
["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"],
)
return default_from_dict(cls, data)
150 changes: 150 additions & 0 deletions integrations/amazon_bedrock/tests/test_text_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import io
from unittest.mock import patch

import pytest
from botocore.exceptions import ClientError

from haystack_integrations.common.amazon_bedrock.errors import (
AmazonBedrockConfigurationError,
AmazonBedrockInferenceError,
)
from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockTextEmbedder


class TestAmazonBedrockTextEmbedder:
def test_init(self, mock_boto3_session, set_env_variables):
embedder = AmazonBedrockTextEmbedder(
model="cohere.embed-english-v3",
input_type="fake_input_type",
)

assert embedder.model == "cohere.embed-english-v3"
assert embedder.kwargs == {"input_type": "fake_input_type"}

# assert mocked boto3 client called exactly once
mock_boto3_session.assert_called_once()

# assert mocked boto3 client was called with the correct parameters
mock_boto3_session.assert_called_with(
aws_access_key_id="some_fake_id",
aws_secret_access_key="some_fake_key",
aws_session_token="some_fake_token",
profile_name="some_fake_profile",
region_name="fake_region",
)

def test_connection_error(self, mock_boto3_session):
mock_boto3_session.side_effect = Exception("some connection error")

with pytest.raises(AmazonBedrockConfigurationError):
AmazonBedrockTextEmbedder(
model="cohere.embed-english-v3",
input_type="fake_input_type",
)

def test_to_dict(self, mock_boto3_session):

embedder = AmazonBedrockTextEmbedder(
model="cohere.embed-english-v3",
input_type="search_query",
)

expected_dict = {
"type": "haystack_integrations.components.embedders.amazon_bedrock.text_embedder.AmazonBedrockTextEmbedder",
"init_parameters": {
"aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
"aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
"aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
"aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
"model": "cohere.embed-english-v3",
"input_type": "search_query",
},
}

assert embedder.to_dict() == expected_dict

def test_from_dict(self, mock_boto3_session):

data = {
"type": "haystack_integrations.components.embedders.amazon_bedrock.text_embedder.AmazonBedrockTextEmbedder",
"init_parameters": {
"aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
"aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
"aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
"aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
"model": "cohere.embed-english-v3",
"input_type": "search_query",
},
}

embedder = AmazonBedrockTextEmbedder.from_dict(data)

assert embedder.model == "cohere.embed-english-v3"
assert embedder.kwargs == {"input_type": "search_query"}


def test_init_invalid_model():
with pytest.raises(ValueError):
AmazonBedrockTextEmbedder(model="")

with pytest.raises(ValueError):
AmazonBedrockTextEmbedder(model="my-unsupported-model")


def test_run_wrong_type(mock_boto3_session):
embedder = AmazonBedrockTextEmbedder(model="cohere.embed-english-v3")
with pytest.raises(TypeError):
embedder.run(text=123)


def test_cohere_invocation(mock_boto3_session):
embedder = AmazonBedrockTextEmbedder(model="cohere.embed-english-v3")

with patch.object(embedder._client, "invoke_model") as mock_invoke_model:
mock_invoke_model.return_value = {
"body": io.StringIO('{"embeddings": [[0.1, 0.2, 0.3]]}'),
}
result = embedder.run(text="some text")

mock_invoke_model.assert_called_once_with(
body='{"texts": ["some text"], "input_type": "search_query"}',
modelId="cohere.embed-english-v3",
accept="*/*",
contentType="application/json",
)

assert result == {"embedding": [0.1, 0.2, 0.3]}


def test_titan_invocation(mock_boto3_session):
embedder = AmazonBedrockTextEmbedder(model="amazon.titan-embed-text-v1")

with patch.object(embedder._client, "invoke_model") as mock_invoke_model:
mock_invoke_model.return_value = {
"body": io.StringIO('{"embedding": [0.1, 0.2, 0.3]}'),
}
result = embedder.run(text="some text")

mock_invoke_model.assert_called_once_with(
body='{"inputText": "some text"}',
modelId="amazon.titan-embed-text-v1",
accept="*/*",
contentType="application/json",
)

assert result == {"embedding": [0.1, 0.2, 0.3]}


def test_run_invocation_error(mock_boto3_session):
embedder = AmazonBedrockTextEmbedder(model="cohere.embed-english-v3")

with patch.object(embedder._client, "invoke_model") as mock_invoke_model:
mock_invoke_model.side_effect = ClientError(
error_response={"Error": {"Code": "some_code", "Message": "some_message"}},
operation_name="some_operation",
)

with pytest.raises(AmazonBedrockInferenceError):
embedder.run(text="some text")

0 comments on commit 66fb26e

Please sign in to comment.