Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bedrock Text Embedder #466

Merged
merged 15 commits into from
Feb 22, 2024
2 changes: 1 addition & 1 deletion integrations/amazon_bedrock/pydoc/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ loaders:
modules: [
"haystack_integrations.components.generators.amazon_bedrock.generator",
"haystack_integrations.components.generators.amazon_bedrock.adapters",
"haystack_integrations.components.generators.amazon_bedrock.errors",
"haystack_integrations.common.amazon_bedrock.errors",
"haystack_integrations.components.generators.amazon_bedrock.handlers",
]
ignore_when_discovered: ["__init__"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Optional

import boto3
from botocore.exceptions import BotoCoreError

from haystack_integrations.common.amazon_bedrock.errors import AWSConfigurationError

AWS_CONFIGURATION_KEYS = [
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
"aws_region_name",
"aws_profile_name",
]


def get_aws_session(
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
**kwargs,
):
"""
Creates an AWS Session with the given parameters.
Checks if the provided AWS credentials are valid and can be used to connect to AWS.

:param aws_access_key_id: AWS access key ID.
:param aws_secret_access_key: AWS secret access key.
:param aws_session_token: AWS session token.
:param aws_region_name: AWS region name.
:param aws_profile_name: AWS profile name.
:param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen.
See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html.
:raises AWSConfigurationError: If the provided AWS credentials are invalid.
:return: The created AWS session.
"""
try:
return boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=aws_region_name,
profile_name=aws_profile_name,
)
except BotoCoreError as e:
provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS}
msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}"
raise AWSConfigurationError(msg) from e


def aws_configured(**kwargs) -> bool:
"""
Checks whether AWS configuration is provided.
:param kwargs: The kwargs passed down to the generator.
:return: True if AWS configuration is provided, False otherwise.
"""
aws_config_provided = any(key in kwargs for key in AWS_CONFIGURATION_KEYS)
return aws_config_provided
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from .text_embedder import AmazonBedrockTextEmbedder

__all__ = ["AmazonBedrockTextEmbedder"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import json
import logging
from typing import Any, Dict, List, Literal, Optional

from botocore.exceptions import ClientError
from haystack import component, default_from_dict, default_to_dict
from haystack.utils.auth import Secret, deserialize_secrets_inplace

from haystack_integrations.common.amazon_bedrock.errors import (
AmazonBedrockConfigurationError,
AmazonBedrockInferenceError,
)
from haystack_integrations.common.amazon_bedrock.utils import get_aws_session

logger = logging.getLogger(__name__)

SUPPORTED_EMBEDDING_MODELS = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"]


@component
class AmazonBedrockTextEmbedder:
"""
A component for embedding strings using Amazon Bedrock.

Usage example:
```python
import os
from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockTextEmbedder

os.environ["AWS_ACCESS_KEY_ID"] = "..."
os.environ["AWS_SECRET_ACCESS_KEY_ID"] = "..."
os.environ["AWS_REGION_NAME"] = "..."

embedder = AmazonBedrockTextEmbedder(
model="cohere.embed-english-v3",
input_type="search_query",
)

print(text_embedder.run("I love Paris in the summer."))

# {'embedding': [0.002, 0.032, 0.504, ...]}
```
"""

def __init__(
self,
model: Literal["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"],
aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008
"AWS_SECRET_ACCESS_KEY", strict=False
),
aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
**kwargs,
):
"""
Initializes the AmazonBedrockTextEmbedder with the provided parameters. The parameters are passed to the
Amazon Bedrock client.

Note that the AWS credentials are not required if the AWS environment is configured correctly. These are loaded
automatically from the environment or the AWS configuration file and do not need to be provided explicitly via
the constructor. If the AWS environment is not configured users need to provide the AWS credentials via the
constructor. Aside from model, three required parameters are `aws_access_key_id`, `aws_secret_access_key`,
and `aws_region_name`.

:param model: The embedding model to use. The model has to be specified in the format outlined in the Amazon
Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html).
:type model: Literal["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"]
:param aws_access_key_id: AWS access key ID.
:param aws_secret_access_key: AWS secret access key.
:param aws_session_token: AWS session token.
:param aws_region_name: AWS region name.
:param aws_profile_name: AWS profile name.
:param kwargs: Additional parameters to pass for model inference. For example, `input_type` and `truncate` for
Cohere models.
"""
if not model or model not in SUPPORTED_EMBEDDING_MODELS:
msg = "Please provide a valid model from the list of supported models: " + ", ".join(
SUPPORTED_EMBEDDING_MODELS
)
raise ValueError(msg)

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None

