From 66fb26e3ef4539036af7449ee214887087cefd13 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 22 Feb 2024 11:32:47 +0100 Subject: [PATCH] Bedrock Text Embedder (#466) * wip * Bedrock refactoring * rm wip embedder * bedrock - remove supports method * rename commons to common * fix pydoc config * text embedder! * more cleaning * lint * rename test module --- .../embedders/amazon_bedrock/__init__.py | 6 + .../embedders/amazon_bedrock/text_embedder.py | 181 ++++++++++++++++++ .../tests/test_text_embedder.py | 150 +++++++++++++++ 3 files changed, 337 insertions(+) create mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py create mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py create mode 100644 integrations/amazon_bedrock/tests/test_text_embedder.py diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py new file mode 100644 index 000000000..5f1ff6caa --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .text_embedder import AmazonBedrockTextEmbedder + +__all__ = ["AmazonBedrockTextEmbedder"] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py new file mode 100644 index 000000000..8804702a0 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py @@ -0,0 +1,181 @@ +import json +import logging +from typing import Any, Dict, List, Literal, Optional + +from botocore.exceptions import ClientError +from haystack import component, default_from_dict, default_to_dict +from haystack.utils.auth import Secret, deserialize_secrets_inplace + +from haystack_integrations.common.amazon_bedrock.errors import ( + AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, +) +from haystack_integrations.common.amazon_bedrock.utils import get_aws_session + +logger = logging.getLogger(__name__) + +SUPPORTED_EMBEDDING_MODELS = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"] + + +@component +class AmazonBedrockTextEmbedder: + """ + A component for embedding strings using Amazon Bedrock. + + Usage example: + ```python + import os + from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockTextEmbedder + + os.environ["AWS_ACCESS_KEY_ID"] = "..." + os.environ["AWS_SECRET_ACCESS_KEY_ID"] = "..." + os.environ["AWS_REGION_NAME"] = "..." + + embedder = AmazonBedrockTextEmbedder( + model="cohere.embed-english-v3", + input_type="search_query", + ) + + print(text_embedder.run("I love Paris in the summer.")) + + # {'embedding': [0.002, 0.032, 0.504, ...]} + ``` + """ + + def __init__( + self, + model: Literal["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"], + aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008 + aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 + "AWS_SECRET_ACCESS_KEY", strict=False + ), + aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008 + aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008 + aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008 + **kwargs, + ): + """ + Initializes the AmazonBedrockTextEmbedder with the provided parameters. The parameters are passed to the + Amazon Bedrock client. + + Note that the AWS credentials are not required if the AWS environment is configured correctly. These are loaded + automatically from the environment or the AWS configuration file and do not need to be provided explicitly via + the constructor. If the AWS environment is not configured users need to provide the AWS credentials via the + constructor. Aside from model, three required parameters are `aws_access_key_id`, `aws_secret_access_key`, + and `aws_region_name`. + + :param model: The embedding model to use. The model has to be specified in the format outlined in the Amazon + Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html). + :type model: Literal["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"] + :param aws_access_key_id: AWS access key ID. + :param aws_secret_access_key: AWS secret access key. + :param aws_session_token: AWS session token. + :param aws_region_name: AWS region name. + :param aws_profile_name: AWS profile name. + :param kwargs: Additional parameters to pass for model inference. For example, `input_type` and `truncate` for + Cohere models. + """ + if not model or model not in SUPPORTED_EMBEDDING_MODELS: + msg = "Please provide a valid model from the list of supported models: " + ", ".join( + SUPPORTED_EMBEDDING_MODELS + ) + raise ValueError(msg) + + def resolve_secret(secret: Optional[Secret]) -> Optional[str]: + return secret.resolve_value() if secret else None + + try: + session = get_aws_session( + aws_access_key_id=resolve_secret(aws_access_key_id), + aws_secret_access_key=resolve_secret(aws_secret_access_key), + aws_session_token=resolve_secret(aws_session_token), + aws_region_name=resolve_secret(aws_region_name), + aws_profile_name=resolve_secret(aws_profile_name), + ) + self._client = session.client("bedrock-runtime") + except Exception as exception: + msg = ( + "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " + "See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration" + ) + raise AmazonBedrockConfigurationError(msg) from exception + + self.model = model + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.aws_session_token = aws_session_token + self.aws_region_name = aws_region_name + self.aws_profile_name = aws_profile_name + self.kwargs = kwargs + + @component.output_types(embedding=List[float]) + def run(self, text: str): + if not isinstance(text, str): + msg = ( + "AmazonBedrockTextEmbedder expects a string as an input." + "In case you want to embed a list of Documents, please use the AmazonBedrockTextEmbedder." + ) + raise TypeError(msg) + + if "cohere" in self.model: + body = { + "texts": [text], + "input_type": self.kwargs.get("input_type", "search_query"), # mandatory parameter for Cohere models + } + if truncate := self.kwargs.get("truncate"): + body["truncate"] = truncate # optional parameter for Cohere models + + elif "titan" in self.model: + body = { + "inputText": text, + } + + try: + response = self._client.invoke_model( + body=json.dumps(body), modelId=self.model, accept="*/*", contentType="application/json" + ) + except ClientError as exception: + msg = ( + f"Could not connect to Amazon Bedrock model {self.model}. " + f"Make sure your AWS environment is configured correctly, " + f"the model is available in the configured AWS region, and you have access." + ) + raise AmazonBedrockInferenceError(msg) from exception + + response_body = json.loads(response.get("body").read()) + + if "cohere" in self.model: + embedding = response_body["embeddings"][0] + elif "titan" in self.model: + embedding = response_body["embedding"] + + return {"embedding": embedding} + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + :return: The serialized component as a dictionary. + """ + return default_to_dict( + self, + aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, + aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None, + aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None, + aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, + aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, + model=self.model, + **self.kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockTextEmbedder": + """ + Deserialize this component from a dictionary. + :param data: The dictionary representation of this component. + :return: The deserialized component instance. + """ + deserialize_secrets_inplace( + data["init_parameters"], + ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], + ) + return default_from_dict(cls, data) diff --git a/integrations/amazon_bedrock/tests/test_text_embedder.py b/integrations/amazon_bedrock/tests/test_text_embedder.py new file mode 100644 index 000000000..022c0c0a5 --- /dev/null +++ b/integrations/amazon_bedrock/tests/test_text_embedder.py @@ -0,0 +1,150 @@ +import io +from unittest.mock import patch + +import pytest +from botocore.exceptions import ClientError + +from haystack_integrations.common.amazon_bedrock.errors import ( + AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, +) +from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockTextEmbedder + + +class TestAmazonBedrockTextEmbedder: + def test_init(self, mock_boto3_session, set_env_variables): + embedder = AmazonBedrockTextEmbedder( + model="cohere.embed-english-v3", + input_type="fake_input_type", + ) + + assert embedder.model == "cohere.embed-english-v3" + assert embedder.kwargs == {"input_type": "fake_input_type"} + + # assert mocked boto3 client called exactly once + mock_boto3_session.assert_called_once() + + # assert mocked boto3 client was called with the correct parameters + mock_boto3_session.assert_called_with( + aws_access_key_id="some_fake_id", + aws_secret_access_key="some_fake_key", + aws_session_token="some_fake_token", + profile_name="some_fake_profile", + region_name="fake_region", + ) + + def test_connection_error(self, mock_boto3_session): + mock_boto3_session.side_effect = Exception("some connection error") + + with pytest.raises(AmazonBedrockConfigurationError): + AmazonBedrockTextEmbedder( + model="cohere.embed-english-v3", + input_type="fake_input_type", + ) + + def test_to_dict(self, mock_boto3_session): + + embedder = AmazonBedrockTextEmbedder( + model="cohere.embed-english-v3", + input_type="search_query", + ) + + expected_dict = { + "type": "haystack_integrations.components.embedders.amazon_bedrock.text_embedder.AmazonBedrockTextEmbedder", + "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "model": "cohere.embed-english-v3", + "input_type": "search_query", + }, + } + + assert embedder.to_dict() == expected_dict + + def test_from_dict(self, mock_boto3_session): + + data = { + "type": "haystack_integrations.components.embedders.amazon_bedrock.text_embedder.AmazonBedrockTextEmbedder", + "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "model": "cohere.embed-english-v3", + "input_type": "search_query", + }, + } + + embedder = AmazonBedrockTextEmbedder.from_dict(data) + + assert embedder.model == "cohere.embed-english-v3" + assert embedder.kwargs == {"input_type": "search_query"} + + +def test_init_invalid_model(): + with pytest.raises(ValueError): + AmazonBedrockTextEmbedder(model="") + + with pytest.raises(ValueError): + AmazonBedrockTextEmbedder(model="my-unsupported-model") + + +def test_run_wrong_type(mock_boto3_session): + embedder = AmazonBedrockTextEmbedder(model="cohere.embed-english-v3") + with pytest.raises(TypeError): + embedder.run(text=123) + + +def test_cohere_invocation(mock_boto3_session): + embedder = AmazonBedrockTextEmbedder(model="cohere.embed-english-v3") + + with patch.object(embedder._client, "invoke_model") as mock_invoke_model: + mock_invoke_model.return_value = { + "body": io.StringIO('{"embeddings": [[0.1, 0.2, 0.3]]}'), + } + result = embedder.run(text="some text") + + mock_invoke_model.assert_called_once_with( + body='{"texts": ["some text"], "input_type": "search_query"}', + modelId="cohere.embed-english-v3", + accept="*/*", + contentType="application/json", + ) + + assert result == {"embedding": [0.1, 0.2, 0.3]} + + +def test_titan_invocation(mock_boto3_session): + embedder = AmazonBedrockTextEmbedder(model="amazon.titan-embed-text-v1") + + with patch.object(embedder._client, "invoke_model") as mock_invoke_model: + mock_invoke_model.return_value = { + "body": io.StringIO('{"embedding": [0.1, 0.2, 0.3]}'), + } + result = embedder.run(text="some text") + + mock_invoke_model.assert_called_once_with( + body='{"inputText": "some text"}', + modelId="amazon.titan-embed-text-v1", + accept="*/*", + contentType="application/json", + ) + + assert result == {"embedding": [0.1, 0.2, 0.3]} + + +def test_run_invocation_error(mock_boto3_session): + embedder = AmazonBedrockTextEmbedder(model="cohere.embed-english-v3") + + with patch.object(embedder._client, "invoke_model") as mock_invoke_model: + mock_invoke_model.side_effect = ClientError( + error_response={"Error": {"Code": "some_code", "Message": "some_message"}}, + operation_name="some_operation", + ) + + with pytest.raises(AmazonBedrockInferenceError): + embedder.run(text="some text")