-
Notifications
You must be signed in to change notification settings - Fork 127
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into cohere_callable_serialization
- Loading branch information
Showing
15 changed files
with
464 additions
and
190 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 3 additions & 0 deletions
3
integrations/amazon_bedrock/src/haystack_integrations/common/amazon_bedrock/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 |
File renamed without changes.
60 changes: 60 additions & 0 deletions
60
integrations/amazon_bedrock/src/haystack_integrations/common/amazon_bedrock/utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from typing import Optional | ||
|
||
import boto3 | ||
from botocore.exceptions import BotoCoreError | ||
|
||
from haystack_integrations.common.amazon_bedrock.errors import AWSConfigurationError | ||
|
||
AWS_CONFIGURATION_KEYS = [ | ||
"aws_access_key_id", | ||
"aws_secret_access_key", | ||
"aws_session_token", | ||
"aws_region_name", | ||
"aws_profile_name", | ||
] | ||
|
||
|
||
def get_aws_session( | ||
aws_access_key_id: Optional[str] = None, | ||
aws_secret_access_key: Optional[str] = None, | ||
aws_session_token: Optional[str] = None, | ||
aws_region_name: Optional[str] = None, | ||
aws_profile_name: Optional[str] = None, | ||
**kwargs, | ||
): | ||
""" | ||
Creates an AWS Session with the given parameters. | ||
Checks if the provided AWS credentials are valid and can be used to connect to AWS. | ||
:param aws_access_key_id: AWS access key ID. | ||
:param aws_secret_access_key: AWS secret access key. | ||
:param aws_session_token: AWS session token. | ||
:param aws_region_name: AWS region name. | ||
:param aws_profile_name: AWS profile name. | ||
:param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen. | ||
See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html. | ||
:raises AWSConfigurationError: If the provided AWS credentials are invalid. | ||
:return: The created AWS session. | ||
""" | ||
try: | ||
return boto3.Session( | ||
aws_access_key_id=aws_access_key_id, | ||
aws_secret_access_key=aws_secret_access_key, | ||
aws_session_token=aws_session_token, | ||
region_name=aws_region_name, | ||
profile_name=aws_profile_name, | ||
) | ||
except BotoCoreError as e: | ||
provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS} | ||
msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}" | ||
raise AWSConfigurationError(msg) from e | ||
|
||
|
||
def aws_configured(**kwargs) -> bool: | ||
""" | ||
Checks whether AWS configuration is provided. | ||
:param kwargs: The kwargs passed down to the generator. | ||
:return: True if AWS configuration is provided, False otherwise. | ||
""" | ||
aws_config_provided = any(key in kwargs for key in AWS_CONFIGURATION_KEYS) | ||
return aws_config_provided |
6 changes: 6 additions & 0 deletions
6
.../amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from .text_embedder import AmazonBedrockTextEmbedder | ||
|
||
__all__ = ["AmazonBedrockTextEmbedder"] |
181 changes: 181 additions & 0 deletions
181
...on_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
import json | ||
import logging | ||
from typing import Any, Dict, List, Literal, Optional | ||
|
||
from botocore.exceptions import ClientError | ||
from haystack import component, default_from_dict, default_to_dict | ||
from haystack.utils.auth import Secret, deserialize_secrets_inplace | ||
|
||
from haystack_integrations.common.amazon_bedrock.errors import ( | ||
AmazonBedrockConfigurationError, | ||
AmazonBedrockInferenceError, | ||
) | ||
from haystack_integrations.common.amazon_bedrock.utils import get_aws_session | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
SUPPORTED_EMBEDDING_MODELS = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"] | ||
|
||
|
||
@component | ||
class AmazonBedrockTextEmbedder: | ||
""" | ||
A component for embedding strings using Amazon Bedrock. | ||
Usage example: | ||
```python | ||
import os | ||
from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockTextEmbedder | ||
os.environ["AWS_ACCESS_KEY_ID"] = "..." | ||
os.environ["AWS_SECRET_ACCESS_KEY_ID"] = "..." | ||
os.environ["AWS_REGION_NAME"] = "..." | ||
embedder = AmazonBedrockTextEmbedder( | ||
model="cohere.embed-english-v3", | ||
input_type="search_query", | ||
) | ||
print(text_embedder.run("I love Paris in the summer.")) | ||
# {'embedding': [0.002, 0.032, 0.504, ...]} | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: Literal["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"], | ||
aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008 | ||
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 | ||
"AWS_SECRET_ACCESS_KEY", strict=False | ||
), | ||
aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008 | ||
aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008 | ||
aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008 | ||
**kwargs, | ||
): | ||
""" | ||
Initializes the AmazonBedrockTextEmbedder with the provided parameters. The parameters are passed to the | ||
Amazon Bedrock client. | ||
Note that the AWS credentials are not required if the AWS environment is configured correctly. These are loaded | ||
automatically from the environment or the AWS configuration file and do not need to be provided explicitly via | ||
the constructor. If the AWS environment is not configured users need to provide the AWS credentials via the | ||
constructor. Aside from model, three required parameters are `aws_access_key_id`, `aws_secret_access_key`, | ||
and `aws_region_name`. | ||
:param model: The embedding model to use. The model has to be specified in the format outlined in the Amazon | ||
Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html). | ||
:type model: Literal["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"] | ||
:param aws_access_key_id: AWS access key ID. | ||
:param aws_secret_access_key: AWS secret access key. | ||
:param aws_session_token: AWS session token. | ||
:param aws_region_name: AWS region name. | ||
:param aws_profile_name: AWS profile name. | ||
:param kwargs: Additional parameters to pass for model inference. For example, `input_type` and `truncate` for | ||
Cohere models. | ||
""" | ||
if not model or model not in SUPPORTED_EMBEDDING_MODELS: | ||
msg = "Please provide a valid model from the list of supported models: " + ", ".join( | ||
SUPPORTED_EMBEDDING_MODELS | ||
) | ||
raise ValueError(msg) | ||
|
||
def resolve_secret(secret: Optional[Secret]) -> Optional[str]: | ||
return secret.resolve_value() if secret else None | ||
|
||
try: | ||
session = get_aws_session( | ||
aws_access_key_id=resolve_secret(aws_access_key_id), | ||
aws_secret_access_key=resolve_secret(aws_secret_access_key), | ||
aws_session_token=resolve_secret(aws_session_token), | ||
aws_region_name=resolve_secret(aws_region_name), | ||
aws_profile_name=resolve_secret(aws_profile_name), | ||
) | ||
self._client = session.client("bedrock-runtime") | ||
except Exception as exception: | ||
msg = ( | ||
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " | ||
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration" | ||
) | ||
raise AmazonBedrockConfigurationError(msg) from exception | ||
|
||
self.model = model | ||
self.aws_access_key_id = aws_access_key_id | ||
self.aws_secret_access_key = aws_secret_access_key | ||
self.aws_session_token = aws_session_token | ||
self.aws_region_name = aws_region_name | ||
self.aws_profile_name = aws_profile_name | ||
self.kwargs = kwargs | ||
|
||
@component.output_types(embedding=List[float]) | ||
def run(self, text: str): | ||
if not isinstance(text, str): | ||
msg = ( | ||
"AmazonBedrockTextEmbedder expects a string as an input." | ||
"In case you want to embed a list of Documents, please use the AmazonBedrockTextEmbedder." | ||
) | ||
raise TypeError(msg) | ||
|
||
if "cohere" in self.model: | ||
body = { | ||
"texts": [text], | ||
"input_type": self.kwargs.get("input_type", "search_query"), # mandatory parameter for Cohere models | ||
} | ||
if truncate := self.kwargs.get("truncate"): | ||
body["truncate"] = truncate # optional parameter for Cohere models | ||
|
||
elif "titan" in self.model: | ||
body = { | ||
"inputText": text, | ||
} | ||
|
||
try: | ||
response = self._client.invoke_model( | ||
body=json.dumps(body), modelId=self.model, accept="*/*", contentType="application/json" | ||
) | ||
except ClientError as exception: | ||
msg = ( | ||
f"Could not connect to Amazon Bedrock model {self.model}. " | ||
f"Make sure your AWS environment is configured correctly, " | ||
f"the model is available in the configured AWS region, and you have access." | ||
) | ||
raise AmazonBedrockInferenceError(msg) from exception | ||
|
||
response_body = json.loads(response.get("body").read()) | ||
|
||
if "cohere" in self.model: | ||
embedding = response_body["embeddings"][0] | ||
elif "titan" in self.model: | ||
embedding = response_body["embedding"] | ||
|
||
return {"embedding": embedding} | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serialize this component to a dictionary. | ||
:return: The serialized component as a dictionary. | ||
""" | ||
return default_to_dict( | ||
self, | ||
aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, | ||
aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None, | ||
aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None, | ||
aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, | ||
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, | ||
model=self.model, | ||
**self.kwargs, | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockTextEmbedder": | ||
""" | ||
Deserialize this component from a dictionary. | ||
:param data: The dictionary representation of this component. | ||
:return: The deserialized component instance. | ||
""" | ||
deserialize_secrets_inplace( | ||
data["init_parameters"], | ||
["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], | ||
) | ||
return default_from_dict(cls, data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.