Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adopt Secret to AmazonSagemaker #432

Merged
merged 17 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions integrations/amazon_sagemaker/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,24 +154,38 @@ ban-relative-imports = "parents"
"tests/**/*" = ["PLR2004", "S101", "TID252"]

[tool.coverage.run]
source = ["haystack_integrations"]
branch = true
parallel = true

[tool.coverage.paths]
amazon_sagemaker_haystack = ["src"]
tests = ["tests"]

[tool.coverage.report]
omit = ["*/tests/*", "*/__init__.py"]
show_missing=true
exclude_lines = [
"no cov",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
]

[[tool.mypy.overrides]]
module = [
"botocore.*",
"boto3.*",
"haystack.*",
"haystack_integrations.*",
"pytest.*",
"numpy.*",
]
ignore_missing_imports = true
ignore_missing_imports = true



[tool.pytest.ini_options]
addopts = "--strict-markers"
markers = [
"unit: unit tests",
"integration: integration tests",
"embedders: embedders tests",
"generators: generators tests",
]
log_cli = true
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
import json
import logging
import os
from typing import Any, ClassVar, Dict, List, Optional

import boto3
import requests
from botocore.exceptions import BotoCoreError
from haystack import component, default_from_dict, default_to_dict
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace

from haystack_integrations.components.generators.amazon_sagemaker.errors import (
AWSConfigurationError,
SagemakerInferenceError,
SagemakerNotReadyError,
)

with LazyImport(message="Run 'pip install boto3'") as boto3_import:
import boto3 # type: ignore
from botocore.client import BaseClient # type: ignore


logger = logging.getLogger(__name__)

AWS_CONFIGURATION_KEYS = [
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
"aws_region_name",
"aws_profile_name",
]
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved


MODEL_NOT_READY_STATUS_CODE = 429

Expand All @@ -36,14 +42,16 @@ class SagemakerGenerator:
```bash
export AWS_ACCESS_KEY_ID=<your_access_key_id>
export AWS_SECRET_ACCESS_KEY=<your_secret_access_key>

export AWS_SECRET_ACCESS_KEY=<your_secret_access_key>

davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
```
(Note: you may also need to set the session token and region name, depending on your AWS configuration)

