diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py index 3fdc7ab57..60dadfb94 100644 --- a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py @@ -1,13 +1,14 @@ -from typing import Optional, List, Dict, Any - -import os -import logging import json +import logging +import os +from typing import Any, Dict, List, Optional import requests -from haystack.lazy_imports import LazyImport from haystack import component -from haystack_integrations.components.generators.amazon_sagemaker.errors import AWSConfigurationError, SagemakerInferenceError, SagemakerNotReadyError +from haystack.lazy_imports import LazyImport +from haystack_integrations.components.generators.amazon_sagemaker.errors import ( + AWSConfigurationError, SagemakerInferenceError, SagemakerNotReadyError +) with LazyImport(message="Run 'pip install boto3'") as boto3_import: import boto3 diff --git a/integrations/amazon_sagemaker/tests/test_sagemaker.py b/integrations/amazon_sagemaker/tests/test_sagemaker.py index 07e4405d0..befb445e8 100644 --- a/integrations/amazon_sagemaker/tests/test_sagemaker.py +++ b/integrations/amazon_sagemaker/tests/test_sagemaker.py @@ -1,10 +1,7 @@ -from typing import List - import os -from unittest.mock import patch, Mock +from unittest.mock import Mock import pytest - from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator from haystack_integrations.components.generators.amazon_sagemaker.errors import AWSConfigurationError @@ -91,7 +88,6 @@ def test_run_with_list_of_dictionaries(self, monkeypatch): assert [isinstance(reply, dict) for reply in response["meta"]] assert response["meta"][0]["other"] == "metadata" - def test_run_with_single_dictionary(self, monkeypatch): monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") @@ -118,12 +114,8 @@ def test_run_with_single_dictionary(self, monkeypatch): assert [isinstance(reply, dict) for reply in response["meta"]] assert response["meta"][0]["other"] == "metadata" - @pytest.mark.skipif( - ( - not os.environ.get("AWS_ACCESS_KEY_ID", None) - or not os.environ.get("AWS_SECRET_ACCESS_KEY", None) - ), + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY containing the AWS credentials to run this test.", ) @pytest.mark.integration @@ -152,16 +144,15 @@ def test_run_falcon(self): assert [isinstance(reply, dict) for reply in response["meta"]] @pytest.mark.skipif( - ( - not os.environ.get("AWS_ACCESS_KEY_ID", None) - or not os.environ.get("AWS_SECRET_ACCESS_KEY", None) - ), + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY containing the AWS credentials to run this test.", ) @pytest.mark.integration def test_run_llama2(self): component = SagemakerGenerator( - model="jumpstart-dft-meta-textgenerationneuron-llama-2-7b", generation_kwargs={"max_new_tokens": 10}, aws_custom_attributes={"accept_eula": True} + model="jumpstart-dft-meta-textgenerationneuron-llama-2-7b", + generation_kwargs={"max_new_tokens": 10}, + aws_custom_attributes={"accept_eula": True}, ) component.warm_up() response = component.run("What's Natural Language Processing?") @@ -184,10 +175,7 @@ def test_run_llama2(self): assert [isinstance(reply, dict) for reply in response["meta"]] @pytest.mark.skipif( - ( - not os.environ.get("AWS_ACCESS_KEY_ID", None) - or not os.environ.get("AWS_SECRET_ACCESS_KEY", None) - ), + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY containing the AWS credentials to run this test.", ) @pytest.mark.integration