diff --git a/noxfile.py b/noxfile.py index 4c4949413..0250db204 100644 --- a/noxfile.py +++ b/noxfile.py @@ -54,6 +54,7 @@ "flake8-annotations", "flake8-docstrings", "mypy", + "moto", ] diff --git a/pyproject.toml b/pyproject.toml index e05d685d0..d638b7cd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,6 +125,7 @@ darglint = "^1.8.0" flake8 = "^3.9.0" flake8-annotations = "^2.9.1" flake8-docstrings = "^1.7.0" +moto = "^4.1.8" [tool.black] exclude = ".*simpleeval.*" diff --git a/singer_sdk/connectors/__init__.py b/singer_sdk/connectors/__init__.py index 32799417a..401157da5 100644 --- a/singer_sdk/connectors/__init__.py +++ b/singer_sdk/connectors/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +from .aws_boto3 import AWSBoto3Connector from .sql import SQLConnector -__all__ = ["SQLConnector"] +__all__ = ["SQLConnector", "AWSBoto3Connector"] diff --git a/singer_sdk/connectors/aws_boto3.py b/singer_sdk/connectors/aws_boto3.py new file mode 100644 index 000000000..cfce466df --- /dev/null +++ b/singer_sdk/connectors/aws_boto3.py @@ -0,0 +1,272 @@ +"""AWS Boto3 Connector.""" + +from __future__ import annotations + +import logging +import os + +from singer_sdk import typing as th # JSON schema typing helpers + +try: + import boto3 +except ImportError: + msg = "boto3 is required for this authenticator. Please install it with `poetry add boto3`." + raise Exception( + msg, + ) + +AWS_AUTH_CONFIG = th.PropertiesList( + th.Property( + "aws_access_key_id", + th.StringType, + secret=True, + description="The access key for your AWS account.", + ), + th.Property( + "aws_secret_access_key", + th.StringType, + secret=True, + description="The secret key for your AWS account.", + ), + th.Property( + "aws_session_token", + th.StringType, + secret=True, + description=( + "The session key for your AWS account. This is only needed when" + " you are using temporary credentials." + ), + ), + th.Property( + "aws_profile", + th.StringType, + description=( + "The AWS credentials profile name to use. The profile must be " + "configured and accessible." + ), + ), + th.Property( + "aws_default_region", + th.StringType, + description="The default AWS region name (e.g. us-east-1) ", + ), + th.Property( + "aws_endpoint_url", + th.StringType, + description="The complete URL to use for the constructed client.", + ), + th.Property( + "aws_assume_role_arn", + th.StringType, + description="The role ARN to assume.", + ), + th.Property( + "use_aws_env_vars", + th.BooleanType, + default=False, + description=("Whether to retrieve aws credentials from environment variables."), + ), +).to_dict() + + +class AWSBoto3Connector: + """Base class for AWS boto3-based connectors. + + The connector class serves as a wrapper around boto3 package. + + The functions of the connector are: + - initializing a client, resource, or session with a simple interface + - accessing AWS credentials via config, env vars, or profile + - supports assuming roles + - enables configurable endpoint_url for testing + """ + + def __init__( + self: AWSBoto3Connector, + config: dict, + service_name: str, + ) -> None: + """Initialize the AWSBotoAuthenticator. + + Args: + config (dict): The config for the connector. + service_name (str): The name of the AWS service. + """ + self._service_name = service_name + self._config = config + self._client = None + self._resource = None + # config for use environment variables + if config.get("use_aws_env_vars"): + self.aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID") + self.aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY") + self.aws_session_token = os.environ.get("AWS_SESSION_TOKEN") + self.aws_profile = os.environ.get("AWS_PROFILE") + self.aws_default_region = os.environ.get("AWS_DEFAULT_REGION") + else: + self.aws_access_key_id = config.get("aws_access_key_id") + self.aws_secret_access_key = config.get("aws_secret_access_key") + self.aws_session_token = config.get("aws_session_token") + self.aws_profile = config.get("aws_profile") + self.aws_default_region = config.get("aws_default_region") + + self.aws_endpoint_url = config.get("aws_endpoint_url") + self.aws_assume_role_arn = config.get("aws_assume_role_arn") + + @property + def config(self: AWSBoto3Connector) -> dict: + """If set, provides access to the tap or target config. + + Returns: + The settings as a dict. + """ + return self._config + + @property + def logger(self: AWSBoto3Connector) -> logging.Logger: + """Get logger. + + Returns: + Plugin logger. + """ + return logging.getLogger("aws_boto_connector") + + @property + def client(self: AWSBoto3Connector) -> boto3.client: + """Return the boto3 client for the service. + + Returns: + boto3.client: The boto3 client for the service. + """ + if self._client: + return self._client + else: + session = self._get_session() + self._client = self._get_client(session, self._service_name) + return self._client + + @property + def resource(self: AWSBoto3Connector) -> boto3.resource: + """Return the boto3 resource for the service. + + Returns: + boto3.resource: The boto3 resource for the service. + """ + if self._resource: + return self._resource + else: + session = self._get_session() + self._resource = self._get_resource(session, self._service_name) + return self._resource + + def _get_session(self: AWSBoto3Connector) -> boto3.session: + """Return the boto3 session. + + Returns: + boto3.session: The boto3 session. + """ + session = None + if ( + self.aws_access_key_id + and self.aws_secret_access_key + and self.aws_session_token + and self.aws_default_region + ): + session = boto3.Session( + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + aws_session_token=self.aws_session_token, + region_name=self.aws_default_region, + ) + self.logger.info( + "Authenticating using access key id, secret access key, and " + "session token.", + ) + elif ( + self.aws_access_key_id + and self.aws_secret_access_key + and self.aws_default_region + ): + session = boto3.Session( + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + region_name=self.aws_default_region, + ) + self.logger.info( + "Authenticating using access key id and secret access key.", + ) + elif self.aws_profile: + session = boto3.Session(profile_name=self.aws_profile) + self.logger.info("Authenticating using profile.") + else: + session = boto3.Session() + self.logger.info("Authenticating using implicit pre-installed credentials.") + + if self.aws_assume_role_arn: + session = self._assume_role(session, self.aws_assume_role_arn) + return session + + def _factory( + self: AWSBoto3Connector, + aws_obj: boto3.resource | boto3.client, + service_name: str, + ) -> boto3.resource | boto3.client: + if self.aws_endpoint_url: + return aws_obj( + service_name, + endpoint_url=self.aws_endpoint_url, + ) + else: + return aws_obj( + service_name, + ) + + def _get_resource( + self: AWSBoto3Connector, + session: boto3.session, + service_name: str, + ) -> boto3.resource: + """Return the boto3 resource for the service. + + Args: + session (boto3.session.Session): The boto3 session. + service_name (str): The name of the AWS service. + + Returns: + boto3.resource: The boto3 resource for the service. + """ + return self._factory(session.resource, service_name) + + def _get_client( + self: AWSBoto3Connector, + session: boto3.session.Session, + service_name: str, + ) -> boto3.client: + """Return the boto3 client for the service. + + Args: + session (boto3.session.Session): The boto3 session. + service_name (str): The name of the AWS service. + + Returns: + boto3.client: The boto3 client for the service. + """ + return self._factory(session.client, service_name) + + def _assume_role( + self: AWSBoto3Connector, + session: boto3.session.Session, + role_arn: str, + ) -> boto3.session.Session: + # TODO: use for auto refresh https://github.com/benkehoe/aws-assume-role-lib + sts_client = self._get_client(session, "sts") + response = sts_client.assume_role( + RoleArn=role_arn, + RoleSessionName="tap-dynamodb", + ) + return boto3.Session( + aws_access_key_id=response["Credentials"]["AccessKeyId"], + aws_secret_access_key=response["Credentials"]["SecretAccessKey"], + aws_session_token=response["Credentials"]["SessionToken"], + region_name=self.aws_default_region, + ) diff --git a/tests/core/test_connector_aws_boto3.py b/tests/core/test_connector_aws_boto3.py new file mode 100644 index 000000000..7c27eede4 --- /dev/null +++ b/tests/core/test_connector_aws_boto3.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +from unittest.mock import patch + +from moto import mock_dynamodb, mock_sts + +from singer_sdk.connectors import AWSBoto3Connector + + +@patch( + "singer_sdk.connectors.aws_boto3.boto3.Session", + return_value="mock_session", +) +@mock_dynamodb +def test_get_session_base(patch): + auth = AWSBoto3Connector( + { + "aws_access_key_id": "foo", + "aws_secret_access_key": "bar", + "aws_default_region": "baz", + }, + "dynamodb", + ) + session = auth._get_session() + patch.assert_called_with( + aws_access_key_id="foo", + aws_secret_access_key="bar", + region_name="baz", + ) + assert session == "mock_session" + + +@patch( + "singer_sdk.connectors.aws_boto3.boto3.Session", + return_value="mock_session", +) +@mock_dynamodb +def test_get_session_w_token(patch): + auth = AWSBoto3Connector( + { + "aws_access_key_id": "foo", + "aws_secret_access_key": "bar", + "aws_session_token": "abc", + "aws_default_region": "baz", + }, + "dynamodb", + ) + session = auth._get_session() + patch.assert_called_with( + aws_access_key_id="foo", + aws_secret_access_key="bar", + aws_session_token="abc", + region_name="baz", + ) + assert session == "mock_session" + + +@patch( + "singer_sdk.connectors.aws_boto3.boto3.Session", + return_value="mock_session", +) +@mock_dynamodb +def test_get_session_w_profile(patch): + auth = AWSBoto3Connector( + { + "aws_profile": "foo", + }, + "dynamodb", + ) + session = auth._get_session() + patch.assert_called_with( + profile_name="foo", + ) + assert session == "mock_session" + + +@patch( + "singer_sdk.connectors.aws_boto3.boto3.Session", + return_value="mock_session", +) +@mock_dynamodb +def test_get_session_implicit(patch): + auth = AWSBoto3Connector({}, "dynamodb") + session = auth._get_session() + patch.assert_called_with() + assert session == "mock_session" + + +@mock_dynamodb +@mock_sts +def test_get_session_assume_role(): + auth = AWSBoto3Connector( + { + "aws_access_key_id": "foo", + "aws_secret_access_key": "bar", + "aws_default_region": "baz", + "aws_assume_role_arn": "arn:aws:iam::123456778910:role/my-role-name", + }, + "dynamodb", + ) + auth._get_session() + + +@mock_dynamodb +def test_get_client(): + auth = AWSBoto3Connector( + { + "aws_access_key_id": "foo", + "aws_secret_access_key": "bar", + "aws_default_region": "baz", + }, + "dynamodb", + ) + session = auth._get_session() + auth._get_client(session, "dynamodb") + + +@mock_dynamodb +def test_get_resource(): + auth = AWSBoto3Connector( + { + "aws_access_key_id": "foo", + "aws_secret_access_key": "bar", + "aws_default_region": "baz", + }, + "dynamodb", + ) + session = auth._get_session() + auth._get_resource(session, "dynamodb") + + +@patch( + "singer_sdk.connectors.aws_boto3.AWSBoto3Connector._get_client", + return_value="mock_client", +) +@mock_dynamodb +def test_client_property(patch): + auth = AWSBoto3Connector( + { + "aws_access_key_id": "foo", + "aws_secret_access_key": "bar", + "aws_default_region": "baz", + }, + "dynamodb", + ) + assert auth.client == auth._client + assert auth.client == auth._client + patch.assert_called_once() + + +@patch( + "singer_sdk.connectors.aws_boto3.AWSBoto3Connector._get_resource", + return_value="mock_resource", +) +@mock_dynamodb +def test_resource_property(patch): + auth = AWSBoto3Connector( + { + "aws_access_key_id": "foo", + "aws_secret_access_key": "bar", + "aws_default_region": "baz", + }, + "dynamodb", + ) + assert auth.resource == auth._resource + assert auth.resource == auth._resource + patch.assert_called_once() + + +@patch( + "singer_sdk.connectors.aws_boto3.boto3.session.Session.resource", +) +@mock_dynamodb +def test_resource_property_endpoint(patch): + auth = AWSBoto3Connector( + { + "aws_access_key_id": "foo", + "aws_secret_access_key": "bar", + "aws_default_region": "baz", + "aws_endpoint_url": "http://localhost:8000", + }, + "dynamodb", + ) + assert auth.resource == auth._resource + patch.assert_called_with("dynamodb", endpoint_url="http://localhost:8000") + + +def test_use_env_vars(): + import os + + with patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "key_id", + "AWS_SECRET_ACCESS_KEY": "access_key", + "AWS_SESSION_TOKEN": "token", + "AWS_PROFILE": "profile", + "AWS_DEFAULT_REGION": "region", + }, + ): + auth = AWSBoto3Connector( + { + "use_aws_env_vars": True, + }, + "dynamodb", + ) + assert auth.aws_access_key_id == "key_id" + assert auth.aws_secret_access_key == "access_key" + assert auth.aws_session_token == "token" + assert auth.aws_profile == "profile" + assert auth.aws_default_region == "region"