Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR-2930 add tests #169

Merged
merged 14 commits into from
Feb 4, 2022
Merged
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@

# OS
.DS_Store

# Tests
.hypothesis
.coverage
156 changes: 83 additions & 73 deletions rain_api_core/aws_util.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,54 @@

import functools
import json
import logging
import os
import sys
import urllib
from netaddr import IPAddress, IPNetwork
from json import loads
import urllib.request
from time import time
from yaml import safe_load
from boto3 import client as botoclient, resource as botoresource, session as botosession, Session as boto_Session

from boto3 import Session as boto_Session
from boto3 import client as botoclient
from boto3 import resource as botoresource
from boto3 import session as botosession
from boto3.resources.base import ServiceResource
from botocore.config import Config as bc_Config
from botocore.exceptions import ClientError
from netaddr import IPAddress, IPNetwork
from yaml import safe_load

from rain_api_core.general_util import return_timing_object, duration
from rain_api_core.general_util import duration, return_timing_object

log = logging.getLogger(__name__)
sts = botoclient('sts')
secret_cache = {}
session_cache = {}
region_list_cache = []
s3_resource = None
region = ''
botosess = botosession.Session()
role_creds_cache = {os.getenv('EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN'): {}, os.getenv('EGRESS_APP_DOWNLOAD_ROLE_ARN'): {}}
role_creds_cache = {
os.getenv('EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN'): {},
os.getenv('EGRESS_APP_DOWNLOAD_ROLE_ARN'): {}
}


def get_region():
def get_region() -> str:
"""
Will determine and return current AWS region.
:return: string describing AWS region
:type: string
"""
global region #pylint: disable=global-statement
global botosess #pylint: disable=global-statement
global region # pylint: disable=global-statement
global botosess # pylint: disable=global-statement
if not region:
region = botosess.region_name
return region


def retrieve_secret(secret_name):

global secret_cache # pylint: disable=global-statement
global botosess # pylint: disable=global-statement
@functools.lru_cache(maxsize=None)
def retrieve_secret(secret_name: str) -> dict:
global region # pylint: disable=global-statement
global botosess # pylint: disable=global-statement
t0 = time()

if secret_name in secret_cache:
log.debug('ET for retrieving secret {} from cache: {} sec'.format(secret_name, round(time() - t0, 4)))
return secret_cache[secret_name]

region_name = os.getenv('AWS_DEFAULT_REGION')

# Create a Secrets Manager client
Expand All @@ -64,40 +66,43 @@ def retrieve_secret(secret_name):
get_secret_value_response = client.get_secret_value(
SecretId=secret_name
)
log.info(return_timing_object(service="secretsmanager", endpoint=f"client().get_secret_value({secret_name})", duration=duration(timer)))
log.info(return_timing_object(
service="secretsmanager",
endpoint=f"client().get_secret_value({secret_name})",
duration=duration(timer)
))
except ClientError as e:
log.error("Encountered fatal error trying to reading URS Secret: {0}".format(e))
raise e
else:
# Decrypts secret using the associated KMS CMK.
# Depending on whether the secret is a string or binary, one of these fields will be populated.
if 'SecretString' in get_secret_value_response:

secret = loads(get_secret_value_response['SecretString'])
secret_cache[secret_name] = secret
log.debug('ET for retrieving secret {} from secret store: {} sec'.format(secret_name, round(time() - t0, 4)))
secret = json.loads(get_secret_value_response['SecretString'])
log.debug(f'ET for retrieving secret {secret_name} from secret store: {time() - t0:.4f} sec')
return secret

return {}


def get_s3_resource():
def get_s3_resource() -> boto_Session.resource:
"""

:return: subclass of boto3.resources.base.ServiceResource
"""
global s3_resource #pylint: disable=global-statement
global s3_resource # pylint: disable=global-statement
if not s3_resource:
params = {}
# Swift signature compatability
if os.getenv('S3_SIGNATURE_VERSION'):
params['config'] = bc_Config(signature_version=os.getenv('S3_SIGNATURE_VERSION'))
signature_version = os.getenv('S3_SIGNATURE_VERSION')
if signature_version:
params['config'] = bc_Config(signature_version=signature_version)
s3_resource = botoresource('s3', **params)

