Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
ZanSara committed Jan 29, 2024
1 parent 89d4651 commit f1de01c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Optional, List, Dict, Any

import os
import logging
import json
import logging
import os
from typing import Any, Dict, List, Optional

import requests
from haystack.lazy_imports import LazyImport
from haystack import component
from haystack_integrations.components.generators.amazon_sagemaker.errors import AWSConfigurationError, SagemakerInferenceError, SagemakerNotReadyError
from haystack.lazy_imports import LazyImport
from haystack_integrations.components.generators.amazon_sagemaker.errors import (
AWSConfigurationError, SagemakerInferenceError, SagemakerNotReadyError
)

with LazyImport(message="Run 'pip install boto3'") as boto3_import:
import boto3
Expand Down
26 changes: 7 additions & 19 deletions integrations/amazon_sagemaker/tests/test_sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from typing import List

import os
from unittest.mock import patch, Mock
from unittest.mock import Mock

import pytest

from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator
from haystack_integrations.components.generators.amazon_sagemaker.errors import AWSConfigurationError

Expand Down Expand Up @@ -91,7 +88,6 @@ def test_run_with_list_of_dictionaries(self, monkeypatch):
assert [isinstance(reply, dict) for reply in response["meta"]]
assert response["meta"][0]["other"] == "metadata"


def test_run_with_single_dictionary(self, monkeypatch):
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key")
Expand All @@ -118,12 +114,8 @@ def test_run_with_single_dictionary(self, monkeypatch):
assert [isinstance(reply, dict) for reply in response["meta"]]
assert response["meta"][0]["other"] == "metadata"


@pytest.mark.skipif(
(
not os.environ.get("AWS_ACCESS_KEY_ID", None)
or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)
),
(not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)),
reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY containing the AWS credentials to run this test.",
)
@pytest.mark.integration
Expand Down Expand Up @@ -152,16 +144,15 @@ def test_run_falcon(self):
assert [isinstance(reply, dict) for reply in response["meta"]]

@pytest.mark.skipif(
(
not os.environ.get("AWS_ACCESS_KEY_ID", None)
or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)
),
(not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)),
reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY containing the AWS credentials to run this test.",
)
@pytest.mark.integration
def test_run_llama2(self):
component = SagemakerGenerator(
model="jumpstart-dft-meta-textgenerationneuron-llama-2-7b", generation_kwargs={"max_new_tokens": 10}, aws_custom_attributes={"accept_eula": True}
model="jumpstart-dft-meta-textgenerationneuron-llama-2-7b",
generation_kwargs={"max_new_tokens": 10},
aws_custom_attributes={"accept_eula": True},
)
component.warm_up()
response = component.run("What's Natural Language Processing?")
Expand All @@ -184,10 +175,7 @@ def test_run_llama2(self):
assert [isinstance(reply, dict) for reply in response["meta"]]

@pytest.mark.skipif(
(
not os.environ.get("AWS_ACCESS_KEY_ID", None)
or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)
),
(not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)),
reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY containing the AWS credentials to run this test.",
)
@pytest.mark.integration
Expand Down

0 comments on commit f1de01c

Please sign in to comment.