diff --git a/integrations/amazon_bedrock/pydoc/config.yml b/integrations/amazon_bedrock/pydoc/config.yml index 277ba85be..2506c3369 100644 --- a/integrations/amazon_bedrock/pydoc/config.yml +++ b/integrations/amazon_bedrock/pydoc/config.yml @@ -4,7 +4,7 @@ loaders: modules: [ "haystack_integrations.components.generators.amazon_bedrock.generator", "haystack_integrations.components.generators.amazon_bedrock.adapters", - "haystack_integrations.components.generators.amazon_bedrock.errors", + "haystack_integrations.common.amazon_bedrock.errors", "haystack_integrations.components.generators.amazon_bedrock.handlers", ] ignore_when_discovered: ["__init__"] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/common/amazon_bedrock/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/common/amazon_bedrock/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/common/amazon_bedrock/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py b/integrations/amazon_bedrock/src/haystack_integrations/common/amazon_bedrock/errors.py similarity index 100% rename from integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py rename to integrations/amazon_bedrock/src/haystack_integrations/common/amazon_bedrock/errors.py diff --git a/integrations/amazon_bedrock/src/haystack_integrations/common/amazon_bedrock/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/common/amazon_bedrock/utils.py new file mode 100644 index 000000000..e1683e3b3 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/common/amazon_bedrock/utils.py @@ -0,0 +1,60 @@ +from typing import Optional + +import boto3 +from botocore.exceptions import BotoCoreError + +from haystack_integrations.common.amazon_bedrock.errors import AWSConfigurationError + +AWS_CONFIGURATION_KEYS = [ + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "aws_region_name", + "aws_profile_name", +] + + +def get_aws_session( + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + **kwargs, +): + """ + Creates an AWS Session with the given parameters. + Checks if the provided AWS credentials are valid and can be used to connect to AWS. + + :param aws_access_key_id: AWS access key ID. + :param aws_secret_access_key: AWS secret access key. + :param aws_session_token: AWS session token. + :param aws_region_name: AWS region name. + :param aws_profile_name: AWS profile name. + :param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen. + See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html. + :raises AWSConfigurationError: If the provided AWS credentials are invalid. + :return: The created AWS session. + """ + try: + return boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + region_name=aws_region_name, + profile_name=aws_profile_name, + ) + except BotoCoreError as e: + provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS} + msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}" + raise AWSConfigurationError(msg) from e + + +def aws_configured(**kwargs) -> bool: + """ + Checks whether AWS configuration is provided. + :param kwargs: The kwargs passed down to the generator. + :return: True if AWS configuration is provided, False otherwise. + """ + aws_config_provided = any(key in kwargs for key in AWS_CONFIGURATION_KEYS) + return aws_config_provided diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py new file mode 100644 index 000000000..5f1ff6caa --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .text_embedder import AmazonBedrockTextEmbedder + +__all__ = ["AmazonBedrockTextEmbedder"] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py new file mode 100644 index 000000000..8804702a0 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py @@ -0,0 +1,181 @@ +import json +import logging +from typing import Any, Dict, List, Literal, Optional + +from botocore.exceptions import ClientError +from haystack import component, default_from_dict, default_to_dict +from haystack.utils.auth import Secret, deserialize_secrets_inplace + +from haystack_integrations.common.amazon_bedrock.errors import ( + AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, +) +from haystack_integrations.common.amazon_bedrock.utils import get_aws_session + +logger = logging.getLogger(__name__) + +SUPPORTED_EMBEDDING_MODELS = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"] + + +@component +class AmazonBedrockTextEmbedder: + """ + A component for embedding strings using Amazon Bedrock. + + Usage example: + ```python + import os + from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockTextEmbedder + + os.environ["AWS_ACCESS_KEY_ID"] = "..." + os.environ["AWS_SECRET_ACCESS_KEY_ID"] = "..." + os.environ["AWS_REGION_NAME"] = "..." + + embedder = AmazonBedrockTextEmbedder( + model="cohere.embed-english-v3", + input_type="search_query", + ) + + print(text_embedder.run("I love Paris in the summer.")) + + # {'embedding': [0.002, 0.032, 0.504, ...]} + ``` + """ + + def __init__( + self, + model: Literal["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"], + aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008 + aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 + "AWS_SECRET_ACCESS_KEY", strict=False + ), + aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008 + aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008 + aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008 + **kwargs, + ): + """ + Initializes the AmazonBedrockTextEmbedder with the provided parameters. The parameters are passed to the + Amazon Bedrock client. + + Note that the AWS credentials are not required if the AWS environment is configured correctly. These are loaded + automatically from the environment or the AWS configuration file and do not need to be provided explicitly via + the constructor. If the AWS environment is not configured users need to provide the AWS credentials via the + constructor. Aside from model, three required parameters are `aws_access_key_id`, `aws_secret_access_key`, + and `aws_region_name`. + + :param model: The embedding model to use. The model has to be specified in the format outlined in the Amazon + Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html). + :type model: Literal["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"] + :param aws_access_key_id: AWS access key ID. + :param aws_secret_access_key: AWS secret access key. + :param aws_session_token: AWS session token. + :param aws_region_name: AWS region name. + :param aws_profile_name: AWS profile name. + :param kwargs: Additional parameters to pass for model inference. For example, `input_type` and `truncate` for + Cohere models. + """ + if not model or model not in SUPPORTED_EMBEDDING_MODELS: + msg = "Please provide a valid model from the list of supported models: " + ", ".join( + SUPPORTED_EMBEDDING_MODELS + ) + raise ValueError(msg) + + def resolve_secret(secret: Optional[Secret]) -> Optional[str]: + return secret.resolve_value() if secret else None + + try: + session = get_aws_session( + aws_access_key_id=resolve_secret(aws_access_key_id), + aws_secret_access_key=resolve_secret(aws_secret_access_key), + aws_session_token=resolve_secret(aws_session_token), + aws_region_name=resolve_secret(aws_region_name), + aws_profile_name=resolve_secret(aws_profile_name), + ) + self._client = session.client("bedrock-runtime") + except Exception as exception: + msg = ( + "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " + "See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration" + ) + raise AmazonBedrockConfigurationError(msg) from exception + + self.model = model + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.aws_session_token = aws_session_token + self.aws_region_name = aws_region_name + self.aws_profile_name = aws_profile_name + self.kwargs = kwargs + + @component.output_types(embedding=List[float]) + def run(self, text: str): + if not isinstance(text, str): + msg = ( + "AmazonBedrockTextEmbedder expects a string as an input." + "In case you want to embed a list of Documents, please use the AmazonBedrockTextEmbedder." + ) + raise TypeError(msg) + + if "cohere" in self.model: + body = { + "texts": [text], + "input_type": self.kwargs.get("input_type", "search_query"), # mandatory parameter for Cohere models + } + if truncate := self.kwargs.get("truncate"): + body["truncate"] = truncate # optional parameter for Cohere models + + elif "titan" in self.model: + body = { + "inputText": text, + } + + try: + response = self._client.invoke_model( + body=json.dumps(body), modelId=self.model, accept="*/*", contentType="application/json" + ) + except ClientError as exception: + msg = ( + f"Could not connect to Amazon Bedrock model {self.model}. " + f"Make sure your AWS environment is configured correctly, " + f"the model is available in the configured AWS region, and you have access." + ) + raise AmazonBedrockInferenceError(msg) from exception + + response_body = json.loads(response.get("body").read()) + + if "cohere" in self.model: + embedding = response_body["embeddings"][0] + elif "titan" in self.model: + embedding = response_body["embedding"] + + return {"embedding": embedding} + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + :return: The serialized component as a dictionary. + """ + return default_to_dict( + self, + aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, + aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None, + aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None, + aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, + aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, + model=self.model, + **self.kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockTextEmbedder": + """ + Deserialize this component from a dictionary. + :param data: The dictionary representation of this component. + :return: The deserialized component instance. + """ + deserialize_secrets_inplace( + data["init_parameters"], + ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], + ) + return default_from_dict(cls, data) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 2af99e685..2ec664f88 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -3,31 +3,22 @@ import re from typing import Any, Callable, ClassVar, Dict, List, Optional, Type -import boto3 -from botocore.exceptions import BotoCoreError, ClientError +from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict from haystack.components.generators.utils import deserialize_callback_handler, serialize_callback_handler from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.utils.auth import Secret, deserialize_secrets_inplace -from haystack_integrations.components.generators.amazon_bedrock.errors import ( +from haystack_integrations.common.amazon_bedrock.errors import ( AmazonBedrockConfigurationError, AmazonBedrockInferenceError, - AWSConfigurationError, ) +from haystack_integrations.common.amazon_bedrock.utils import get_aws_session from .adapters import AnthropicClaudeChatAdapter, BedrockModelChatAdapter, MetaLlama2ChatAdapter logger = logging.getLogger(__name__) -AWS_CONFIGURATION_KEYS = [ - "aws_access_key_id", - "aws_secret_access_key", - "aws_session_token", - "aws_region_name", - "aws_profile_name", -] - @component class AmazonBedrockChatGenerator: @@ -123,7 +114,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None try: - session = self.get_aws_session( + session = get_aws_session( aws_access_key_id=resolve_secret(aws_access_key_id), aws_secret_access_key=resolve_secret(aws_secret_access_key), aws_session_token=resolve_secret(aws_session_token), @@ -186,53 +177,6 @@ def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter] return adapter return None - @classmethod - def aws_configured(cls, **kwargs) -> bool: - """ - Checks whether AWS configuration is provided. - :param kwargs: The kwargs passed down to the generator. - :return: True if AWS configuration is provided, False otherwise. - """ - aws_config_provided = any(key in kwargs for key in AWS_CONFIGURATION_KEYS) - return aws_config_provided - - @classmethod - def get_aws_session( - cls, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_session_token: Optional[str] = None, - aws_region_name: Optional[str] = None, - aws_profile_name: Optional[str] = None, - **kwargs, - ): - """ - Creates an AWS Session with the given parameters. - Checks if the provided AWS credentials are valid and can be used to connect to AWS. - - :param aws_access_key_id: AWS access key ID. - :param aws_secret_access_key: AWS secret access key. - :param aws_session_token: AWS session token. - :param aws_region_name: AWS region name. - :param aws_profile_name: AWS profile name. - :param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen. - See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html. - :raises AWSConfigurationError: If the provided AWS credentials are invalid. - :return: The created AWS session. - """ - try: - return boto3.Session( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_session_token=aws_session_token, - region_name=aws_region_name, - profile_name=aws_profile_name, - ) - except BotoCoreError as e: - provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS} - msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}" - raise AWSConfigurationError(msg) from e - def to_dict(self) -> Dict[str, Any]: """ Serialize this component to a dictionary. diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index e1820497d..9c3e157cb 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -3,11 +3,16 @@ import re from typing import Any, ClassVar, Dict, List, Optional, Type, Union -import boto3 -from botocore.exceptions import BotoCoreError, ClientError +from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict from haystack.utils.auth import Secret, deserialize_secrets_inplace +from haystack_integrations.common.amazon_bedrock.errors import ( + AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, +) +from haystack_integrations.common.amazon_bedrock.utils import get_aws_session + from .adapters import ( AI21LabsJurassic2Adapter, AmazonTitanAdapter, @@ -16,11 +21,6 @@ CohereCommandAdapter, MetaLlama2ChatAdapter, ) -from .errors import ( - AmazonBedrockConfigurationError, - AmazonBedrockInferenceError, - AWSConfigurationError, -) from .handlers import ( DefaultPromptHandler, DefaultTokenStreamingHandler, @@ -29,14 +29,6 @@ logger = logging.getLogger(__name__) -AWS_CONFIGURATION_KEYS = [ - "aws_access_key_id", - "aws_secret_access_key", - "aws_session_token", - "aws_region_name", - "aws_profile_name", -] - @component class AmazonBedrockGenerator: @@ -98,7 +90,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None try: - session = self.get_aws_session( + session = get_aws_session( aws_access_key_id=resolve_secret(aws_access_key_id), aws_secret_access_key=resolve_secret(aws_secret_access_key), aws_session_token=resolve_secret(aws_session_token), @@ -212,53 +204,6 @@ def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelAdapter]]: return adapter return None - @classmethod - def aws_configured(cls, **kwargs) -> bool: - """ - Checks whether AWS configuration is provided. - :param kwargs: The kwargs passed down to the generator. - :return: True if AWS configuration is provided, False otherwise. - """ - aws_config_provided = any(key in kwargs for key in AWS_CONFIGURATION_KEYS) - return aws_config_provided - - @classmethod - def get_aws_session( - cls, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_session_token: Optional[str] = None, - aws_region_name: Optional[str] = None, - aws_profile_name: Optional[str] = None, - **kwargs, - ): - """ - Creates an AWS Session with the given parameters. - Checks if the provided AWS credentials are valid and can be used to connect to AWS. - - :param aws_access_key_id: AWS access key ID. - :param aws_secret_access_key: AWS secret access key. - :param aws_session_token: AWS session token. - :param aws_region_name: AWS region name. - :param aws_profile_name: AWS profile name. - :param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen. - See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html. - :raises AWSConfigurationError: If the provided AWS credentials are invalid. - :return: The created AWS session. - """ - try: - return boto3.Session( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_session_token=aws_session_token, - region_name=aws_region_name, - profile_name=aws_profile_name, - ) - except BotoCoreError as e: - provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS} - msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}" - raise AWSConfigurationError(msg) from e - def to_dict(self) -> Dict[str, Any]: """ Serialize this component to a dictionary. diff --git a/integrations/amazon_bedrock/tests/conftest.py b/integrations/amazon_bedrock/tests/conftest.py index 4c8ce688c..9406559bf 100644 --- a/integrations/amazon_bedrock/tests/conftest.py +++ b/integrations/amazon_bedrock/tests/conftest.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest @@ -12,14 +12,6 @@ def set_env_variables(monkeypatch): monkeypatch.setenv("AWS_PROFILE", "some_fake_profile") -@pytest.fixture -def mock_auto_tokenizer(): - with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained: - mock_tokenizer = MagicMock() - mock_from_pretrained.return_value = mock_tokenizer - yield mock_tokenizer - - # create a fixture with mocked boto3 client and session @pytest.fixture def mock_boto3_session(): diff --git a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py b/integrations/amazon_bedrock/tests/test_chat_generator.py similarity index 86% rename from integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py rename to integrations/amazon_bedrock/tests/test_chat_generator.py index 86a610811..0f7bced89 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,5 +1,4 @@ from typing import Optional, Type -from unittest.mock import MagicMock, patch import pytest from haystack.components.generators.utils import print_streaming_chunk @@ -15,30 +14,7 @@ clazz = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" -@pytest.fixture -def mock_auto_tokenizer(): - with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained: - mock_tokenizer = MagicMock() - mock_from_pretrained.return_value = mock_tokenizer - yield mock_tokenizer - - -# create a fixture with mocked boto3 client and session -@pytest.fixture -def mock_boto3_session(): - with patch("boto3.Session") as mock_client: - yield mock_client - - -@pytest.fixture -def mock_prompt_handler(): - with patch( - "haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler" - ) as mock_prompt_handler: - yield mock_prompt_handler - - -def test_to_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables): +def test_to_dict(mock_boto3_session): """ Test that the to_dict method returns the correct dictionary without aws credentials """ @@ -65,7 +41,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables): assert generator.to_dict() == expected_dict -def test_from_dict(mock_auto_tokenizer, mock_boto3_session): +def test_from_dict(mock_boto3_session): """ Test that the from_dict method returns the correct object """ @@ -89,7 +65,7 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): assert generator.streaming_callback == print_streaming_chunk -def test_default_constructor(mock_auto_tokenizer, mock_boto3_session, set_env_variables): +def test_default_constructor(mock_boto3_session, set_env_variables): """ Test that the default constructor sets the correct values """ @@ -116,7 +92,7 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session, set_env_va ) -def test_constructor_with_generation_kwargs(mock_auto_tokenizer, mock_boto3_session): +def test_constructor_with_generation_kwargs(mock_boto3_session): """ Test that model_kwargs are correctly set in the constructor """ @@ -135,8 +111,7 @@ def test_constructor_with_empty_model(): AmazonBedrockChatGenerator(model="") -@pytest.mark.unit -def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session): +def test_invoke_with_no_kwargs(mock_boto3_session): """ Test invoke raises an error if no messages are provided """ @@ -145,7 +120,6 @@ def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session): layer.invoke() -@pytest.mark.unit @pytest.mark.parametrize( "model, expected_model_adapter", [ @@ -168,7 +142,7 @@ def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[Bed class TestAnthropicClaudeAdapter: - def test_prepare_body_with_default_params(self, mock_auto_tokenizer) -> None: + def test_prepare_body_with_default_params(self) -> None: layer = AnthropicClaudeChatAdapter(generation_kwargs={}) prompt = "Hello, how are you?" expected_body = { @@ -181,7 +155,7 @@ def test_prepare_body_with_default_params(self, mock_auto_tokenizer) -> None: assert body == expected_body - def test_prepare_body_with_custom_inference_params(self, mock_auto_tokenizer) -> None: + def test_prepare_body_with_custom_inference_params(self) -> None: layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) prompt = "Hello, how are you?" expected_body = { diff --git a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py b/integrations/amazon_bedrock/tests/test_generator.py similarity index 98% rename from integrations/amazon_bedrock/tests/test_amazon_bedrock.py rename to integrations/amazon_bedrock/tests/test_generator.py index e43cc94cf..1ab9ba3c8 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -14,8 +14,7 @@ ) -@pytest.mark.unit -def test_to_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables): +def test_to_dict(mock_boto3_session): """ Test that the to_dict method returns the correct dictionary without aws credentials """ @@ -40,8 +39,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables): assert generator.to_dict() == expected_dict -@pytest.mark.unit -def test_from_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables): +def test_from_dict(mock_boto3_session): """ Test that the from_dict method returns the correct object """ @@ -64,8 +62,7 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables): assert generator.model == "anthropic.claude-v2" -@pytest.mark.unit -def test_default_constructor(mock_auto_tokenizer, mock_boto3_session, set_env_variables): +def test_default_constructor(mock_boto3_session, set_env_variables): """ Test that the default constructor sets the correct values """ @@ -94,8 +91,7 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session, set_env_va ) -@pytest.mark.unit -def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_session, mock_prompt_handler): +def test_constructor_prompt_handler_initialized(mock_boto3_session, mock_prompt_handler): """ Test that the constructor sets the prompt_handler correctly, with the correct model_max_length for llama-2 """ @@ -104,8 +100,7 @@ def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_ assert layer.prompt_handler.model_max_length == 4096 -@pytest.mark.unit -def test_constructor_with_model_kwargs(mock_auto_tokenizer, mock_boto3_session): +def test_constructor_with_model_kwargs(mock_boto3_session): """ Test that model_kwargs are correctly set in the constructor """ @@ -116,7 +111,6 @@ def test_constructor_with_model_kwargs(mock_auto_tokenizer, mock_boto3_session): assert layer.model_adapter.model_kwargs["temperature"] == 0.7 -@pytest.mark.unit def test_constructor_with_empty_model(): """ Test that the constructor raises an error when the model is empty @@ -125,8 +119,7 @@ def test_constructor_with_empty_model(): AmazonBedrockGenerator(model="") -@pytest.mark.unit -def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session): +def test_invoke_with_no_kwargs(mock_boto3_session): """ Test invoke raises an error if no prompt is provided """ @@ -135,7 +128,6 @@ def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session): layer.invoke() -@pytest.mark.unit def test_short_prompt_is_not_truncated(mock_boto3_session): """ Test that a short prompt is not truncated @@ -166,7 +158,6 @@ def test_short_prompt_is_not_truncated(mock_boto3_session): assert prompt_after_resize == mock_prompt_text -@pytest.mark.unit def test_long_prompt_is_truncated(mock_boto3_session): """ Test that a long prompt is truncated @@ -201,7 +192,6 @@ def test_long_prompt_is_truncated(mock_boto3_session): assert prompt_after_resize == truncated_prompt_text -@pytest.mark.unit @pytest.mark.parametrize( "model, expected_model_adapter", [ diff --git a/integrations/amazon_bedrock/tests/test_text_embedder.py b/integrations/amazon_bedrock/tests/test_text_embedder.py new file mode 100644 index 000000000..022c0c0a5 --- /dev/null +++ b/integrations/amazon_bedrock/tests/test_text_embedder.py @@ -0,0 +1,150 @@ +import io +from unittest.mock import patch + +import pytest +from botocore.exceptions import ClientError + +from haystack_integrations.common.amazon_bedrock.errors import ( + AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, +) +from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockTextEmbedder + + +class TestAmazonBedrockTextEmbedder: + def test_init(self, mock_boto3_session, set_env_variables): + embedder = AmazonBedrockTextEmbedder( + model="cohere.embed-english-v3", + input_type="fake_input_type", + ) + + assert embedder.model == "cohere.embed-english-v3" + assert embedder.kwargs == {"input_type": "fake_input_type"} + + # assert mocked boto3 client called exactly once + mock_boto3_session.assert_called_once() + + # assert mocked boto3 client was called with the correct parameters + mock_boto3_session.assert_called_with( + aws_access_key_id="some_fake_id", + aws_secret_access_key="some_fake_key", + aws_session_token="some_fake_token", + profile_name="some_fake_profile", + region_name="fake_region", + ) + + def test_connection_error(self, mock_boto3_session): + mock_boto3_session.side_effect = Exception("some connection error") + + with pytest.raises(AmazonBedrockConfigurationError): + AmazonBedrockTextEmbedder( + model="cohere.embed-english-v3", + input_type="fake_input_type", + ) + + def test_to_dict(self, mock_boto3_session): + + embedder = AmazonBedrockTextEmbedder( + model="cohere.embed-english-v3", + input_type="search_query", + ) + + expected_dict = { + "type": "haystack_integrations.components.embedders.amazon_bedrock.text_embedder.AmazonBedrockTextEmbedder", + "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "model": "cohere.embed-english-v3", + "input_type": "search_query", + }, + } + + assert embedder.to_dict() == expected_dict + + def test_from_dict(self, mock_boto3_session): + + data = { + "type": "haystack_integrations.components.embedders.amazon_bedrock.text_embedder.AmazonBedrockTextEmbedder", + "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "model": "cohere.embed-english-v3", + "input_type": "search_query", + }, + } + + embedder = AmazonBedrockTextEmbedder.from_dict(data) + + assert embedder.model == "cohere.embed-english-v3" + assert embedder.kwargs == {"input_type": "search_query"} + + +def test_init_invalid_model(): + with pytest.raises(ValueError): + AmazonBedrockTextEmbedder(model="") + + with pytest.raises(ValueError): + AmazonBedrockTextEmbedder(model="my-unsupported-model") + + +def test_run_wrong_type(mock_boto3_session): + embedder = AmazonBedrockTextEmbedder(model="cohere.embed-english-v3") + with pytest.raises(TypeError): + embedder.run(text=123) + + +def test_cohere_invocation(mock_boto3_session): + embedder = AmazonBedrockTextEmbedder(model="cohere.embed-english-v3") + + with patch.object(embedder._client, "invoke_model") as mock_invoke_model: + mock_invoke_model.return_value = { + "body": io.StringIO('{"embeddings": [[0.1, 0.2, 0.3]]}'), + } + result = embedder.run(text="some text") + + mock_invoke_model.assert_called_once_with( + body='{"texts": ["some text"], "input_type": "search_query"}', + modelId="cohere.embed-english-v3", + accept="*/*", + contentType="application/json", + ) + + assert result == {"embedding": [0.1, 0.2, 0.3]} + + +def test_titan_invocation(mock_boto3_session): + embedder = AmazonBedrockTextEmbedder(model="amazon.titan-embed-text-v1") + + with patch.object(embedder._client, "invoke_model") as mock_invoke_model: + mock_invoke_model.return_value = { + "body": io.StringIO('{"embedding": [0.1, 0.2, 0.3]}'), + } + result = embedder.run(text="some text") + + mock_invoke_model.assert_called_once_with( + body='{"inputText": "some text"}', + modelId="amazon.titan-embed-text-v1", + accept="*/*", + contentType="application/json", + ) + + assert result == {"embedding": [0.1, 0.2, 0.3]} + + +def test_run_invocation_error(mock_boto3_session): + embedder = AmazonBedrockTextEmbedder(model="cohere.embed-english-v3") + + with patch.object(embedder._client, "invoke_model") as mock_invoke_model: + mock_invoke_model.side_effect = ClientError( + error_response={"Error": {"Code": "some_code", "Message": "some_message"}}, + operation_name="some_operation", + ) + + with pytest.raises(AmazonBedrockInferenceError): + embedder.run(text="some text")