return s3_resource


def read_s3(bucket: str, key: str, s3: ServiceResource=None):
def read_s3(bucket: str, key: str, s3: ServiceResource = None) -> str:
"""
returns file
:type bucket: str
Expand All @@ -117,12 +122,16 @@ def read_s3(bucket: str, key: str, s3: ServiceResource=None):
obj = s3.Object(bucket, key)
log.debug('ET for reading {} from S3: {} sec'.format(key, round(time() - t0, 4)))
timer = time()
body = obj.get()['Body'].read().decode('utf-8')
log.info(return_timing_object(service="s3", endpoint=f"resource().Object(s3://{bucket}/{key}).get()", duration=duration(timer)))
body = obj.get()['Body'].read().decode('utf-8')
log.info(return_timing_object(
service="s3",
endpoint=f"resource().Object(s3://{bucket}/{key}).get()",
duration=duration(timer)
))
return body


def get_yaml(bucket: str, file_name: str):
def get_yaml(bucket: str, file_name: str) -> dict:
"""
Loads the YAML from a given bucket/filename
:param bucket: bucket name
Expand All @@ -133,68 +142,71 @@ def get_yaml(bucket: str, file_name: str):
cfg_yaml = read_s3(bucket, file_name)
return safe_load(cfg_yaml)
except ClientError as e:
log.error('Had trouble getting yaml file s3://{}/{}, {}'.format(bucket, file_name, e))
log.error('Could not download yaml file s3://{}/{}, {}'.format(bucket, file_name, e))
raise


def get_yaml_file(bucket, key):

def get_yaml_file(bucket: str, key: str) -> dict:
if not key:
# No file was provided, send empty dict
return {}
try:
log.info("Attempting to download yaml s3://{0}/{1}".format(bucket, key))
optional_file = get_yaml(bucket, key)
return optional_file
except ClientError as e:
# The specified file did not exist
log.error("Could not download yaml @ s3://{0}/{1}: {2}".format(bucket, key, e))
sys.exit()

def get_role_creds(user_id: str='', in_region: bool=False):
log.info("Attempting to download yaml file s3://{0}/{1}".format(bucket, key))
return get_yaml(bucket, key)


def get_role_creds(user_id: str = None, in_region: bool = False):
"""
:param user_id: string with URS username
:param in_region: boolean If True a download role that works only in region will be returned
:return: Returns a set of temporary security credentials (consisting of an access key ID, a secret access key, and a security token)
:return: Returns a set of temporary security credentials (consisting of an access key ID, a secret access key, and
a security token)
:return: Offset, in seconds for how long the STS session has been active
"""
global sts #pylint: disable=global-statement
global sts # pylint: disable=global-statement
if not user_id:
user_id = 'unauthenticated'

if in_region:
download_role_arn = os.getenv('EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN')
else:
download_role_arn = os.getenv('EGRESS_APP_DOWNLOAD_ROLE_ARN')
dl_arn_name=download_role_arn.split("/")[-1]
dl_arn_name = download_role_arn.split("/")[-1]

# chained role assumption like this CANNOT currently be extended past 1 Hour.
# https://aws.amazon.com/premiumsupport/knowledge-center/iam-role-chaining-limit/
now = time()
session_params = {"RoleArn": download_role_arn, "RoleSessionName": f"{user_id}@{round(now)}", "DurationSeconds": 3600 }
session_params = {
"RoleArn": download_role_arn,
"RoleSessionName": f"{user_id}@{round(now)}",
"DurationSeconds": 3600
}
session_offset = 0

