Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: review integrations sagemaker #544

Merged
merged 4 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,46 +1,16 @@
from typing import Optional
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that was all unnecessary reimplementation



class SagemakerError(Exception):
"""
Error generated by the Amazon Sagemaker integration.
Parent class for all exceptions raised by the Sagemaker component
"""

def __init__(
self,
message: Optional[str] = None,
):
super().__init__()
if message:
self.message = message

def __getattr__(self, attr):
# If self.__cause__ is None, it will raise the expected AttributeError
getattr(self.__cause__, attr)

def __str__(self):
return self.message

def __repr__(self):
return str(self)


class AWSConfigurationError(SagemakerError):
"""Exception raised when AWS is not configured correctly"""

def __init__(self, message: Optional[str] = None):
super().__init__(message=message)


class SagemakerNotReadyError(SagemakerError):
"""Exception for issues that occur during Sagemaker inference"""

def __init__(self, message: Optional[str] = None):
super().__init__(message=message)


class SagemakerInferenceError(SagemakerError):
"""Exception for issues that occur during Sagemaker inference"""

def __init__(self, message: Optional[str] = None):
super().__init__(message=message)
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,17 @@ class SagemakerGenerator:
[SageMaker JumpStart foundation models documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-use.html).

Usage example:

Make sure your AWS credentials are set up correctly. You can use environment variables or a shared credentials file.
Then you can use the generator as follows:
```python
# Make sure your AWS credentials are set up correctly. You can use environment variables or a shared credentials
# file. Then you can use the generator as follows:
from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator

generator = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-bf16")
response = generator.run("What's Natural Language Processing? Be brief.")
print(response)
```
```
>> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on
>> the interaction between computers and human language. It involves enabling computers to understand, interpret,
>> and respond to natural human language in a way that is both meaningful and useful.'], 'meta': [{}]}
>>> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on
>>> the interaction between computers and human language. It involves enabling computers to understand, interpret,
>>> and respond to natural human language in a way that is both meaningful and useful.'], 'meta': [{}]}
```
"""

Expand Down Expand Up @@ -73,7 +71,6 @@ def __init__(
:param model: The name for SageMaker Model Endpoint.
:param aws_custom_attributes: Custom attributes to be passed to SageMaker, for example `{"accept_eula": True}`
in case of Llama-2 models.

:param generation_kwargs: Additional keyword arguments for text generation. For a list of supported parameters
see your model's documentation page, for example here for HuggingFace models:
https://huggingface.co/blog/sagemaker-huggingface-llm#4-run-inference-and-chat-with-our-model
Expand Down Expand Up @@ -121,15 +118,15 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Returns data that is sent to Posthog for usage analytics.
:returns: a dictionary with following keys:
- model: The name of the model.

:returns: A dictionary with the following keys:
- `model`: The name of the model.
"""
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""
Expand All @@ -149,10 +146,11 @@ def to_dict(self) -> Dict[str, Any]:
def from_dict(cls, data) -> "SagemakerGenerator":
"""
Deserializes the component from a dictionary.

:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
Deserialized component.
"""
deserialize_secrets_inplace(
data["init_parameters"],
Expand All @@ -170,6 +168,7 @@ def _get_aws_session(
):
"""
Creates an AWS Session with the given parameters.

Checks if the provided AWS credentials are valid and can be used to connect to AWS.

:param aws_access_key_id: AWS access key ID.
Expand Down Expand Up @@ -200,8 +199,10 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):

:param prompt: The string prompt to use for text generation.
:param generation_kwargs: Additional keyword arguments for text generation. These parameters will
potentially override the parameters passed in the `__init__` method.

potentially override the parameters passed in the `__init__` method.
:raises ValueError: If the model response type is not a list of dictionaries or a single dictionary.
:raises SagemakerNotReadyError: If the SageMaker model is not ready to accept requests.
:raises SagemakerInferenceError: If the SageMaker Inference returns an error.
:returns: A dictionary with the following keys:
- `replies`: A list of strings containing the generated responses
- `meta`: A list of dictionaries containing the metadata for each response.
Expand Down Expand Up @@ -249,5 +250,5 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
msg = f"Sagemaker model not ready: {res.text}"
raise SagemakerNotReadyError(msg) from err

msg = f"SageMaker Inference returned an error. Status code: {res.status_code} Response body: {res.text}"
raise SagemakerInferenceError(msg, status_code=res.status_code) from err
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this exception doesn't take a param status_code

msg = f"SageMaker Inference returned an error. Status code: {res.status_code}. Response body: {res.text}"
raise SagemakerInferenceError(msg) from err