diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py index 4eacf33a4..106698558 100644 --- a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py @@ -24,11 +24,13 @@ @component class SagemakerGenerator: """ - Enables text generation using Sagemaker. It supports Large Language Models (LLMs) hosted and deployed on a SageMaker - Inference Endpoint. For guidance on how to deploy a model to SageMaker, refer to the + Enables text generation using Amazon Sagemaker. + + SagemakerGenerator supports Large Language Models (LLMs) hosted and deployed on a SageMaker Inference Endpoint. + For guidance on how to deploy a model to SageMaker, refer to the [SageMaker JumpStart foundation models documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-use.html). - **Example:** + Usage example: Make sure your AWS credentials are set up correctly. You can use environment variables or a shared credentials file. Then you can use the generator as follows: @@ -118,13 +120,18 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: def _get_telemetry_data(self) -> Dict[str, Any]: """ - Data that is sent to Posthog for usage analytics. + Returns data that is sent to Posthog for usage analytics. + :returns: a dictionary with following keys: + - model: The name of the model. + """ return {"model": self.model} def to_dict(self) -> Dict[str, Any]: """ - Serialize the object to a dictionary. + Serializes the component to a dictionary. + :returns: + Dictionary with serialized data. """ return default_to_dict( self, @@ -141,7 +148,11 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data) -> "SagemakerGenerator": """ - Deserialize the dictionary into an instance of SagemakerGenerator. + Deserializes the component from a dictionary. + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. """ deserialize_secrets_inplace( data["init_parameters"], @@ -185,14 +196,15 @@ def _get_aws_session( @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): """ - Invoke the text generation inference based on the provided messages and generation parameters. + Invoke the text generation inference based on the provided prompt and generation parameters. :param prompt: The string prompt to use for text generation. :param generation_kwargs: Additional keyword arguments for text generation. These parameters will potentially override the parameters passed in the `__init__` method. - :return: A list of strings containing the generated responses and a list of dictionaries containing the metadata - for each response. + :return: A dictionary with the following keys: + - `replies`: A list of strings containing the generated responses + - `meta`: A list of dictionaries containing the metadata for each response. """ generation_kwargs = generation_kwargs or self.generation_kwargs custom_attributes = ";".join(