Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update docs for Amazon Sagemaker #514

Merged
merged 1 commit into from
Mar 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
@component
class SagemakerGenerator:
"""
Enables text generation using Sagemaker. It supports Large Language Models (LLMs) hosted and deployed on a SageMaker
Inference Endpoint. For guidance on how to deploy a model to SageMaker, refer to the
Enables text generation using Amazon Sagemaker.

SagemakerGenerator supports Large Language Models (LLMs) hosted and deployed on a SageMaker Inference Endpoint.
For guidance on how to deploy a model to SageMaker, refer to the
[SageMaker JumpStart foundation models documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-use.html).

**Example:**
Usage example:

Make sure your AWS credentials are set up correctly. You can use environment variables or a shared credentials file.
Then you can use the generator as follows:
Expand Down Expand Up @@ -118,13 +120,18 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
Returns data that is sent to Posthog for usage analytics.
:returns: a dictionary with following keys:
- model: The name of the model.

"""
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Serialize the object to a dictionary.
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
Expand All @@ -141,7 +148,11 @@ def to_dict(self) -> Dict[str, Any]:
@classmethod
def from_dict(cls, data) -> "SagemakerGenerator":
"""
Deserialize the dictionary into an instance of SagemakerGenerator.
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
deserialize_secrets_inplace(
data["init_parameters"],
Expand Down Expand Up @@ -185,14 +196,15 @@ def _get_aws_session(
@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Invoke the text generation inference based on the provided messages and generation parameters.
Invoke the text generation inference based on the provided prompt and generation parameters.

:param prompt: The string prompt to use for text generation.
:param generation_kwargs: Additional keyword arguments for text generation. These parameters will
potentially override the parameters passed in the `__init__` method.

:return: A list of strings containing the generated responses and a list of dictionaries containing the metadata
for each response.
:return: A dictionary with the following keys:
- `replies`: A list of strings containing the generated responses
- `meta`: A list of dictionaries containing the metadata for each response.
"""
generation_kwargs = generation_kwargs or self.generation_kwargs
custom_attributes = ";".join(
Expand Down