From b577150f1c79a85fcf6565736f75fd1d7c638549 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Thu, 15 Feb 2024 10:40:39 +0100 Subject: [PATCH 01/14] initial import --- .../amazon_bedrock/chat/chat_generator.py | 29 ++++++++++++------- integrations/amazon_sagemaker/pyproject.toml | 8 ++--- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 804d44413..e99e6f847 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -61,11 +61,13 @@ class AmazonBedrockChatGenerator: def __init__( self, model: str, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_session_token: Optional[str] = None, - aws_region_name: Optional[str] = None, - aws_profile_name: Optional[str] = None, + aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False),# noqa: B008 + aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 + ["AWS_SECRET_ACCESS_KEY"], strict=False + ), + aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008 + aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 + aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 generation_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, @@ -82,12 +84,14 @@ def __init__( :param model: The model to use for generation. The model must be available in Amazon Bedrock. The model has to be specified in the format outlined in the Amazon Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html). + :param aws_access_key_id: AWS access key ID. :param aws_secret_access_key: AWS secret access key. :param aws_session_token: AWS session token. :param aws_region_name: AWS region name. :param aws_profile_name: AWS profile name. :param generation_kwargs: Additional generation keyword arguments passed to the model. The defined keyword + parameters are specific to a specific model and can be found in the model's documentation. For example, the Anthropic Claude generation parameters can be found [here](https://docs.anthropic.com/claude/reference/complete_post). :param stop_words: A list of stop words that stop model generation when encountered. They can be provided via @@ -111,14 +115,17 @@ def __init__( self.model_adapter = model_adapter_cls(generation_kwargs or {}) # create the AWS session and client + def resolve_secret(secret: Optional[Secret]) -> Optional[str]: + return secret.resolve_value() if secret else None + try: session = self.get_aws_session( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_session_token=aws_session_token, - aws_region_name=aws_region_name, - aws_profile_name=aws_profile_name, - ) + aws_access_key_id=resolve_secret(aws_access_key_id), + aws_secret_access_key=resolve_secret(aws_secret_access_key), + aws_session_token=resolve_secret(aws_session_token), + aws_region_name=resolve_secret(aws_region_name), + aws_profile_name=resolve_secret(aws_profile_name), + ) self.client = session.client("bedrock-runtime") except Exception as exception: msg = ( diff --git a/integrations/amazon_sagemaker/pyproject.toml b/integrations/amazon_sagemaker/pyproject.toml index 916307156..c1517881a 100644 --- a/integrations/amazon_sagemaker/pyproject.toml +++ b/integrations/amazon_sagemaker/pyproject.toml @@ -154,19 +154,19 @@ ban-relative-imports = "parents" "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] +source = ["haystack_integrations"] branch = true parallel = true -[tool.coverage.paths] -amazon_sagemaker_haystack = ["src"] -tests = ["tests"] - [tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true exclude_lines = [ "no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:", ] + [[tool.mypy.overrides]] module = [ "haystack.*", From db5acb7e0341c48784741797fc960900e37cf814 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Thu, 15 Feb 2024 11:50:22 +0100 Subject: [PATCH 02/14] wip --- .../generators/amazon_sagemaker/sagemaker.py | 133 +++++++++++------- 1 file changed, 85 insertions(+), 48 deletions(-) diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py index 35e54a055..4ccb9dbec 100644 --- a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py @@ -1,24 +1,30 @@ import json import logging -import os from typing import Any, ClassVar, Dict, List, Optional +import boto3 import requests +from botocore.exceptions import BotoCoreError from haystack import component, default_from_dict, default_to_dict -from haystack.lazy_imports import LazyImport +from haystack.utils import Secret + from haystack_integrations.components.generators.amazon_sagemaker.errors import ( AWSConfigurationError, SagemakerInferenceError, SagemakerNotReadyError, ) -with LazyImport(message="Run 'pip install boto3'") as boto3_import: - import boto3 # type: ignore - from botocore.client import BaseClient # type: ignore - logger = logging.getLogger(__name__) +AWS_CONFIGURATION_KEYS = [ + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "aws_region_name", + "aws_profile_name", +] + MODEL_NOT_READY_STATUS_CODE = 429 @@ -41,9 +47,8 @@ class SagemakerGenerator: Then you can use the generator as follows: ```python - from haystack.components.generators.sagemaker import SagemakerGenerator + from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator generator = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16") - generator.warm_up() response = generator.run("What's Natural Language Processing? Be brief.") print(response) ``` @@ -59,11 +64,13 @@ class SagemakerGenerator: def __init__( self, model: str, - aws_access_key_id_var: str = "AWS_ACCESS_KEY_ID", - aws_secret_access_key_var: str = "AWS_SECRET_ACCESS_KEY", - aws_session_token_var: str = "AWS_SESSION_TOKEN", - aws_region_name_var: str = "AWS_REGION", - aws_profile_name_var: str = "AWS_PROFILE", + aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008 + aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 + ["AWS_SECRET_ACCESS_KEY"], strict=False + ), + aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008 + aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 + aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 aws_custom_attributes: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None, ): @@ -71,13 +78,16 @@ def __init__( Instantiates the session with SageMaker. :param model: The name for SageMaker Model Endpoint. - :param aws_access_key_id_var: The name of the env var where the AWS access key ID is stored. - :param aws_secret_access_key_var: The name of the env var where the AWS secret access key is stored. - :param aws_session_token_var: The name of the env var where the AWS session token is stored. - :param aws_region_name_var: The name of the env var where the AWS region name is stored. - :param aws_profile_name_var: The name of the env var where the AWS profile name is stored. + + :param aws_access_key_id: The name of the env var where the AWS access key ID is stored. + :param aws_secret_access_key: The name of the env var where the AWS secret access key is stored. + :param aws_session_token: The name of the env var where the AWS session token is stored. + :param aws_region_name: The name of the env var where the AWS region name is stored. + :param aws_profile_name: The name of the env var where the AWS profile name is stored. + :param aws_custom_attributes: Custom attributes to be passed to SageMaker, for example `{"accept_eula": True}` in case of Llama-2 models. + :param generation_kwargs: Additional keyword arguments for text generation. For a list of supported parameters see your model's documentation page, for example here for HuggingFace models: https://huggingface.co/blog/sagemaker-huggingface-llm#4-run-inference-and-chat-with-our-model @@ -95,21 +105,27 @@ def __init__( be boolean. The default value for it is `False`. """ self.model = model - self.aws_access_key_id_var = aws_access_key_id_var - self.aws_secret_access_key_var = aws_secret_access_key_var - self.aws_session_token_var = aws_session_token_var - self.aws_region_name_var = aws_region_name_var - self.aws_profile_name_var = aws_profile_name_var self.aws_custom_attributes = aws_custom_attributes or {} self.generation_kwargs = generation_kwargs or {"max_new_tokens": 1024} - self.client: Optional[BaseClient] = None - if not os.getenv(self.aws_access_key_id_var) or not os.getenv(self.aws_secret_access_key_var): + def resolve_secret(secret: Optional[Secret]) -> Optional[str]: + return secret.resolve_value() if secret else None + + try: + session = self.get_aws_session( + aws_access_key_id=resolve_secret(aws_access_key_id), + aws_secret_access_key=resolve_secret(aws_secret_access_key), + aws_session_token=resolve_secret(aws_session_token), + aws_region_name=resolve_secret(aws_region_name), + aws_profile_name=resolve_secret(aws_profile_name), + ) + self.client = session.client("sagemaker-runtime") + except Exception as e: msg = ( - f"Please provide AWS credentials via environment variables '{self.aws_access_key_id_var}' and " - f"'{self.aws_secret_access_key_var}'." + f"Could not connect to SageMaker Inference Endpoint '{self.model}'." + f"Make sure the Endpoint exists and AWS environment is configured." ) - raise AWSConfigurationError(msg) + raise AWSConfigurationError(msg) from e def _get_telemetry_data(self) -> Dict[str, Any]: """ @@ -124,11 +140,6 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, model=self.model, - aws_access_key_id_var=self.aws_access_key_id_var, - aws_secret_access_key_var=self.aws_secret_access_key_var, - aws_session_token_var=self.aws_session_token_var, - aws_region_name_var=self.aws_region_name_var, - aws_profile_name_var=self.aws_profile_name_var, aws_custom_attributes=self.aws_custom_attributes, generation_kwargs=self.generation_kwargs, ) @@ -140,25 +151,51 @@ def from_dict(cls, data) -> "SagemakerGenerator": """ return default_from_dict(cls, data) - def warm_up(self): + @classmethod + def aws_configured(cls, **kwargs) -> bool: + """ + Checks whether AWS configuration is provided. + :param kwargs: The kwargs passed down to the generator. + :return: True if AWS configuration is provided, False otherwise. """ - Initializes the SageMaker Inference client. + aws_config_provided = any(key in kwargs for key in AWS_CONFIGURATION_KEYS) + return aws_config_provided + + @classmethod + def get_aws_session( + cls, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + **kwargs, + ): + """ + Creates an AWS Session with the given parameters. + Checks if the provided AWS credentials are valid and can be used to connect to AWS. + + :param aws_access_key_id: AWS access key ID. + :param aws_secret_access_key: AWS secret access key. + :param aws_session_token: AWS session token. + :param aws_region_name: AWS region name. + :param aws_profile_name: AWS profile name. + :param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen. + See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html. + :raises AWSConfigurationError: If the provided AWS credentials are invalid. + :return: The created AWS session. """ - boto3_import.check() try: - session = boto3.Session( - aws_access_key_id=os.getenv(self.aws_access_key_id_var), - aws_secret_access_key=os.getenv(self.aws_secret_access_key_var), - aws_session_token=os.getenv(self.aws_session_token_var), - region_name=os.getenv(self.aws_region_name_var), - profile_name=os.getenv(self.aws_profile_name_var), - ) - self.client = session.client("sagemaker-runtime") - except Exception as e: - msg = ( - f"Could not connect to SageMaker Inference Endpoint '{self.model}'." - f"Make sure the Endpoint exists and AWS environment is configured." + return boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + region_name=aws_region_name, + profile_name=aws_profile_name, ) + except BotoCoreError as e: + provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS} + msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}" raise AWSConfigurationError(msg) from e @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) From 9e00c9c31b4a25b99acbe0599f2f0ff9b42f7214 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Thu, 15 Feb 2024 21:19:33 +0100 Subject: [PATCH 03/14] adding Secret support and fixing/refactoring tests --- integrations/amazon_sagemaker/pyproject.toml | 13 +- .../generators/amazon_sagemaker/sagemaker.py | 26 +-- .../amazon_sagemaker/tests/conftest.py | 18 ++ .../amazon_sagemaker/tests/test_sagemaker.py | 160 ++++++++---------- 4 files changed, 114 insertions(+), 103 deletions(-) create mode 100644 integrations/amazon_sagemaker/tests/conftest.py diff --git a/integrations/amazon_sagemaker/pyproject.toml b/integrations/amazon_sagemaker/pyproject.toml index c1517881a..9a96e6dd6 100644 --- a/integrations/amazon_sagemaker/pyproject.toml +++ b/integrations/amazon_sagemaker/pyproject.toml @@ -174,4 +174,15 @@ module = [ "pytest.*", "numpy.*", ] -ignore_missing_imports = true \ No newline at end of file +ignore_missing_imports = true + + +[tool.pytest.ini_options] +addopts = "--strict-markers" +markers = [ + "unit: unit tests", + "integration: integration tests", + "embedders: embedders tests", + "generators: generators tests", +] +log_cli = true \ No newline at end of file diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py index 4ccb9dbec..d983aec87 100644 --- a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py @@ -6,7 +6,7 @@ import requests from botocore.exceptions import BotoCoreError from haystack import component, default_from_dict, default_to_dict -from haystack.utils import Secret +from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.components.generators.amazon_sagemaker.errors import ( AWSConfigurationError, @@ -107,6 +107,11 @@ def __init__( self.model = model self.aws_custom_attributes = aws_custom_attributes or {} self.generation_kwargs = generation_kwargs or {"max_new_tokens": 1024} + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.aws_session_token = aws_session_token + self.aws_region_name = aws_region_name + self.aws_profile_name = aws_profile_name def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -140,6 +145,11 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, model=self.model, + aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, + aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None, + aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None, + aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, + aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, aws_custom_attributes=self.aws_custom_attributes, generation_kwargs=self.generation_kwargs, ) @@ -149,18 +159,12 @@ def from_dict(cls, data) -> "SagemakerGenerator": """ Deserialize the dictionary into an instance of SagemakerGenerator. """ + deserialize_secrets_inplace( + data["init_parameters"], + ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], + ) return default_from_dict(cls, data) - @classmethod - def aws_configured(cls, **kwargs) -> bool: - """ - Checks whether AWS configuration is provided. - :param kwargs: The kwargs passed down to the generator. - :return: True if AWS configuration is provided, False otherwise. - """ - aws_config_provided = any(key in kwargs for key in AWS_CONFIGURATION_KEYS) - return aws_config_provided - @classmethod def get_aws_session( cls, diff --git a/integrations/amazon_sagemaker/tests/conftest.py b/integrations/amazon_sagemaker/tests/conftest.py new file mode 100644 index 000000000..d99faab2e --- /dev/null +++ b/integrations/amazon_sagemaker/tests/conftest.py @@ -0,0 +1,18 @@ +import pytest + +from unittest.mock import patch + + +@pytest.fixture +def set_env_variables(monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "some_fake_id") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "some_fake_key") + monkeypatch.setenv("AWS_SESSION_TOKEN", "some_fake_token") + monkeypatch.setenv("AWS_DEFAULT_REGION", "fake_region") + monkeypatch.setenv("AWS_PROFILE", "some_fake_profile") + + +@pytest.fixture +def mock_boto3_session(): + with patch("boto3.Session") as mock_client: + yield mock_client diff --git a/integrations/amazon_sagemaker/tests/test_sagemaker.py b/integrations/amazon_sagemaker/tests/test_sagemaker.py index a22634be1..d9b7281fb 100644 --- a/integrations/amazon_sagemaker/tests/test_sagemaker.py +++ b/integrations/amazon_sagemaker/tests/test_sagemaker.py @@ -1,105 +1,83 @@ import os -from unittest.mock import Mock - import pytest + from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator -from haystack_integrations.components.generators.amazon_sagemaker.errors import AWSConfigurationError +from haystack.utils.auth import EnvVarSecret +from unittest.mock import Mock -class TestSagemakerGenerator: - def test_init_default(self, monkeypatch): - monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") - monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") +mocked_dict = { + "type": "haystack_integrations.components.generators.amazon_sagemaker.sagemaker.SagemakerGenerator", + "init_parameters": { + "model": "model", + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_custom_attributes": {"accept_eula": True}, + "generation_kwargs": {"max_new_tokens": 10}, + }, +} + + +@pytest.mark.unit +@pytest.mark.usefixtures("set_env_variables", "mock_boto3_session") +def test_to_dict(): + """ + Test that the to_dict method returns the correct dictionary without aws credentials + """ + + generator = SagemakerGenerator( + model="model", + generation_kwargs={"max_new_tokens": 10}, + aws_custom_attributes={"accept_eula": True}, + ) + assert generator.to_dict() == mocked_dict - component = SagemakerGenerator(model="test-model") - assert component.model == "test-model" - assert component.aws_access_key_id_var == "AWS_ACCESS_KEY_ID" - assert component.aws_secret_access_key_var == "AWS_SECRET_ACCESS_KEY" - assert component.aws_session_token_var == "AWS_SESSION_TOKEN" - assert component.aws_region_name_var == "AWS_REGION" - assert component.aws_profile_name_var == "AWS_PROFILE" - assert component.aws_custom_attributes == {} - assert component.generation_kwargs == {"max_new_tokens": 1024} - assert component.client is None - - def test_init_fail_wo_access_key_or_secret_key(self, monkeypatch): - monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) - monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) - with pytest.raises(AWSConfigurationError): - SagemakerGenerator(model="test-model") - monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") - monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) - with pytest.raises(AWSConfigurationError): - SagemakerGenerator(model="test-model") +@pytest.mark.unit +@pytest.mark.usefixtures("set_env_variables", "mock_boto3_session") +def test_from_dict(): + """ + Test that the from_dict method returns the correct object + """ - monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) - monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") - with pytest.raises(AWSConfigurationError): - SagemakerGenerator(model="test-model") + generator = SagemakerGenerator.from_dict(mocked_dict) + assert generator.model == "model" + assert isinstance(generator.aws_access_key_id, EnvVarSecret) - def test_init_with_parameters(self, monkeypatch): - monkeypatch.setenv("MY_ACCESS_KEY_ID", "test-access-key") - monkeypatch.setenv("MY_SECRET_ACCESS_KEY", "test-secret-key") - component = SagemakerGenerator( - model="test-model", - aws_access_key_id_var="MY_ACCESS_KEY_ID", - aws_secret_access_key_var="MY_SECRET_ACCESS_KEY", - aws_session_token_var="MY_SESSION_TOKEN", - aws_region_name_var="MY_REGION", - aws_profile_name_var="MY_PROFILE", - aws_custom_attributes={"custom": "attr"}, - generation_kwargs={"generation": "kwargs"}, - ) - assert component.model == "test-model" - assert component.aws_access_key_id_var == "MY_ACCESS_KEY_ID" - assert component.aws_secret_access_key_var == "MY_SECRET_ACCESS_KEY" - assert component.aws_session_token_var == "MY_SESSION_TOKEN" - assert component.aws_region_name_var == "MY_REGION" - assert component.aws_profile_name_var == "MY_PROFILE" - assert component.aws_custom_attributes == {"custom": "attr"} - assert component.generation_kwargs == {"generation": "kwargs"} - assert component.client is None - - def test_to_from_dict(self, monkeypatch): - monkeypatch.setenv("MY_ACCESS_KEY_ID", "test-access-key") - monkeypatch.setenv("MY_SECRET_ACCESS_KEY", "test-secret-key") +@pytest.mark.unit +@pytest.mark.usefixtures("set_env_variables", "mock_boto3_session") +def test_default_constructor(mock_boto3_session): + """ + Test that the default constructor sets the correct values + """ - component = SagemakerGenerator( - model="test-model", - aws_access_key_id_var="MY_ACCESS_KEY_ID", - aws_secret_access_key_var="MY_SECRET_ACCESS_KEY", - aws_session_token_var="MY_SESSION_TOKEN", - aws_region_name_var="MY_REGION", - aws_profile_name_var="MY_PROFILE", - aws_custom_attributes={"custom": "attr"}, - generation_kwargs={"generation": "kwargs"}, - ) - serialized = component.to_dict() - assert serialized == { - "type": "haystack_integrations.components.generators.amazon_sagemaker.sagemaker.SagemakerGenerator", - "init_parameters": { - "model": "test-model", - "aws_access_key_id_var": "MY_ACCESS_KEY_ID", - "aws_secret_access_key_var": "MY_SECRET_ACCESS_KEY", - "aws_session_token_var": "MY_SESSION_TOKEN", - "aws_region_name_var": "MY_REGION", - "aws_profile_name_var": "MY_PROFILE", - "aws_custom_attributes": {"custom": "attr"}, - "generation_kwargs": {"generation": "kwargs"}, - }, - } - deserialized = SagemakerGenerator.from_dict(serialized) - assert deserialized.model == "test-model" - assert deserialized.aws_access_key_id_var == "MY_ACCESS_KEY_ID" - assert deserialized.aws_secret_access_key_var == "MY_SECRET_ACCESS_KEY" - assert deserialized.aws_session_token_var == "MY_SESSION_TOKEN" - assert deserialized.aws_region_name_var == "MY_REGION" - assert deserialized.aws_profile_name_var == "MY_PROFILE" - assert deserialized.aws_custom_attributes == {"custom": "attr"} - assert deserialized.generation_kwargs == {"generation": "kwargs"} - assert deserialized.client is None + generator = SagemakerGenerator( + model="test-model", + ) + + assert generator.generation_kwargs == {"max_new_tokens": 1024} + assert generator.model == "test-model" + + # assert mocked boto3 client called exactly once + mock_boto3_session.assert_called_once() + + assert generator.client is not None + + # assert mocked boto3 client was called with the correct parameters + mock_boto3_session.assert_called_with( + aws_access_key_id="some_fake_id", + aws_secret_access_key="some_fake_key", + aws_session_token="some_fake_token", + profile_name="some_fake_profile", + region_name="fake_region", + ) + + +class TestSagemakerGenerator: def test_run_with_list_of_dictionaries(self, monkeypatch): monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") From 16418b6785dbd6c9b9ad5d290a552200ca8b19ec Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 16 Feb 2024 10:17:41 +0100 Subject: [PATCH 04/14] adding mypy overrides for boto --- integrations/amazon_sagemaker/pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/integrations/amazon_sagemaker/pyproject.toml b/integrations/amazon_sagemaker/pyproject.toml index 9a96e6dd6..9c40ac683 100644 --- a/integrations/amazon_sagemaker/pyproject.toml +++ b/integrations/amazon_sagemaker/pyproject.toml @@ -169,6 +169,8 @@ exclude_lines = [ [[tool.mypy.overrides]] module = [ + "botocore.*", + "boto3.*", "haystack.*", "haystack_integrations.*", "pytest.*", @@ -177,6 +179,7 @@ module = [ ignore_missing_imports = true + [tool.pytest.ini_options] addopts = "--strict-markers" markers = [ From 35fa26200bdca21313744f7aa3c679c4d2d137a4 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 16 Feb 2024 10:22:42 +0100 Subject: [PATCH 05/14] adding mypy overrides for boto --- .../amazon_sagemaker/tests/test_sagemaker.py | 175 +++++++++--------- 1 file changed, 88 insertions(+), 87 deletions(-) diff --git a/integrations/amazon_sagemaker/tests/test_sagemaker.py b/integrations/amazon_sagemaker/tests/test_sagemaker.py index d9b7281fb..93c38a6bb 100644 --- a/integrations/amazon_sagemaker/tests/test_sagemaker.py +++ b/integrations/amazon_sagemaker/tests/test_sagemaker.py @@ -131,91 +131,92 @@ def test_run_with_single_dictionary(self, monkeypatch): assert [isinstance(reply, dict) for reply in response["meta"]] assert response["meta"][0]["other"] == "metadata" - @pytest.mark.skipif( - (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), - reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", - ) - @pytest.mark.integration - def test_run_falcon(self): - component = SagemakerGenerator( - model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16", generation_kwargs={"max_new_tokens": 10} - ) - component.warm_up() - response = component.run("What's Natural Language Processing?") - - # check that the component returns the correct ChatMessage response - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, str) for reply in response["replies"]] - - # Coarse check: assuming no more than 4 chars per token. In any case it - # will fail if the `max_new_tokens` parameter is not respected, as the - # default is either 256 or 1024 - assert all(len(reply) <= 40 for reply in response["replies"]) - - assert "meta" in response - assert isinstance(response["meta"], list) - assert len(response["meta"]) == 1 - assert [isinstance(reply, dict) for reply in response["meta"]] - - @pytest.mark.skipif( - (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), - reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", - ) - @pytest.mark.integration - def test_run_llama2(self): - component = SagemakerGenerator( - model="jumpstart-dft-meta-textgenerationneuron-llama-2-7b", - generation_kwargs={"max_new_tokens": 10}, - aws_custom_attributes={"accept_eula": True}, - ) - component.warm_up() - response = component.run("What's Natural Language Processing?") - - # check that the component returns the correct ChatMessage response - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, str) for reply in response["replies"]] - - # Coarse check: assuming no more than 4 chars per token. In any case it - # will fail if the `max_new_tokens` parameter is not respected, as the - # default is either 256 or 1024 - assert all(len(reply) <= 40 for reply in response["replies"]) - - assert "meta" in response - assert isinstance(response["meta"], list) - assert len(response["meta"]) == 1 - assert [isinstance(reply, dict) for reply in response["meta"]] - - @pytest.mark.skipif( - (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), - reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", - ) - @pytest.mark.integration - def test_run_bloomz(self): - component = SagemakerGenerator( - model="jumpstart-dft-hf-textgeneration-bloomz-1b1", generation_kwargs={"max_new_tokens": 10} - ) - component.warm_up() - response = component.run("What's Natural Language Processing?") - # check that the component returns the correct ChatMessage response - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, str) for reply in response["replies"]] - - # Coarse check: assuming no more than 4 chars per token. In any case it - # will fail if the `max_new_tokens` parameter is not respected, as the - # default is either 256 or 1024 - assert all(len(reply) <= 40 for reply in response["replies"]) - - assert "meta" in response - assert isinstance(response["meta"], list) - assert len(response["meta"]) == 1 - assert [isinstance(reply, dict) for reply in response["meta"]] + # @pytest.mark.skipif( + # (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), + # reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", + # ) + # @pytest.mark.integration + # def test_run_falcon(self): + # component = SagemakerGenerator( + # model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16", generation_kwargs={"max_new_tokens": 10} + # ) + # # component.warm_up() + # response = component.run("What's Natural Language Processing?") + # + # # check that the component returns the correct ChatMessage response + # assert isinstance(response, dict) + # assert "replies" in response + # assert isinstance(response["replies"], list) + # assert len(response["replies"]) == 1 + # assert [isinstance(reply, str) for reply in response["replies"]] + # + # # Coarse check: assuming no more than 4 chars per token. In any case it + # # will fail if the `max_new_tokens` parameter is not respected, as the + # # default is either 256 or 1024 + # assert all(len(reply) <= 40 for reply in response["replies"]) + # + # assert "meta" in response + # assert isinstance(response["meta"], list) + # assert len(response["meta"]) == 1 + # assert [isinstance(reply, dict) for reply in response["meta"]] + # + # @pytest.mark.skipif( + # (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), + # reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", + # ) + # @pytest.mark.integration + # def test_run_llama2(self): + # component = SagemakerGenerator( + # model="jumpstart-dft-meta-textgenerationneuron-llama-2-7b", + # generation_kwargs={"max_new_tokens": 10}, + # aws_custom_attributes={"accept_eula": True}, + # ) + # component.warm_up() + # response = component.run("What's Natural Language Processing?") + # + # # check that the component returns the correct ChatMessage response + # assert isinstance(response, dict) + # assert "replies" in response + # assert isinstance(response["replies"], list) + # assert len(response["replies"]) == 1 + # assert [isinstance(reply, str) for reply in response["replies"]] + # + # # Coarse check: assuming no more than 4 chars per token. In any case it + # # will fail if the `max_new_tokens` parameter is not respected, as the + # # default is either 256 or 1024 + # assert all(len(reply) <= 40 for reply in response["replies"]) + # + # assert "meta" in response + # assert isinstance(response["meta"], list) + # assert len(response["meta"]) == 1 + # assert [isinstance(reply, dict) for reply in response["meta"]] + # + # @pytest.mark.skipif( + # (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), + # reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", + # ) + # @pytest.mark.integration + # def test_run_bloomz(self): + # component = SagemakerGenerator( + # model="jumpstart-dft-hf-textgeneration-bloomz-1b1", generation_kwargs={"max_new_tokens": 10} + # ) + # component.warm_up() + # response = component.run("What's Natural Language Processing?") + # + # # check that the component returns the correct ChatMessage response + # assert isinstance(response, dict) + # assert "replies" in response + # assert isinstance(response["replies"], list) + # assert len(response["replies"]) == 1 + # assert [isinstance(reply, str) for reply in response["replies"]] + # + # # Coarse check: assuming no more than 4 chars per token. In any case it + # # will fail if the `max_new_tokens` parameter is not respected, as the + # # default is either 256 or 1024 + # assert all(len(reply) <= 40 for reply in response["replies"]) + # + # assert "meta" in response + # assert isinstance(response["meta"], list) + # assert len(response["meta"]) == 1 + # assert [isinstance(reply, dict) for reply in response["meta"]] From bd298b85b81e1c7f743cbe772402331464e21340 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 16 Feb 2024 10:22:54 +0100 Subject: [PATCH 06/14] adding mypy overrides for boto --- .../amazon_sagemaker/tests/test_sagemaker.py | 176 +++++++++--------- 1 file changed, 88 insertions(+), 88 deletions(-) diff --git a/integrations/amazon_sagemaker/tests/test_sagemaker.py b/integrations/amazon_sagemaker/tests/test_sagemaker.py index 93c38a6bb..e69f25175 100644 --- a/integrations/amazon_sagemaker/tests/test_sagemaker.py +++ b/integrations/amazon_sagemaker/tests/test_sagemaker.py @@ -132,91 +132,91 @@ def test_run_with_single_dictionary(self, monkeypatch): assert response["meta"][0]["other"] == "metadata" - # @pytest.mark.skipif( - # (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), - # reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", - # ) - # @pytest.mark.integration - # def test_run_falcon(self): - # component = SagemakerGenerator( - # model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16", generation_kwargs={"max_new_tokens": 10} - # ) - # # component.warm_up() - # response = component.run("What's Natural Language Processing?") - # - # # check that the component returns the correct ChatMessage response - # assert isinstance(response, dict) - # assert "replies" in response - # assert isinstance(response["replies"], list) - # assert len(response["replies"]) == 1 - # assert [isinstance(reply, str) for reply in response["replies"]] - # - # # Coarse check: assuming no more than 4 chars per token. In any case it - # # will fail if the `max_new_tokens` parameter is not respected, as the - # # default is either 256 or 1024 - # assert all(len(reply) <= 40 for reply in response["replies"]) - # - # assert "meta" in response - # assert isinstance(response["meta"], list) - # assert len(response["meta"]) == 1 - # assert [isinstance(reply, dict) for reply in response["meta"]] - # - # @pytest.mark.skipif( - # (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), - # reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", - # ) - # @pytest.mark.integration - # def test_run_llama2(self): - # component = SagemakerGenerator( - # model="jumpstart-dft-meta-textgenerationneuron-llama-2-7b", - # generation_kwargs={"max_new_tokens": 10}, - # aws_custom_attributes={"accept_eula": True}, - # ) - # component.warm_up() - # response = component.run("What's Natural Language Processing?") - # - # # check that the component returns the correct ChatMessage response - # assert isinstance(response, dict) - # assert "replies" in response - # assert isinstance(response["replies"], list) - # assert len(response["replies"]) == 1 - # assert [isinstance(reply, str) for reply in response["replies"]] - # - # # Coarse check: assuming no more than 4 chars per token. In any case it - # # will fail if the `max_new_tokens` parameter is not respected, as the - # # default is either 256 or 1024 - # assert all(len(reply) <= 40 for reply in response["replies"]) - # - # assert "meta" in response - # assert isinstance(response["meta"], list) - # assert len(response["meta"]) == 1 - # assert [isinstance(reply, dict) for reply in response["meta"]] - # - # @pytest.mark.skipif( - # (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), - # reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", - # ) - # @pytest.mark.integration - # def test_run_bloomz(self): - # component = SagemakerGenerator( - # model="jumpstart-dft-hf-textgeneration-bloomz-1b1", generation_kwargs={"max_new_tokens": 10} - # ) - # component.warm_up() - # response = component.run("What's Natural Language Processing?") - # - # # check that the component returns the correct ChatMessage response - # assert isinstance(response, dict) - # assert "replies" in response - # assert isinstance(response["replies"], list) - # assert len(response["replies"]) == 1 - # assert [isinstance(reply, str) for reply in response["replies"]] - # - # # Coarse check: assuming no more than 4 chars per token. In any case it - # # will fail if the `max_new_tokens` parameter is not respected, as the - # # default is either 256 or 1024 - # assert all(len(reply) <= 40 for reply in response["replies"]) - # - # assert "meta" in response - # assert isinstance(response["meta"], list) - # assert len(response["meta"]) == 1 - # assert [isinstance(reply, dict) for reply in response["meta"]] + @pytest.mark.skipif( + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), + reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", + ) + @pytest.mark.integration + def test_run_falcon(self): + component = SagemakerGenerator( + model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16", generation_kwargs={"max_new_tokens": 10} + ) + # component.warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + # Coarse check: assuming no more than 4 chars per token. In any case it + # will fail if the `max_new_tokens` parameter is not respected, as the + # default is either 256 or 1024 + assert all(len(reply) <= 40 for reply in response["replies"]) + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + + @pytest.mark.skipif( + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), + reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", + ) + @pytest.mark.integration + def test_run_llama2(self): + component = SagemakerGenerator( + model="jumpstart-dft-meta-textgenerationneuron-llama-2-7b", + generation_kwargs={"max_new_tokens": 10}, + aws_custom_attributes={"accept_eula": True}, + ) + component.warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + # Coarse check: assuming no more than 4 chars per token. In any case it + # will fail if the `max_new_tokens` parameter is not respected, as the + # default is either 256 or 1024 + assert all(len(reply) <= 40 for reply in response["replies"]) + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + + @pytest.mark.skipif( + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), + reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", + ) + @pytest.mark.integration + def test_run_bloomz(self): + component = SagemakerGenerator( + model="jumpstart-dft-hf-textgeneration-bloomz-1b1", generation_kwargs={"max_new_tokens": 10} + ) + component.warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + # Coarse check: assuming no more than 4 chars per token. In any case it + # will fail if the `max_new_tokens` parameter is not respected, as the + # default is either 256 or 1024 + assert all(len(reply) <= 40 for reply in response["replies"]) + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] From ec307955f955ef8707dc495e53c9282bd1b14024 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 16 Feb 2024 11:51:20 +0100 Subject: [PATCH 07/14] wip: tests --- .../generators/amazon_sagemaker/sagemaker.py | 4 - .../amazon_sagemaker/tests/test_sagemaker.py | 310 +++++++++--------- 2 files changed, 160 insertions(+), 154 deletions(-) diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py index d983aec87..461a64b6d 100644 --- a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py @@ -214,10 +214,6 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): :return: A list of strings containing the generated responses and a list of dictionaries containing the metadata for each response. """ - if self.client is None: - msg = "SageMaker Inference client is not initialized. Please call warm_up() first." - raise ValueError(msg) - generation_kwargs = generation_kwargs or self.generation_kwargs custom_attributes = ";".join( f"{k}={str(v).lower() if isinstance(v, bool) else str(v)}" for k, v in self.aws_custom_attributes.items() diff --git a/integrations/amazon_sagemaker/tests/test_sagemaker.py b/integrations/amazon_sagemaker/tests/test_sagemaker.py index e69f25175..b60f33c00 100644 --- a/integrations/amazon_sagemaker/tests/test_sagemaker.py +++ b/integrations/amazon_sagemaker/tests/test_sagemaker.py @@ -1,10 +1,13 @@ -import os import pytest from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator from haystack.utils.auth import EnvVarSecret -from unittest.mock import Mock +from botocore.exceptions import BotoCoreError + +from haystack_integrations.components.generators.amazon_sagemaker.errors import AWSConfigurationError + +from unittest.mock import Mock, patch mocked_dict = { "type": "haystack_integrations.components.generators.amazon_sagemaker.sagemaker.SagemakerGenerator", @@ -55,16 +58,12 @@ def test_default_constructor(mock_boto3_session): Test that the default constructor sets the correct values """ - generator = SagemakerGenerator( - model="test-model", - ) - + generator = SagemakerGenerator(model="test-model") assert generator.generation_kwargs == {"max_new_tokens": 1024} assert generator.model == "test-model" # assert mocked boto3 client called exactly once mock_boto3_session.assert_called_once() - assert generator.client is not None # assert mocked boto3 client was called with the correct parameters @@ -77,146 +76,157 @@ def test_default_constructor(mock_boto3_session): ) -class TestSagemakerGenerator: - - def test_run_with_list_of_dictionaries(self, monkeypatch): - monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") - monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") - client_mock = Mock() - client_mock.invoke_endpoint.return_value = { - "Body": Mock(read=lambda: b'[{"generated_text": "test-reply", "other": "metadata"}]') - } - - component = SagemakerGenerator(model="test-model") - component.client = client_mock # Simulate warm_up() - response = component.run("What's Natural Language Processing?") - - # check that the component returns the correct ChatMessage response - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, str) for reply in response["replies"]] - assert "test-reply" in response["replies"][0] - - assert "meta" in response - assert isinstance(response["meta"], list) - assert len(response["meta"]) == 1 - assert [isinstance(reply, dict) for reply in response["meta"]] - assert response["meta"][0]["other"] == "metadata" - - def test_run_with_single_dictionary(self, monkeypatch): - monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") - monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") - client_mock = Mock() - client_mock.invoke_endpoint.return_value = { - "Body": Mock(read=lambda: b'{"generation": "test-reply", "other": "metadata"}') - } - - component = SagemakerGenerator(model="test-model") - component.client = client_mock # Simulate warm_up() - response = component.run("What's Natural Language Processing?") - - # check that the component returns the correct ChatMessage response - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, str) for reply in response["replies"]] - assert "test-reply" in response["replies"][0] - - assert "meta" in response - assert isinstance(response["meta"], list) - assert len(response["meta"]) == 1 - assert [isinstance(reply, dict) for reply in response["meta"]] - assert response["meta"][0]["other"] == "metadata" - - - @pytest.mark.skipif( - (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), - reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", - ) - @pytest.mark.integration - def test_run_falcon(self): - component = SagemakerGenerator( - model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16", generation_kwargs={"max_new_tokens": 10} - ) - # component.warm_up() - response = component.run("What's Natural Language Processing?") - - # check that the component returns the correct ChatMessage response - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, str) for reply in response["replies"]] - - # Coarse check: assuming no more than 4 chars per token. In any case it - # will fail if the `max_new_tokens` parameter is not respected, as the - # default is either 256 or 1024 - assert all(len(reply) <= 40 for reply in response["replies"]) - - assert "meta" in response - assert isinstance(response["meta"], list) - assert len(response["meta"]) == 1 - assert [isinstance(reply, dict) for reply in response["meta"]] - - @pytest.mark.skipif( - (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), - reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", - ) - @pytest.mark.integration - def test_run_llama2(self): - component = SagemakerGenerator( - model="jumpstart-dft-meta-textgenerationneuron-llama-2-7b", - generation_kwargs={"max_new_tokens": 10}, - aws_custom_attributes={"accept_eula": True}, - ) - component.warm_up() - response = component.run("What's Natural Language Processing?") - - # check that the component returns the correct ChatMessage response - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, str) for reply in response["replies"]] - - # Coarse check: assuming no more than 4 chars per token. In any case it - # will fail if the `max_new_tokens` parameter is not respected, as the - # default is either 256 or 1024 - assert all(len(reply) <= 40 for reply in response["replies"]) - - assert "meta" in response - assert isinstance(response["meta"], list) - assert len(response["meta"]) == 1 - assert [isinstance(reply, dict) for reply in response["meta"]] - - @pytest.mark.skipif( - (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), - reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", - ) - @pytest.mark.integration - def test_run_bloomz(self): - component = SagemakerGenerator( - model="jumpstart-dft-hf-textgeneration-bloomz-1b1", generation_kwargs={"max_new_tokens": 10} - ) - component.warm_up() - response = component.run("What's Natural Language Processing?") - - # check that the component returns the correct ChatMessage response - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, str) for reply in response["replies"]] - - # Coarse check: assuming no more than 4 chars per token. In any case it - # will fail if the `max_new_tokens` parameter is not respected, as the - # default is either 256 or 1024 - assert all(len(reply) <= 40 for reply in response["replies"]) - - assert "meta" in response - assert isinstance(response["meta"], list) - assert len(response["meta"]) == 1 - assert [isinstance(reply, dict) for reply in response["meta"]] +@pytest.mark.unit +@pytest.mark.usefixtures("set_env_variables", "mock_boto3_session") +def test_init_raises_boto_error(): + with patch("boto3.Session") as mock_boto3_session: + mock_boto3_session.side_effect = BotoCoreError() + with pytest.raises( + AWSConfigurationError, + match="Could not connect to SageMaker Inference Endpoint 'test-model'." + "Make sure the Endpoint exists and AWS environment is configured.", + ): + SagemakerGenerator(model="test-model") + + +@pytest.mark.unit +@pytest.mark.usefixtures("set_env_variables", "mock_boto3_session") +def test_run_with_list_of_dictionaries(): + client_mock = Mock() + client_mock.invoke_endpoint.return_value = { + "Body": Mock(read=lambda: b'[{"generated_text": "test-reply", "other": "metadata"}]') + } + component = SagemakerGenerator(model="test-model") + component.client = client_mock + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + assert "test-reply" in response["replies"][0] + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + assert response["meta"][0]["other"] == "metadata" + + +@pytest.mark.unit +@pytest.mark.usefixtures("set_env_variables", "mock_boto3_session") +def test_run_with_single_dictionary(): + client_mock = Mock() + client_mock.invoke_endpoint.return_value = { + "Body": Mock(read=lambda: b'{"generation": "test-reply", "other": "metadata"}') + } + + component = SagemakerGenerator(model="test-model") + component.client = client_mock + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + assert "test-reply" in response["replies"][0] + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + assert response["meta"][0]["other"] == "metadata" + + +# @pytest.mark.skipif( +# (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), +# reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", +# ) +# @pytest.mark.integration +# def test_run_falcon(self): +# component = SagemakerGenerator( +# model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16", generation_kwargs={"max_new_tokens": 10} +# ) +# # component.warm_up() +# response = component.run("What's Natural Language Processing?") +# +# # check that the component returns the correct ChatMessage response +# assert isinstance(response, dict) +# assert "replies" in response +# assert isinstance(response["replies"], list) +# assert len(response["replies"]) == 1 +# assert [isinstance(reply, str) for reply in response["replies"]] +# +# # Coarse check: assuming no more than 4 chars per token. In any case it +# # will fail if the `max_new_tokens` parameter is not respected, as the +# # default is either 256 or 1024 +# assert all(len(reply) <= 40 for reply in response["replies"]) +# +# assert "meta" in response +# assert isinstance(response["meta"], list) +# assert len(response["meta"]) == 1 +# assert [isinstance(reply, dict) for reply in response["meta"]] +# +# @pytest.mark.skipif( +# (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), +# reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", +# ) +# @pytest.mark.integration +# def test_run_llama2(self): +# component = SagemakerGenerator( +# model="jumpstart-dft-meta-textgenerationneuron-llama-2-7b", +# generation_kwargs={"max_new_tokens": 10}, +# aws_custom_attributes={"accept_eula": True}, +# ) +# component.warm_up() +# response = component.run("What's Natural Language Processing?") +# +# # check that the component returns the correct ChatMessage response +# assert isinstance(response, dict) +# assert "replies" in response +# assert isinstance(response["replies"], list) +# assert len(response["replies"]) == 1 +# assert [isinstance(reply, str) for reply in response["replies"]] +# +# # Coarse check: assuming no more than 4 chars per token. In any case it +# # will fail if the `max_new_tokens` parameter is not respected, as the +# # default is either 256 or 1024 +# assert all(len(reply) <= 40 for reply in response["replies"]) +# +# assert "meta" in response +# assert isinstance(response["meta"], list) +# assert len(response["meta"]) == 1 +# assert [isinstance(reply, dict) for reply in response["meta"]] +# +# @pytest.mark.skipif( +# (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), +# reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", +# ) +# @pytest.mark.integration +# def test_run_bloomz(self): +# component = SagemakerGenerator( +# model="jumpstart-dft-hf-textgeneration-bloomz-1b1", generation_kwargs={"max_new_tokens": 10} +# ) +# component.warm_up() +# response = component.run("What's Natural Language Processing?") +# +# # check that the component returns the correct ChatMessage response +# assert isinstance(response, dict) +# assert "replies" in response +# assert isinstance(response["replies"], list) +# assert len(response["replies"]) == 1 +# assert [isinstance(reply, str) for reply in response["replies"]] +# +# # Coarse check: assuming no more than 4 chars per token. In any case it +# # will fail if the `max_new_tokens` parameter is not respected, as the +# # default is either 256 or 1024 +# assert all(len(reply) <= 40 for reply in response["replies"]) +# +# assert "meta" in response +# assert isinstance(response["meta"], list) +# assert len(response["meta"]) == 1 +# assert [isinstance(reply, dict) for reply in response["meta"]] From 28701cbf8efdccc573a278705c9e13589bf55e87 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 16 Feb 2024 12:04:27 +0100 Subject: [PATCH 08/14] finishing tests --- .../amazon_sagemaker/tests/test_sagemaker.py | 133 +++++------------- 1 file changed, 38 insertions(+), 95 deletions(-) diff --git a/integrations/amazon_sagemaker/tests/test_sagemaker.py b/integrations/amazon_sagemaker/tests/test_sagemaker.py index b60f33c00..ea108c794 100644 --- a/integrations/amazon_sagemaker/tests/test_sagemaker.py +++ b/integrations/amazon_sagemaker/tests/test_sagemaker.py @@ -1,14 +1,12 @@ -import pytest - -from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator -from haystack.utils.auth import EnvVarSecret +import os +from unittest.mock import Mock, patch +import pytest from botocore.exceptions import BotoCoreError - +from haystack.utils.auth import EnvVarSecret +from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator from haystack_integrations.components.generators.amazon_sagemaker.errors import AWSConfigurationError -from unittest.mock import Mock, patch - mocked_dict = { "type": "haystack_integrations.components.generators.amazon_sagemaker.sagemaker.SagemakerGenerator", "init_parameters": { @@ -142,91 +140,36 @@ def test_run_with_single_dictionary(): assert response["meta"][0]["other"] == "metadata" -# @pytest.mark.skipif( -# (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), -# reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", -# ) -# @pytest.mark.integration -# def test_run_falcon(self): -# component = SagemakerGenerator( -# model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16", generation_kwargs={"max_new_tokens": 10} -# ) -# # component.warm_up() -# response = component.run("What's Natural Language Processing?") -# -# # check that the component returns the correct ChatMessage response -# assert isinstance(response, dict) -# assert "replies" in response -# assert isinstance(response["replies"], list) -# assert len(response["replies"]) == 1 -# assert [isinstance(reply, str) for reply in response["replies"]] -# -# # Coarse check: assuming no more than 4 chars per token. In any case it -# # will fail if the `max_new_tokens` parameter is not respected, as the -# # default is either 256 or 1024 -# assert all(len(reply) <= 40 for reply in response["replies"]) -# -# assert "meta" in response -# assert isinstance(response["meta"], list) -# assert len(response["meta"]) == 1 -# assert [isinstance(reply, dict) for reply in response["meta"]] -# -# @pytest.mark.skipif( -# (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), -# reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", -# ) -# @pytest.mark.integration -# def test_run_llama2(self): -# component = SagemakerGenerator( -# model="jumpstart-dft-meta-textgenerationneuron-llama-2-7b", -# generation_kwargs={"max_new_tokens": 10}, -# aws_custom_attributes={"accept_eula": True}, -# ) -# component.warm_up() -# response = component.run("What's Natural Language Processing?") -# -# # check that the component returns the correct ChatMessage response -# assert isinstance(response, dict) -# assert "replies" in response -# assert isinstance(response["replies"], list) -# assert len(response["replies"]) == 1 -# assert [isinstance(reply, str) for reply in response["replies"]] -# -# # Coarse check: assuming no more than 4 chars per token. In any case it -# # will fail if the `max_new_tokens` parameter is not respected, as the -# # default is either 256 or 1024 -# assert all(len(reply) <= 40 for reply in response["replies"]) -# -# assert "meta" in response -# assert isinstance(response["meta"], list) -# assert len(response["meta"]) == 1 -# assert [isinstance(reply, dict) for reply in response["meta"]] -# -# @pytest.mark.skipif( -# (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), -# reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", -# ) -# @pytest.mark.integration -# def test_run_bloomz(self): -# component = SagemakerGenerator( -# model="jumpstart-dft-hf-textgeneration-bloomz-1b1", generation_kwargs={"max_new_tokens": 10} -# ) -# component.warm_up() -# response = component.run("What's Natural Language Processing?") -# -# # check that the component returns the correct ChatMessage response -# assert isinstance(response, dict) -# assert "replies" in response -# assert isinstance(response["replies"], list) -# assert len(response["replies"]) == 1 -# assert [isinstance(reply, str) for reply in response["replies"]] -# -# # Coarse check: assuming no more than 4 chars per token. In any case it -# # will fail if the `max_new_tokens` parameter is not respected, as the -# # default is either 256 or 1024 -# assert all(len(reply) <= 40 for reply in response["replies"]) -# -# assert "meta" in response -# assert isinstance(response["meta"], list) -# assert len(response["meta"]) == 1 -# assert [isinstance(reply, dict) for reply in response["meta"]] +@pytest.mark.skipif( + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), + reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", +) +@pytest.mark.integration +@pytest.mark.parametrize( + "model", + [ + "jumpstart-dft-hf-llm-falcon-7b-instruct-bf16", + "jumpstart-dft-meta-textgenerationneuron-llama-2-7b", + "jumpstart-dft-hf-textgeneration-bloomz-1b1", + ], +) +def test_run(model: str): + component = SagemakerGenerator(model=model, generation_kwargs={"max_new_tokens": 10}) + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + # Coarse check: assuming no more than 4 chars per token. In any case it + # will fail if the `max_new_tokens` parameter is not respected, as the + # default is either 256 or 1024 + assert all(len(reply) <= 40 for reply in response["replies"]) + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] From 97776a5759ef000f90ee47f2deecde187f3a471b Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 16 Feb 2024 12:20:12 +0100 Subject: [PATCH 09/14] nit --- .../amazon_bedrock/chat/chat_generator.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index e6edea34f..555792173 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -62,11 +62,11 @@ class AmazonBedrockChatGenerator: def __init__( self, model: str, - aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False),# noqa: B008 + aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008 aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 ["AWS_SECRET_ACCESS_KEY"], strict=False ), - aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008 + aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008 aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 generation_kwargs: Optional[Dict[str, Any]] = None, @@ -85,14 +85,12 @@ def __init__( :param model: The model to use for generation. The model must be available in Amazon Bedrock. The model has to be specified in the format outlined in the Amazon Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html). - :param aws_access_key_id: AWS access key ID. :param aws_secret_access_key: AWS secret access key. :param aws_session_token: AWS session token. :param aws_region_name: AWS region name. :param aws_profile_name: AWS profile name. :param generation_kwargs: Additional generation keyword arguments passed to the model. The defined keyword - parameters are specific to a specific model and can be found in the model's documentation. For example, the Anthropic Claude generation parameters can be found [here](https://docs.anthropic.com/claude/reference/complete_post). :param stop_words: A list of stop words that stop model generation when encountered. They can be provided via @@ -126,11 +124,11 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: try: session = self.get_aws_session( - aws_access_key_id=resolve_secret(aws_access_key_id), - aws_secret_access_key=resolve_secret(aws_secret_access_key), - aws_session_token=resolve_secret(aws_session_token), - aws_region_name=resolve_secret(aws_region_name), - aws_profile_name=resolve_secret(aws_profile_name), + aws_access_key_id=resolve_secret(aws_access_key_id), + aws_secret_access_key=resolve_secret(aws_secret_access_key), + aws_session_token=resolve_secret(aws_session_token), + aws_region_name=resolve_secret(aws_region_name), + aws_profile_name=resolve_secret(aws_profile_name), ) self.client = session.client("bedrock-runtime") except Exception as exception: From d6053e16c8c46147059f8b5d62eccdb71b87d2ae Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 16 Feb 2024 12:21:03 +0100 Subject: [PATCH 10/14] nit --- .../generators/amazon_bedrock/chat/chat_generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 555792173..6ce671e68 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -64,7 +64,7 @@ def __init__( model: str, aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008 aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 - ["AWS_SECRET_ACCESS_KEY"], strict=False + ["AWS_SECRET_ACCESS_KEY"], strict=False ), aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008 aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 @@ -129,7 +129,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: aws_session_token=resolve_secret(aws_session_token), aws_region_name=resolve_secret(aws_region_name), aws_profile_name=resolve_secret(aws_profile_name), - ) + ) self.client = session.client("bedrock-runtime") except Exception as exception: msg = ( From 59f7edd2980dac9a42ce9a2795c15686762d09ba Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 16 Feb 2024 12:21:36 +0100 Subject: [PATCH 11/14] nit --- integrations/amazon_sagemaker/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/amazon_sagemaker/pyproject.toml b/integrations/amazon_sagemaker/pyproject.toml index 9c40ac683..8a6d6d039 100644 --- a/integrations/amazon_sagemaker/pyproject.toml +++ b/integrations/amazon_sagemaker/pyproject.toml @@ -188,4 +188,4 @@ markers = [ "embedders: embedders tests", "generators: generators tests", ] -log_cli = true \ No newline at end of file +log_cli = true From 90f11eb7cdbe892c37c5534088597abdd139b6c0 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 16 Feb 2024 13:17:43 +0100 Subject: [PATCH 12/14] updating service name for boto3 session client --- .../components/generators/amazon_sagemaker/sagemaker.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py index 461a64b6d..73d0b43c9 100644 --- a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py @@ -42,13 +42,16 @@ class SagemakerGenerator: ```bash export AWS_ACCESS_KEY_ID= export AWS_SECRET_ACCESS_KEY= + + export AWS_SECRET_ACCESS_KEY= + ``` (Note: you may also need to set the session token and region name, depending on your AWS configuration) Then you can use the generator as follows: ```python from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator - generator = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16") + generator = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-bf16") response = generator.run("What's Natural Language Processing? Be brief.") print(response) ``` @@ -124,7 +127,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: aws_region_name=resolve_secret(aws_region_name), aws_profile_name=resolve_secret(aws_profile_name), ) - self.client = session.client("sagemaker-runtime") + self.client = session.client("runtime.sagemaker") except Exception as e: msg = ( f"Could not connect to SageMaker Inference Endpoint '{self.model}'." From 0a6b7413471fa813cd50fe7114c3f055a1a24cc8 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 16 Feb 2024 17:25:24 +0100 Subject: [PATCH 13/14] attending PR comments --- .../generators/amazon_sagemaker/sagemaker.py | 40 ++--------- .../amazon_sagemaker/tests/test_sagemaker.py | 68 ++++++++++--------- 2 files changed, 41 insertions(+), 67 deletions(-) diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py index 73d0b43c9..39e347f86 100644 --- a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py @@ -17,14 +17,6 @@ logger = logging.getLogger(__name__) -AWS_CONFIGURATION_KEYS = [ - "aws_access_key_id", - "aws_secret_access_key", - "aws_session_token", - "aws_region_name", - "aws_profile_name", -] - MODEL_NOT_READY_STATUS_CODE = 429 @@ -38,16 +30,7 @@ class SagemakerGenerator: **Example:** - First export your AWS credentials as environment variables: - ```bash - export AWS_ACCESS_KEY_ID= - export AWS_SECRET_ACCESS_KEY= - - export AWS_SECRET_ACCESS_KEY= - - ``` - (Note: you may also need to set the session token and region name, depending on your AWS configuration) - + Make sure your AWS credentials are set up correctly. You can use environment variables or a shared credentials file. Then you can use the generator as follows: ```python from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator @@ -81,13 +64,6 @@ def __init__( Instantiates the session with SageMaker. :param model: The name for SageMaker Model Endpoint. - - :param aws_access_key_id: The name of the env var where the AWS access key ID is stored. - :param aws_secret_access_key: The name of the env var where the AWS secret access key is stored. - :param aws_session_token: The name of the env var where the AWS session token is stored. - :param aws_region_name: The name of the env var where the AWS region name is stored. - :param aws_profile_name: The name of the env var where the AWS profile name is stored. - :param aws_custom_attributes: Custom attributes to be passed to SageMaker, for example `{"accept_eula": True}` in case of Llama-2 models. @@ -120,7 +96,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None try: - session = self.get_aws_session( + session = self._get_aws_session( aws_access_key_id=resolve_secret(aws_access_key_id), aws_secret_access_key=resolve_secret(aws_secret_access_key), aws_session_token=resolve_secret(aws_session_token), @@ -168,15 +144,13 @@ def from_dict(cls, data) -> "SagemakerGenerator": ) return default_from_dict(cls, data) - @classmethod - def get_aws_session( - cls, + @staticmethod + def _get_aws_session( aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_session_token: Optional[str] = None, aws_region_name: Optional[str] = None, aws_profile_name: Optional[str] = None, - **kwargs, ): """ Creates an AWS Session with the given parameters. @@ -187,8 +161,7 @@ def get_aws_session( :param aws_session_token: AWS session token. :param aws_region_name: AWS region name. :param aws_profile_name: AWS profile name. - :param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen. - See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html. + :raises AWSConfigurationError: If the provided AWS credentials are invalid. :return: The created AWS session. """ @@ -201,8 +174,7 @@ def get_aws_session( profile_name=aws_profile_name, ) except BotoCoreError as e: - provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS} - msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}" + msg = f"Failed to initialize the session with provided AWS credentials: {e}." raise AWSConfigurationError(msg) from e @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) diff --git a/integrations/amazon_sagemaker/tests/test_sagemaker.py b/integrations/amazon_sagemaker/tests/test_sagemaker.py index ea108c794..5b67e3835 100644 --- a/integrations/amazon_sagemaker/tests/test_sagemaker.py +++ b/integrations/amazon_sagemaker/tests/test_sagemaker.py @@ -7,28 +7,26 @@ from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator from haystack_integrations.components.generators.amazon_sagemaker.errors import AWSConfigurationError -mocked_dict = { - "type": "haystack_integrations.components.generators.amazon_sagemaker.sagemaker.SagemakerGenerator", - "init_parameters": { - "model": "model", - "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, - "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, - "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, - "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, - "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, - "aws_custom_attributes": {"accept_eula": True}, - "generation_kwargs": {"max_new_tokens": 10}, - }, -} - - -@pytest.mark.unit -@pytest.mark.usefixtures("set_env_variables", "mock_boto3_session") -def test_to_dict(): + +def test_to_dict(set_env_variables, mock_boto3_session): # noqa: ARG001 """ Test that the to_dict method returns the correct dictionary without aws credentials """ + mocked_dict = { + "type": "haystack_integrations.components.generators.amazon_sagemaker.sagemaker.SagemakerGenerator", + "init_parameters": { + "model": "model", + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_custom_attributes": {"accept_eula": True}, + "generation_kwargs": {"max_new_tokens": 10}, + }, + } + generator = SagemakerGenerator( model="model", generation_kwargs={"max_new_tokens": 10}, @@ -37,21 +35,31 @@ def test_to_dict(): assert generator.to_dict() == mocked_dict -@pytest.mark.unit -@pytest.mark.usefixtures("set_env_variables", "mock_boto3_session") -def test_from_dict(): +def test_from_dict(set_env_variables, mock_boto3_session): # noqa: ARG001 """ Test that the from_dict method returns the correct object """ + mocked_dict = { + "type": "haystack_integrations.components.generators.amazon_sagemaker.sagemaker.SagemakerGenerator", + "init_parameters": { + "model": "model", + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_custom_attributes": {"accept_eula": True}, + "generation_kwargs": {"max_new_tokens": 10}, + }, + } + generator = SagemakerGenerator.from_dict(mocked_dict) assert generator.model == "model" assert isinstance(generator.aws_access_key_id, EnvVarSecret) -@pytest.mark.unit -@pytest.mark.usefixtures("set_env_variables", "mock_boto3_session") -def test_default_constructor(mock_boto3_session): +def test_default_constructor(set_env_variables, mock_boto3_session): # noqa: ARG001 """ Test that the default constructor sets the correct values """ @@ -74,9 +82,7 @@ def test_default_constructor(mock_boto3_session): ) -@pytest.mark.unit -@pytest.mark.usefixtures("set_env_variables", "mock_boto3_session") -def test_init_raises_boto_error(): +def test_init_raises_boto_error(set_env_variables, mock_boto3_session): # noqa: ARG001 with patch("boto3.Session") as mock_boto3_session: mock_boto3_session.side_effect = BotoCoreError() with pytest.raises( @@ -87,9 +93,7 @@ def test_init_raises_boto_error(): SagemakerGenerator(model="test-model") -@pytest.mark.unit -@pytest.mark.usefixtures("set_env_variables", "mock_boto3_session") -def test_run_with_list_of_dictionaries(): +def test_run_with_list_of_dictionaries(set_env_variables, mock_boto3_session): # noqa: ARG001 client_mock = Mock() client_mock.invoke_endpoint.return_value = { "Body": Mock(read=lambda: b'[{"generated_text": "test-reply", "other": "metadata"}]') @@ -113,9 +117,7 @@ def test_run_with_list_of_dictionaries(): assert response["meta"][0]["other"] == "metadata" -@pytest.mark.unit -@pytest.mark.usefixtures("set_env_variables", "mock_boto3_session") -def test_run_with_single_dictionary(): +def test_run_with_single_dictionary(set_env_variables, mock_boto3_session): # noqa: ARG001 client_mock = Mock() client_mock.invoke_endpoint.return_value = { "Body": Mock(read=lambda: b'{"generation": "test-reply", "other": "metadata"}') From e1fa482f9790eb4c3669d926d66b058a29283cab Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 16 Feb 2024 18:03:59 +0100 Subject: [PATCH 14/14] attending PR comments --- .../components/generators/amazon_sagemaker/sagemaker.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py index 39e347f86..4eacf33a4 100644 --- a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py @@ -63,6 +63,11 @@ def __init__( """ Instantiates the session with SageMaker. + :param aws_access_key_id: The `Secret` for AWS access key ID. + :param aws_secret_access_key: The `Secret` for AWS secret access key. + :param aws_session_token: The `Secret` for AWS session token. + :param aws_region_name: The `Secret` for AWS region name. If not provided, the default region will be used. + :param aws_profile_name: The `Secret` for AWS profile name. If not provided, the default profile will be used. :param model: The name for SageMaker Model Endpoint. :param aws_custom_attributes: Custom attributes to be passed to SageMaker, for example `{"accept_eula": True}` in case of Llama-2 models.