Then you can use the generator as follows:
```python
from haystack.components.generators.sagemaker import SagemakerGenerator
generator = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16")
generator.warm_up()
from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator
generator = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-bf16")
response = generator.run("What's Natural Language Processing? Be brief.")
print(response)
```
Expand All @@ -59,25 +67,30 @@ class SagemakerGenerator:
def __init__(
self,
model: str,
aws_access_key_id_var: str = "AWS_ACCESS_KEY_ID",
aws_secret_access_key_var: str = "AWS_SECRET_ACCESS_KEY",
aws_session_token_var: str = "AWS_SESSION_TOKEN",
aws_region_name_var: str = "AWS_REGION",
aws_profile_name_var: str = "AWS_PROFILE",
aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008
["AWS_SECRET_ACCESS_KEY"], strict=False
),
aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008
aws_custom_attributes: Optional[Dict[str, Any]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Instantiates the session with SageMaker.

:param model: The name for SageMaker Model Endpoint.
:param aws_access_key_id_var: The name of the env var where the AWS access key ID is stored.
:param aws_secret_access_key_var: The name of the env var where the AWS secret access key is stored.
:param aws_session_token_var: The name of the env var where the AWS session token is stored.
:param aws_region_name_var: The name of the env var where the AWS region name is stored.
:param aws_profile_name_var: The name of the env var where the AWS profile name is stored.

:param aws_access_key_id: The name of the env var where the AWS access key ID is stored.
:param aws_secret_access_key: The name of the env var where the AWS secret access key is stored.
:param aws_session_token: The name of the env var where the AWS session token is stored.
:param aws_region_name: The name of the env var where the AWS region name is stored.
:param aws_profile_name: The name of the env var where the AWS profile name is stored.
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved

:param aws_custom_attributes: Custom attributes to be passed to SageMaker, for example `{"accept_eula": True}`
in case of Llama-2 models.

:param generation_kwargs: Additional keyword arguments for text generation. For a list of supported parameters
see your model's documentation page, for example here for HuggingFace models:
https://huggingface.co/blog/sagemaker-huggingface-llm#4-run-inference-and-chat-with-our-model
Expand All @@ -95,21 +108,32 @@ def __init__(
be boolean. The default value for it is `False`.
"""
self.model = model
self.aws_access_key_id_var = aws_access_key_id_var
self.aws_secret_access_key_var = aws_secret_access_key_var
self.aws_session_token_var = aws_session_token_var
self.aws_region_name_var = aws_region_name_var
self.aws_profile_name_var = aws_profile_name_var
self.aws_custom_attributes = aws_custom_attributes or {}
self.generation_kwargs = generation_kwargs or {"max_new_tokens": 1024}
self.client: Optional[BaseClient] = None
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None

if not os.getenv(self.aws_access_key_id_var) or not os.getenv(self.aws_secret_access_key_var):
try:
session = self.get_aws_session(
aws_access_key_id=resolve_secret(aws_access_key_id),
aws_secret_access_key=resolve_secret(aws_secret_access_key),
aws_session_token=resolve_secret(aws_session_token),
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self.client = session.client("runtime.sagemaker")
except Exception as e:
msg = (
f"Please provide AWS credentials via environment variables '{self.aws_access_key_id_var}' and "
f"'{self.aws_secret_access_key_var}'."
f"Could not connect to SageMaker Inference Endpoint '{self.model}'."
f"Make sure the Endpoint exists and AWS environment is configured."
)
raise AWSConfigurationError(msg)
raise AWSConfigurationError(msg) from e

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -124,11 +148,11 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
model=self.model,
aws_access_key_id_var=self.aws_access_key_id_var,
aws_secret_access_key_var=self.aws_secret_access_key_var,
aws_session_token_var=self.aws_session_token_var,
aws_region_name_var=self.aws_region_name_var,
aws_profile_name_var=self.aws_profile_name_var,
aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None,
aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None,
aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None,
aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None,
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
aws_custom_attributes=self.aws_custom_attributes,
generation_kwargs=self.generation_kwargs,
)
Expand All @@ -138,27 +162,47 @@ def from_dict(cls, data) -> "SagemakerGenerator":
"""
Deserialize the dictionary into an instance of SagemakerGenerator.
"""
deserialize_secrets_inplace(
data["init_parameters"],
["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"],
)
return default_from_dict(cls, data)

def warm_up(self):
@classmethod
def get_aws_session(
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
cls,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
**kwargs,
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Initializes the SageMaker Inference client.
Creates an AWS Session with the given parameters.
Checks if the provided AWS credentials are valid and can be used to connect to AWS.

:param aws_access_key_id: AWS access key ID.
:param aws_secret_access_key: AWS secret access key.
:param aws_session_token: AWS session token.
:param aws_region_name: AWS region name.
:param aws_profile_name: AWS profile name.
:param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen.
See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html.
:raises AWSConfigurationError: If the provided AWS credentials are invalid.
:return: The created AWS session.
"""
boto3_import.check()
try:
session = boto3.Session(
aws_access_key_id=os.getenv(self.aws_access_key_id_var),
aws_secret_access_key=os.getenv(self.aws_secret_access_key_var),
aws_session_token=os.getenv(self.aws_session_token_var),
region_name=os.getenv(self.aws_region_name_var),
profile_name=os.getenv(self.aws_profile_name_var),
)
self.client = session.client("sagemaker-runtime")
except Exception as e:
msg = (
f"Could not connect to SageMaker Inference Endpoint '{self.model}'."
f"Make sure the Endpoint exists and AWS environment is configured."
return boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=aws_region_name,
profile_name=aws_profile_name,
)
except BotoCoreError as e:
provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS}
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}"
raise AWSConfigurationError(msg) from e

@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
Expand All @@ -173,10 +217,6 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
:return: A list of strings containing the generated responses and a list of dictionaries containing the metadata
for each response.
"""
if self.client is None:
msg = "SageMaker Inference client is not initialized. Please call warm_up() first."
raise ValueError(msg)

generation_kwargs = generation_kwargs or self.generation_kwargs
custom_attributes = ";".join(
f"{k}={str(v).lower() if isinstance(v, bool) else str(v)}" for k, v in self.aws_custom_attributes.items()
Expand Down
18 changes: 18 additions & 0 deletions integrations/amazon_sagemaker/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest

from unittest.mock import patch


@pytest.fixture
def set_env_variables(monkeypatch):
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "some_fake_id")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "some_fake_key")
monkeypatch.setenv("AWS_SESSION_TOKEN", "some_fake_token")
monkeypatch.setenv("AWS_DEFAULT_REGION", "fake_region")
monkeypatch.setenv("AWS_PROFILE", "some_fake_profile")


@pytest.fixture
def mock_boto3_session():
with patch("boto3.Session") as mock_client:
yield mock_client
Loading