From 2d15065b5777eae3d7171556d0462b47a035b338 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Fri, 26 Jan 2024 12:06:11 +0100 Subject: [PATCH 1/3] initial implementation --- haystack/components/generators/sagemaker.py | 216 ++++++++++++++++++++ 1 file changed, 216 insertions(+) create mode 100644 haystack/components/generators/sagemaker.py diff --git a/haystack/components/generators/sagemaker.py b/haystack/components/generators/sagemaker.py new file mode 100644 index 0000000000..985ed25735 --- /dev/null +++ b/haystack/components/generators/sagemaker.py @@ -0,0 +1,216 @@ +from typing import Optional, List, Dict, Any + +import os +import logging +import json + +import requests +from haystack.lazy_imports import LazyImport +from haystack import component, default_from_dict, default_to_dict, ComponentError + +with LazyImport(message="Run 'pip install boto3'") as boto3_import: + import boto3 + from botocore.client import BaseClient + + +logger = logging.getLogger(__name__) + + +@component +class SagemakerGenerator: + model_generation_keys = ["generated_text", "generation"] + + """ + Enables text generation using Sagemaker. It supports Large Language Models (LLMs) hosted and deployed on a SageMaker + Inference Endpoint. For guidance on how to deploy a model to SageMaker, refer to the + [SageMaker JumpStart foundation models documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-use.html). + + Example: + + ```bash + export AWS_ACCESS_KEY_ID= + export AWS_SECRET_ACCESS_KEY= + export AWS_SESSION_TOKEN= # This is optional + export AWS_REGION_NAME= + ``` + + ```python + from haystack.components.generators.sagemaker import SagemakerGenerator + generator = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16") + generator.warm_up() + response = generator.run("What's Natural Language Processing? Be brief.") + print(response) + ``` + + TODO review reply format + + >> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on + >> the interaction between computers and human language. It involves enabling computers to understand, interpret, + >> and respond to natural human language in a way that is both meaningful and useful.'], 'meta': [{}]} + ``` + """ + + def __init__( + self, + model: str, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + aws_custom_attributes: Optional[Dict[str, Any]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Instantiates the session with SageMaker. + + :param model: The name for SageMaker Model Endpoint. + :param aws_access_key_id: AWS access key ID. + :param aws_secret_access_key: AWS secret access key. + :param aws_session_token: AWS session token. + :param aws_region_name: AWS region name. + :param aws_profile_name: AWS profile name. + :param aws_custom_attributes: Custom attributes to be passed to SageMaker, for example `{"accept_eula": True}` + in case of Llama-2 models. + :param generation_kwargs: Additional keyword arguments for text generation. For a list of supported parameters + see your model's documentation page, for example here for HuggingFace models: + https://huggingface.co/blog/sagemaker-huggingface-llm#4-run-inference-and-chat-with-our-model + + Specifically, Llama-2 models support the following inference payload parameters: + + - `max_new_tokens`: Model generates text until the output length (excluding the input context length) reaches + `max_new_tokens`. If specified, it must be a positive integer. + - `temperature`: Controls the randomness in the output. Higher temperature results in output sequence with + low-probability words and lower temperature results in output sequence with high-probability words. + If `temperature=0`, it results in greedy decoding. If specified, it must be a positive float. + - `top_p`: In each step of text generation, sample from the smallest possible set of words with cumulative + probability `top_p`. If specified, it must be a float between 0 and 1. + - `return_full_text`: If `True`, input text will be part of the output generated text. If specified, it must + be boolean. The default value for it is `False`. + """ + self.model = model + self.aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID", None) + self.aws_secret_access_key = aws_secret_access_key or os.getenv("AWS_SECRET_KEY", None) + self.aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN", None) + self.aws_region_name = aws_region_name or os.getenv("AWS_REGION_NAME", None) + self.aws_profile_name = aws_profile_name or os.getenv("AWS_PROFILE_NAME", None) + self.aws_custom_attributes = aws_custom_attributes or {} + self.generation_kwargs = generation_kwargs or {"max_new_tokens": 1024} + + self.client: Optional[BaseClient] = None + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + return {"model": self.model} + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. We must avoid serializing AWS credentials, so + we serialize only region and profile names. + + :return: The serialized component as a dictionary. + """ + return default_to_dict( + self, + model=self.model, + aws_region_name=self.aws_region_name, + aws_profile_name=self.aws_profile_name, + aws_custom_attributes=self.aws_custom_attributes, + generation_kwargs=self.generation_kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SagemakerGenerator": + """ + Deserialize this component from a dictionary. + + :param data: The dictionary representation of this component. + :return: The deserialized component instance. + """ + return default_from_dict(cls, data) + + def warm_up(self): + boto3_import.check() + try: + session = boto3.Session( + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + aws_session_token=self.aws_session_token, + region_name=self.aws_region_name, + profile_name=self.aws_profile_name, + ) + self.client = session.client("sagemaker-runtime") + except Exception as e: + raise ComponentError( + f"Could not connect to SageMaker Inference Endpoint '{self.model}'." + f"Make sure the Endpoint exists and AWS environment is configured." + ) from e + + @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) + def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): + """ + Invoke the text generation inference based on the provided messages and generation parameters. + + :param prompt: The string prompt to use for text generation. + :param generation_kwargs: Additional keyword arguments for text generation. These parameters will + potentially override the parameters passed in the __init__ method. + + :return: A list of strings containing the generated responses and a list of dictionaries containing the metadata + for each response. + """ + if self.client is None: + raise ValueError("SageMaker Inference client is not initialized. Please call warm_up() first.") + + generation_kwargs = generation_kwargs or self.generation_kwargs + custom_attributes = ";".join( + f"{k}={str(v).lower() if isinstance(v, bool) else str(v)}" for k, v in self.aws_custom_attributes.items() + ) + self.client: BaseClient + try: + body = json.dumps({"inputs": prompt, "parameters": generation_kwargs}) + response = self.client.invoke_endpoint( + EndpointName=self.model, + Body=body, + ContentType="application/json", + Accept="application/json", + CustomAttributes=custom_attributes, + ) + response_json = response.get("Body").read().decode("utf-8") + output: Dict[str, Dict[str, Any]] = json.loads(response_json) + + # Find the key that contains the generated text + # It can be any of the keys in model_generation_keys, depending on the model + for key in self.model_generation_keys: + if key in output[0]: + break + + replies = [o.pop(key, None) for o in output] + return {"replies": replies, "meta": output * len(replies)} + except requests.HTTPError as err: + res = err.response + if res.status_code == 429: + raise ComponentError(f"Sagemaker model not ready: {res.text}") from err + + raise ComponentError( + f"SageMaker Inference returned an error. Status code: {res.status_code} Response body: {res.text}", + status_code=res.status_code, + ) from err + + +""" + +""" + +# import os +# from haystack.nodes import PromptNode + +# # We can also configure Sagemaker via AWS environment variables without AWS profile name +# pn = PromptNode(model_name_or_path="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16", max_length=256, +# model_kwargs={"aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"), +# "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), +# "aws_region_name": "eu-west-1"}) + +# response = pn("Tell me more about Berlin, be elaborate") +# print(response) From 7402130815d6ff7fc7670cc93157b002cdf9742b Mon Sep 17 00:00:00 2001 From: ZanSara Date: Fri, 26 Jan 2024 13:10:38 +0100 Subject: [PATCH 2/3] add check on credentials in init --- haystack/components/generators/sagemaker.py | 104 +++++++------------- 1 file changed, 35 insertions(+), 69 deletions(-) diff --git a/haystack/components/generators/sagemaker.py b/haystack/components/generators/sagemaker.py index 985ed25735..28f39e36ba 100644 --- a/haystack/components/generators/sagemaker.py +++ b/haystack/components/generators/sagemaker.py @@ -25,15 +25,16 @@ class SagemakerGenerator: Inference Endpoint. For guidance on how to deploy a model to SageMaker, refer to the [SageMaker JumpStart foundation models documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-use.html). - Example: + **Example:** + First export your AWS credentials as environment variables: ```bash export AWS_ACCESS_KEY_ID= export AWS_SECRET_ACCESS_KEY= - export AWS_SESSION_TOKEN= # This is optional - export AWS_REGION_NAME= ``` + (Note: you may also need to set the session token and region name, depending on your AWS configuration) + Then you can use the generator as follows: ```python from haystack.components.generators.sagemaker import SagemakerGenerator generator = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16") @@ -42,8 +43,7 @@ class SagemakerGenerator: print(response) ``` - TODO review reply format - + ``` >> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on >> the interaction between computers and human language. It involves enabling computers to understand, interpret, >> and respond to natural human language in a way that is both meaningful and useful.'], 'meta': [{}]} @@ -53,11 +53,11 @@ class SagemakerGenerator: def __init__( self, model: str, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_session_token: Optional[str] = None, - aws_region_name: Optional[str] = None, - aws_profile_name: Optional[str] = None, + aws_access_key_id_var: str = "AWS_ACCESS_KEY_ID", + aws_secret_access_key_var: str = "AWS_SECRET_ACCESS_KEY", + aws_session_token_var: str = "AWS_SESSION_TOKEN", + aws_region_name_var: str = "AWS_REGION", + aws_profile_name_var: str = "AWS_PROFILE", aws_custom_attributes: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None, ): @@ -65,11 +65,11 @@ def __init__( Instantiates the session with SageMaker. :param model: The name for SageMaker Model Endpoint. - :param aws_access_key_id: AWS access key ID. - :param aws_secret_access_key: AWS secret access key. - :param aws_session_token: AWS session token. - :param aws_region_name: AWS region name. - :param aws_profile_name: AWS profile name. + :param aws_access_key_id_var: The name of the env var where the AWS access key ID is stored. + :param aws_secret_access_key_var: The name of the env var where the AWS secret access key is stored. + :param aws_session_token_var: The name of the env var where the AWS session token is stored. + :param aws_region_name_var: The name of the env var where the AWS region name is stored. + :param aws_profile_name_var: The name of the env var where the AWS profile name is stored. :param aws_custom_attributes: Custom attributes to be passed to SageMaker, for example `{"accept_eula": True}` in case of Llama-2 models. :param generation_kwargs: Additional keyword arguments for text generation. For a list of supported parameters @@ -89,57 +89,39 @@ def __init__( be boolean. The default value for it is `False`. """ self.model = model - self.aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID", None) - self.aws_secret_access_key = aws_secret_access_key or os.getenv("AWS_SECRET_KEY", None) - self.aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN", None) - self.aws_region_name = aws_region_name or os.getenv("AWS_REGION_NAME", None) - self.aws_profile_name = aws_profile_name or os.getenv("AWS_PROFILE_NAME", None) + self.aws_access_key_id_var = aws_access_key_id_var + self.aws_secret_access_key_var = aws_secret_access_key_var + self.aws_session_token_var = aws_session_token_var + self.aws_region_name_var = aws_region_name_var + self.aws_profile_name_var = aws_profile_name_var self.aws_custom_attributes = aws_custom_attributes or {} self.generation_kwargs = generation_kwargs or {"max_new_tokens": 1024} - self.client: Optional[BaseClient] = None + if not os.getenv(self.aws_access_key_id_var) or not os.getenv(self.aws_secret_access_key_var): + raise ValueError( + f"Please provide AWS credentials via environment variables '{self.aws_access_key_id_var}' and " + f"'{self.aws_secret_access_key_var}'." + ) + def _get_telemetry_data(self) -> Dict[str, Any]: """ Data that is sent to Posthog for usage analytics. """ return {"model": self.model} - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. We must avoid serializing AWS credentials, so - we serialize only region and profile names. - - :return: The serialized component as a dictionary. + def warm_up(self): """ - return default_to_dict( - self, - model=self.model, - aws_region_name=self.aws_region_name, - aws_profile_name=self.aws_profile_name, - aws_custom_attributes=self.aws_custom_attributes, - generation_kwargs=self.generation_kwargs, - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SagemakerGenerator": + Initializes the SageMaker Inference client. """ - Deserialize this component from a dictionary. - - :param data: The dictionary representation of this component. - :return: The deserialized component instance. - """ - return default_from_dict(cls, data) - - def warm_up(self): boto3_import.check() try: session = boto3.Session( - aws_access_key_id=self.aws_access_key_id, - aws_secret_access_key=self.aws_secret_access_key, - aws_session_token=self.aws_session_token, - region_name=self.aws_region_name, - profile_name=self.aws_profile_name, + aws_access_key_id=os.getenv(self.aws_access_key_id_var), + aws_secret_access_key=os.getenv(self.aws_secret_access_key_var), + aws_session_token=os.getenv(self.aws_session_token_var), + region_name=os.getenv(self.aws_region_name_var), + profile_name=os.getenv(self.aws_profile_name_var), ) self.client = session.client("sagemaker-runtime") except Exception as e: @@ -155,7 +137,7 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): :param prompt: The string prompt to use for text generation. :param generation_kwargs: Additional keyword arguments for text generation. These parameters will - potentially override the parameters passed in the __init__ method. + potentially override the parameters passed in the `__init__` method. :return: A list of strings containing the generated responses and a list of dictionaries containing the metadata for each response. @@ -188,6 +170,7 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): replies = [o.pop(key, None) for o in output] return {"replies": replies, "meta": output * len(replies)} + except requests.HTTPError as err: res = err.response if res.status_code == 429: @@ -197,20 +180,3 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): f"SageMaker Inference returned an error. Status code: {res.status_code} Response body: {res.text}", status_code=res.status_code, ) from err - - -""" - -""" - -# import os -# from haystack.nodes import PromptNode - -# # We can also configure Sagemaker via AWS environment variables without AWS profile name -# pn = PromptNode(model_name_or_path="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16", max_length=256, -# model_kwargs={"aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"), -# "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), -# "aws_region_name": "eu-west-1"}) - -# response = pn("Tell me more about Berlin, be elaborate") -# print(response) From 884752be4c20295521326ff9b1f3a5e2c95ec2a7 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Fri, 26 Jan 2024 14:47:23 +0100 Subject: [PATCH 3/3] add some tests --- haystack/components/generators/sagemaker.py | 2 +- test/components/generators/test_sagemaker.py | 125 +++++++++++++++++++ 2 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 test/components/generators/test_sagemaker.py diff --git a/haystack/components/generators/sagemaker.py b/haystack/components/generators/sagemaker.py index 28f39e36ba..d863c55677 100644 --- a/haystack/components/generators/sagemaker.py +++ b/haystack/components/generators/sagemaker.py @@ -6,7 +6,7 @@ import requests from haystack.lazy_imports import LazyImport -from haystack import component, default_from_dict, default_to_dict, ComponentError +from haystack import component, ComponentError with LazyImport(message="Run 'pip install boto3'") as boto3_import: import boto3 diff --git a/test/components/generators/test_sagemaker.py b/test/components/generators/test_sagemaker.py new file mode 100644 index 0000000000..1ae26bd846 --- /dev/null +++ b/test/components/generators/test_sagemaker.py @@ -0,0 +1,125 @@ +from typing import List + +import os +from unittest.mock import patch, Mock + +import pytest +from openai import OpenAIError + +from haystack.components.generators.sagemaker import SagemakerGenerator + + +class TestSagemakerGenerator: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + + component = SagemakerGenerator(model="test-model") + assert component.model == "test-model" + assert component.aws_access_key_id_var == "AWS_ACCESS_KEY_ID" + assert component.aws_secret_access_key_var == "AWS_SECRET_ACCESS_KEY" + assert component.aws_session_token_var == "AWS_SESSION_TOKEN" + assert component.aws_region_name_var == "AWS_REGION" + assert component.aws_profile_name_var == "AWS_PROFILE" + assert component.aws_custom_attributes == {} + assert component.generation_kwargs == {"max_new_tokens": 1024} + assert component.client is None + + def test_init_fail_wo_access_key_or_secret_key(self, monkeypatch): + monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + with pytest.raises(ValueError): + SagemakerGenerator(model="test-model") + + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + with pytest.raises(ValueError): + SagemakerGenerator(model="test-model") + + monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + with pytest.raises(ValueError): + SagemakerGenerator(model="test-model") + + def test_init_with_parameters(self, monkeypatch): + monkeypatch.setenv("MY_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("MY_SECRET_ACCESS_KEY", "test-secret-key") + + component = SagemakerGenerator( + model="test-model", + aws_access_key_id_var="MY_ACCESS_KEY_ID", + aws_secret_access_key_var="MY_SECRET_ACCESS_KEY", + aws_session_token_var="MY_SESSION_TOKEN", + aws_region_name_var="MY_REGION", + aws_profile_name_var="MY_PROFILE", + aws_custom_attributes={"custom": "attr"}, + generation_kwargs={"generation": "kwargs"}, + ) + assert component.model == "test-model" + assert component.aws_access_key_id_var == "MY_ACCESS_KEY_ID" + assert component.aws_secret_access_key_var == "MY_SECRET_ACCESS_KEY" + assert component.aws_session_token_var == "MY_SESSION_TOKEN" + assert component.aws_region_name_var == "MY_REGION" + assert component.aws_profile_name_var == "MY_PROFILE" + assert component.aws_custom_attributes == {"custom": "attr"} + assert component.generation_kwargs == {"generation": "kwargs"} + assert component.client is None + + def test_run(self, monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + client_mock = Mock() + client_mock.invoke_endpoint.return_value = { + "Body": Mock(read=lambda: b'[{"generated_text": "test-reply", "other": "metadata"}]') + } + + component = SagemakerGenerator(model="test-model") + component.client = client_mock # Simulate warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + assert "test-reply" in response["replies"][0] + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + assert response["meta"][0]["other"] == "metadata" + + @pytest.mark.skipif( + ( + not os.environ.get("AWS_ACCESS_KEY_ID", None) + or not os.environ.get("AWS_SECRET_ACCESS_KEY", None) + or not os.environ.get("AWS_SAGEMAKER_TEST_MODEL", None) + ), + reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY containing the AWS credentials to run this test.", + ) + @pytest.mark.integration + def test_run_falcon(self): + component = SagemakerGenerator( + model=os.getenv("AWS_SAGEMAKER_TEST_MODEL"), generation_kwargs={"max_new_tokens": 10} + ) + component.warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + # Coarse check: assuming no more than 4 chars per token. In any case it + # will fail if the `max_new_tokens` parameter is not respected, as the + # default is either 256 or 1024 + assert all(len(reply) <= 40 for reply in response["replies"]) + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]]