try:
session = get_aws_session(
aws_access_key_id=resolve_secret(aws_access_key_id),
aws_secret_access_key=resolve_secret(aws_secret_access_key),
aws_session_token=resolve_secret(aws_session_token),
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self._client = session.client("bedrock-runtime")
except Exception as exception:
msg = (
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration"
)
raise AmazonBedrockConfigurationError(msg) from exception

self.model = model
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.kwargs = kwargs

@component.output_types(embedding=List[float])
def run(self, text: str):
if not isinstance(text, str):
msg = (
"AmazonBedrockTextEmbedder expects a string as an input."
"In case you want to embed a list of Documents, please use the AmazonBedrockTextEmbedder."
)
raise TypeError(msg)

if "cohere" in self.model:
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
body = {
"texts": [text],
"input_type": self.kwargs.get("input_type", "search_query"), # mandatory parameter for Cohere models
}
if truncate := self.kwargs.get("truncate"):
body["truncate"] = truncate # optional parameter for Cohere models

elif "titan" in self.model:
body = {
"inputText": text,
}

try:
response = self._client.invoke_model(
body=json.dumps(body), modelId=self.model, accept="*/*", contentType="application/json"
)
except ClientError as exception:
msg = (
f"Could not connect to Amazon Bedrock model {self.model}. "
f"Make sure your AWS environment is configured correctly, "
f"the model is available in the configured AWS region, and you have access."
)
raise AmazonBedrockInferenceError(msg) from exception

response_body = json.loads(response.get("body").read())

if "cohere" in self.model:
embedding = response_body["embeddings"][0]
elif "titan" in self.model:
embedding = response_body["embedding"]

return {"embedding": embedding}

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:return: The serialized component as a dictionary.
"""
return default_to_dict(
self,
aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None,
aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None,
aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None,
aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None,
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
model=self.model,
**self.kwargs,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockTextEmbedder":
"""
Deserialize this component from a dictionary.
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
"""
deserialize_secrets_inplace(
data["init_parameters"],
["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"],
)
return default_from_dict(cls, data)
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,22 @@
import re
from typing import Any, Callable, ClassVar, Dict, List, Optional, Type

import boto3
from botocore.exceptions import BotoCoreError, ClientError
from botocore.exceptions import ClientError
from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.utils import deserialize_callback_handler, serialize_callback_handler
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.utils.auth import Secret, deserialize_secrets_inplace

from haystack_integrations.components.generators.amazon_bedrock.errors import (
from haystack_integrations.common.amazon_bedrock.errors import (
AmazonBedrockConfigurationError,
AmazonBedrockInferenceError,
AWSConfigurationError,
)
from haystack_integrations.common.amazon_bedrock.utils import get_aws_session

from .adapters import AnthropicClaudeChatAdapter, BedrockModelChatAdapter, MetaLlama2ChatAdapter

logger = logging.getLogger(__name__)

AWS_CONFIGURATION_KEYS = [
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
"aws_region_name",
"aws_profile_name",
]


@component
class AmazonBedrockChatGenerator:
Expand Down Expand Up @@ -123,7 +114,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None

try:
session = self.get_aws_session(
session = get_aws_session(
aws_access_key_id=resolve_secret(aws_access_key_id),
aws_secret_access_key=resolve_secret(aws_secret_access_key),
aws_session_token=resolve_secret(aws_session_token),
Expand Down Expand Up @@ -186,53 +177,6 @@ def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter]
return adapter
return None

@classmethod
def aws_configured(cls, **kwargs) -> bool:
"""
Checks whether AWS configuration is provided.
:param kwargs: The kwargs passed down to the generator.
:return: True if AWS configuration is provided, False otherwise.
"""
aws_config_provided = any(key in kwargs for key in AWS_CONFIGURATION_KEYS)
return aws_config_provided

@classmethod
def get_aws_session(
cls,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
**kwargs,
):
"""
Creates an AWS Session with the given parameters.
Checks if the provided AWS credentials are valid and can be used to connect to AWS.

:param aws_access_key_id: AWS access key ID.
:param aws_secret_access_key: AWS secret access key.
:param aws_session_token: AWS session token.
:param aws_region_name: AWS region name.
:param aws_profile_name: AWS profile name.
:param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen.
See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html.
:raises AWSConfigurationError: If the provided AWS credentials are invalid.
:return: The created AWS session.
"""
try:
return boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=aws_region_name,
profile_name=aws_profile_name,
)
except BotoCoreError as e:
provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS}
msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}"
raise AWSConfigurationError(msg) from e

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
Expand Down
Loading