if user_id not in role_creds_cache[download_role_arn]:
fresh_session = sts.assume_role(**session_params)
log.info(return_timing_object(service="sts", endpoint=f"client().assume_role({dl_arn_name}/{user_id})", duration=duration(now)))
role_creds_cache[download_role_arn][user_id] = {"session": fresh_session, "timestamp": now }
log.info(return_timing_object(
service="sts",
endpoint=f"client().assume_role({dl_arn_name}/{user_id})",
duration=duration(now)
))
role_creds_cache[download_role_arn][user_id] = {"session": fresh_session, "timestamp": now}
elif now - role_creds_cache[download_role_arn][user_id]["timestamp"] > 600:
# If the session has been active for more than 10 minutes, grab a new one.
log.info("Replacing 10 minute old session for {0}".format(user_id))
fresh_session = sts.assume_role(**session_params)
log.info(return_timing_object(service="sts", endpoint="client().assume_role()", duration=duration(now)))
role_creds_cache[download_role_arn][user_id] = {"session": fresh_session, "timestamp": now }
role_creds_cache[download_role_arn][user_id] = {"session": fresh_session, "timestamp": now}
else:
log.info("Reusing role credentials for {0}".format(user_id))
session_offset = round( now - role_creds_cache[download_role_arn][user_id]["timestamp"] )
session_offset = round(now - role_creds_cache[download_role_arn][user_id]["timestamp"])

log.debug(f'assuming role: {0}, role session username: {1}'.format(download_role_arn,user_id))
log.debug(f'assuming role: {0}, role session username: {1}'.format(download_role_arn, user_id))
return role_creds_cache[download_role_arn][user_id]["session"], session_offset


def get_role_session(creds=None, user_id=None):

global session_cache #pylint: disable=global-statement
def get_role_session(creds: dict = None, user_id: str = None) -> boto_Session:
global session_cache # pylint: disable=global-statement
sts_resp = creds if creds else get_role_creds(user_id)[0]
log.debug('sts_resp: {0}'.format(sts_resp))

Expand All @@ -211,41 +223,39 @@ def get_role_session(creds=None, user_id=None):
return session_cache[session_id]


def get_region_cidr_ranges():
def get_region_cidr_ranges() -> list:
"""
:return: Utility function to download AWS regions
"""
global region_list_cache # pylint: disable=global-statement

global region_list_cache #pylint: disable=global-statement

if not region_list_cache: #pylint: disable=used-before-assignment
if not region_list_cache: # pylint: disable=used-before-assignment
url = 'https://ip-ranges.amazonaws.com/ip-ranges.json'
now = time()
req = urllib.request.Request(url)
r = urllib.request.urlopen(req).read() #nosec URL is *always* https://ip-ranges...
r = urllib.request.urlopen(req).read() # nosec URL is *always* https://ip-ranges...
log.info(return_timing_object(service="AWS", endpoint=url, duration=duration(now)))
region_list_json = loads(r.decode('utf-8'))
region_list_cache = []

region_list_json = json.loads(r.decode('utf-8'))
# Sort out ONLY values from this AWS region
for pre in region_list_json["prefixes"]:
if "ip_prefix" in pre and "region" in pre:
if pre["region"] == get_region():
region_list_cache.append(IPNetwork(pre["ip_prefix"]))
this_region = get_region()
region_list_cache = [
IPNetwork(pre["ip_prefix"]) for pre in region_list_json["prefixes"]
if "ip_prefix" in pre and "region" in pre and pre["region"] == this_region
]

return region_list_cache


def check_in_region_request(ip_addr: str):
def check_in_region_request(ip_addr: str) -> bool:
"""
:param ip_addr: string with ip address to be checked for in-regionness
:return: boolean True if ip_addr is in_region, False otherwise
:type: Boolean
"""

addr = IPAddress(ip_addr)
for cidr in get_region_cidr_ranges():
#log.debug("Checking ip {0} vs cidr {1}".format(user_ip, cidr))
if IPAddress(ip_addr) in cidr:
if addr in cidr:
log.info("IP {0} matched in-region CIDR {1}".format(ip_addr, cidr))
return True

Expand Down
Loading