From e067ba84f8064d7d8f8dfe37b171e14283abd2de Mon Sep 17 00:00:00 2001 From: Rohan Weeden Date: Fri, 21 Jan 2022 16:12:20 -0900 Subject: [PATCH] Refactor util files --- rain_api_core/aws_util.py | 101 +++++++++++++++++----------------- rain_api_core/egress_util.py | 58 +++++++++---------- rain_api_core/general_util.py | 98 ++++++++++++++++++--------------- rain_api_core/urs_util.py | 87 ++++++++++++++--------------- rain_api_core/view_util.py | 85 +++++++++++++--------------- 5 files changed, 217 insertions(+), 212 deletions(-) diff --git a/rain_api_core/aws_util.py b/rain_api_core/aws_util.py index 5bb756e..880e5fd 100644 --- a/rain_api_core/aws_util.py +++ b/rain_api_core/aws_util.py @@ -1,28 +1,35 @@ - +import functools +import json import logging import os import sys import urllib -from netaddr import IPAddress, IPNetwork -from json import loads from time import time -from yaml import safe_load -from boto3 import client as botoclient, resource as botoresource, session as botosession, Session as boto_Session + +from boto3 import Session as boto_Session +from boto3 import client as botoclient +from boto3 import resource as botoresource +from boto3 import session as botosession from boto3.resources.base import ServiceResource from botocore.config import Config as bc_Config from botocore.exceptions import ClientError +from netaddr import IPAddress, IPNetwork +from yaml import safe_load -from rain_api_core.general_util import return_timing_object, duration +from rain_api_core.general_util import duration, return_timing_object log = logging.getLogger(__name__) sts = botoclient('sts') -secret_cache = {} session_cache = {} region_list_cache = [] s3_resource = None region = '' botosess = botosession.Session() -role_creds_cache = {os.getenv('EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN'): {}, os.getenv('EGRESS_APP_DOWNLOAD_ROLE_ARN'): {}} +role_creds_cache = { + os.getenv('EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN'): {}, + os.getenv('EGRESS_APP_DOWNLOAD_ROLE_ARN'): {} +} + def get_region(): """ @@ -30,23 +37,19 @@ def get_region(): :return: string describing AWS region :type: string """ - global region #pylint: disable=global-statement - global botosess #pylint: disable=global-statement + global region # pylint: disable=global-statement + global botosess # pylint: disable=global-statement if not region: region = botosess.region_name return region +@functools.lru_cache(maxsize=None) def retrieve_secret(secret_name): - - global secret_cache # pylint: disable=global-statement - global botosess # pylint: disable=global-statement + global region # pylint: disable=global-statement + global botosess # pylint: disable=global-statement t0 = time() - if secret_name in secret_cache: - log.debug('ET for retrieving secret {} from cache: {} sec'.format(secret_name, round(time() - t0, 4))) - return secret_cache[secret_name] - region_name = os.getenv('AWS_DEFAULT_REGION') # Create a Secrets Manager client @@ -72,9 +75,7 @@ def retrieve_secret(secret_name): # Decrypts secret using the associated KMS CMK. # Depending on whether the secret is a string or binary, one of these fields will be populated. if 'SecretString' in get_secret_value_response: - - secret = loads(get_secret_value_response['SecretString']) - secret_cache[secret_name] = secret + secret = json.loads(get_secret_value_response['SecretString']) log.debug('ET for retrieving secret {} from secret store: {} sec'.format(secret_name, round(time() - t0, 4))) return secret @@ -86,18 +87,19 @@ def get_s3_resource(): :return: subclass of boto3.resources.base.ServiceResource """ - global s3_resource #pylint: disable=global-statement + global s3_resource # pylint: disable=global-statement if not s3_resource: params = {} # Swift signature compatability - if os.getenv('S3_SIGNATURE_VERSION'): - params['config'] = bc_Config(signature_version=os.getenv('S3_SIGNATURE_VERSION')) + signature_version = os.getenv('S3_SIGNATURE_VERSION') + if signature_version: + params['config'] = bc_Config(signature_version=signature_version) s3_resource = botoresource('s3', **params) return s3_resource -def read_s3(bucket: str, key: str, s3: ServiceResource=None): +def read_s3(bucket: str, key: str, s3: ServiceResource = None): """ returns file :type bucket: str @@ -117,7 +119,7 @@ def read_s3(bucket: str, key: str, s3: ServiceResource=None): obj = s3.Object(bucket, key) log.debug('ET for reading {} from S3: {} sec'.format(key, round(time() - t0, 4))) timer = time() - body = obj.get()['Body'].read().decode('utf-8') + body = obj.get()['Body'].read().decode('utf-8') log.info(return_timing_object(service="s3", endpoint=f"resource().Object(s3://{bucket}/{key}).get()", duration=duration(timer))) return body @@ -138,7 +140,6 @@ def get_yaml(bucket: str, file_name: str): def get_yaml_file(bucket, key): - if not key: # No file was provided, send empty dict return {} @@ -152,14 +153,15 @@ def get_yaml_file(bucket, key): # TODO(reweeden): remove this, why is this here!!!? sys.exit() -def get_role_creds(user_id: str='', in_region: bool=False): + +def get_role_creds(user_id: str = '', in_region: bool = False): """ :param user_id: string with URS username :param in_region: boolean If True a download role that works only in region will be returned :return: Returns a set of temporary security credentials (consisting of an access key ID, a secret access key, and a security token) :return: Offset, in seconds for how long the STS session has been active """ - global sts #pylint: disable=global-statement + global sts # pylint: disable=global-statement if not user_id: user_id = 'unauthenticated' @@ -167,35 +169,38 @@ def get_role_creds(user_id: str='', in_region: bool=False): download_role_arn = os.getenv('EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN') else: download_role_arn = os.getenv('EGRESS_APP_DOWNLOAD_ROLE_ARN') - dl_arn_name=download_role_arn.split("/")[-1] + dl_arn_name = download_role_arn.split("/")[-1] # chained role assumption like this CANNOT currently be extended past 1 Hour. # https://aws.amazon.com/premiumsupport/knowledge-center/iam-role-chaining-limit/ now = time() - session_params = {"RoleArn": download_role_arn, "RoleSessionName": f"{user_id}@{round(now)}", "DurationSeconds": 3600 } + session_params = { + "RoleArn": download_role_arn, + "RoleSessionName": f"{user_id}@{round(now)}", + "DurationSeconds": 3600 + } session_offset = 0 if user_id not in role_creds_cache[download_role_arn]: fresh_session = sts.assume_role(**session_params) log.info(return_timing_object(service="sts", endpoint=f"client().assume_role({dl_arn_name}/{user_id})", duration=duration(now))) - role_creds_cache[download_role_arn][user_id] = {"session": fresh_session, "timestamp": now } + role_creds_cache[download_role_arn][user_id] = {"session": fresh_session, "timestamp": now} elif now - role_creds_cache[download_role_arn][user_id]["timestamp"] > 600: # If the session has been active for more than 10 minutes, grab a new one. log.info("Replacing 10 minute old session for {0}".format(user_id)) fresh_session = sts.assume_role(**session_params) log.info(return_timing_object(service="sts", endpoint="client().assume_role()", duration=duration(now))) - role_creds_cache[download_role_arn][user_id] = {"session": fresh_session, "timestamp": now } + role_creds_cache[download_role_arn][user_id] = {"session": fresh_session, "timestamp": now} else: log.info("Reusing role credentials for {0}".format(user_id)) - session_offset = round( now - role_creds_cache[download_role_arn][user_id]["timestamp"] ) + session_offset = round(now - role_creds_cache[download_role_arn][user_id]["timestamp"]) - log.debug(f'assuming role: {0}, role session username: {1}'.format(download_role_arn,user_id)) + log.debug(f'assuming role: {0}, role session username: {1}'.format(download_role_arn, user_id)) return role_creds_cache[download_role_arn][user_id]["session"], session_offset def get_role_session(creds=None, user_id=None): - - global session_cache #pylint: disable=global-statement + global session_cache # pylint: disable=global-statement sts_resp = creds if creds else get_role_creds(user_id)[0] log.debug('sts_resp: {0}'.format(sts_resp)) @@ -216,23 +221,21 @@ def get_region_cidr_ranges(): """ :return: Utility function to download AWS regions """ + global region_list_cache # pylint: disable=global-statement - global region_list_cache #pylint: disable=global-statement - - if not region_list_cache: #pylint: disable=used-before-assignment + if not region_list_cache: # pylint: disable=used-before-assignment url = 'https://ip-ranges.amazonaws.com/ip-ranges.json' now = time() req = urllib.request.Request(url) - r = urllib.request.urlopen(req).read() #nosec URL is *always* https://ip-ranges... + r = urllib.request.urlopen(req).read() # nosec URL is *always* https://ip-ranges... log.info(return_timing_object(service="AWS", endpoint=url, duration=duration(now))) - region_list_json = loads(r.decode('utf-8')) - region_list_cache = [] - + region_list_json = json.loads(r.decode('utf-8')) # Sort out ONLY values from this AWS region - for pre in region_list_json["prefixes"]: - if "ip_prefix" in pre and "region" in pre: - if pre["region"] == get_region(): - region_list_cache.append(IPNetwork(pre["ip_prefix"])) + this_region = get_region() + region_list_cache = [ + IPNetwork(pre["ip_prefix"]) for pre in region_list_json["prefixes"] + if "ip_prefix" in pre and "region" in pre and pre["region"] == this_region + ] return region_list_cache @@ -244,9 +247,9 @@ def check_in_region_request(ip_addr: str): :type: Boolean """ + addr = IPAddress(ip_addr) for cidr in get_region_cidr_ranges(): - #log.debug("Checking ip {0} vs cidr {1}".format(user_ip, cidr)) - if IPAddress(ip_addr) in cidr: + if addr in cidr: log.info("IP {0} matched in-region CIDR {1}".format(ip_addr, cidr)) return True diff --git a/rain_api_core/egress_util.py b/rain_api_core/egress_util.py index a961bfe..ff12f4e 100644 --- a/rain_api_core/egress_util.py +++ b/rain_api_core/egress_util.py @@ -1,31 +1,29 @@ -import logging import hmac -from hashlib import sha256 +import logging import os import urllib from datetime import datetime +from hashlib import sha256 log = logging.getLogger(__name__) # This warning is stupid # pylint: disable=logging-fstring-interpolation -def prepend_bucketname(name): - prefix = os.getenv('BUCKETNAME_PREFIX', "gsfc-ngap-{}-".format(os.getenv('MATURITY', 'DEV')[0:1].lower())) +def prepend_bucketname(name): + prefix = os.getenv('BUCKETNAME_PREFIX', "gsfc-ngap-{}-".format(os.getenv('MATURITY', 'DEV')[:1].lower())) return "{}{}".format(prefix, name) def hmacsha256(key, string): - return hmac.new(key, string.encode('utf-8'), sha256) def get_presigned_url(session, bucket_name, object_name, region_name, expire_seconds, user_id, method='GET'): - timez = datetime.utcnow().strftime('%Y%m%dT%H%M%SZ') datez = timez[:8] - hostname = "{0}.s3{1}.amazonaws.com".format(bucket_name, "."+region_name if region_name != "us-east-1" else "") + hostname = "{0}.s3{1}.amazonaws.com".format(bucket_name, "." + region_name if region_name != "us-east-1" else "") cred = session['Credentials']['AccessKeyId'] secret = session['Credentials']['SecretAccessKey'] @@ -53,11 +51,10 @@ def get_presigned_url(session, bucket_name, object_name, region_name, expire_sec stringtosign = "\n".join(["AWS4-HMAC-SHA256", timez, aws4_request, can_req_hash]) # Signing Key - StepOne = hmacsha256( "AWS4{0}".format(secret).encode('utf-8'), datez).digest() - StepTwo = hmacsha256( StepOne, region_name ).digest() - StepThree = hmacsha256( StepTwo, "s3").digest() - SigningKey = hmacsha256( StepThree, "aws4_request").digest() - + StepOne = hmacsha256("AWS4{0}".format(secret).encode('utf-8'), datez).digest() + StepTwo = hmacsha256(StepOne, region_name).digest() + StepThree = hmacsha256(StepTwo, "s3").digest() + SigningKey = hmacsha256(StepThree, "aws4_request").digest() # Final Signature Signature = hmacsha256(SigningKey, stringtosign).hexdigest() @@ -68,7 +65,6 @@ def get_presigned_url(session, bucket_name, object_name, region_name, expire_sec def get_bucket_dynamic_path(path_list, b_map): - # Old and REVERSE format has no 'MAP'. In either case, we don't want it fouling our dict. if 'MAP' in b_map: map_dict = b_map['MAP'] @@ -81,7 +77,7 @@ def get_bucket_dynamic_path(path_list, b_map): # walk the bucket map to see if this path is valid for path_part in path_list: # Check if we hit a leaf of the YAML tree - if (mapping and isinstance(map_dict, str)) or 'bucket' in map_dict: # + if (mapping and isinstance(map_dict, str)) or 'bucket' in map_dict: customheaders = {} if isinstance(map_dict, dict) and 'bucket' in map_dict: bucketname = map_dict['bucket'] @@ -130,31 +126,33 @@ def process_varargs(varargs: list, b_map: dict): def process_request(varargs, b_map): - - varargs = varargs.split("/") + split_args = varargs.split("/") # Make sure we got at least 1 path, and 1 file name: - if len(varargs) < 2: - return "/".join(varargs), None, None, [] + if len(split_args) < 2: + return varargs, None, None, {} # Watch for ASF-ish reverse URL mapping formats: - if len(varargs) == 3: + if len(split_args) == 3: if os.getenv('USE_REVERSE_BUCKET_MAP', 'FALSE').lower() == 'true': - varargs[0], varargs[1] = varargs[1], varargs[0] + split_args[0], split_args[1] = split_args[1], split_args[0] # Look up the bucket from path parts - bucket, path, object_name, headers = get_bucket_dynamic_path(varargs, b_map) + bucket, path, object_name, headers = get_bucket_dynamic_path(split_args, b_map) # If we didn't figure out the bucket, we don't know the path/object_name if not bucket: - object_name = varargs.pop(-1) - path = "/".join(varargs) + object_name = split_args.pop(-1) + path = "/".join(split_args) return path, bucket, object_name, headers + def bucket_prefix_match(bucket_check, bucket_map, object_name=""): + # NOTE: https://github.com/asfadmin/thin-egress-app/issues/188 log.debug(f"bucket_prefix_match(): checking if {bucket_check} matches {bucket_map} w/ optional obj '{object_name}'") - if bucket_check == bucket_map.split('/')[0] and object_name.startswith("/".join(bucket_map.split('/')[1:])): + prefix, *tail = bucket_map.split("/", 1) + if bucket_check == prefix and object_name.startswith("/".join(tail)): log.debug(f"Prefixed Bucket Map matched: s3://{bucket_check}/{object_name} => {bucket_map}") return True return False @@ -168,21 +166,22 @@ def get_sorted_bucket_list(b_map, bucket_group): return [] # b_map[bucket_group] SHOULD be a dict, but list actually works too. - if isinstance(b_map[bucket_group], dict): - return sorted(list(b_map[bucket_group].keys()), key=lambda e: e.count("/"), reverse=True ) + if isinstance(b_map[bucket_group], dict): + return sorted(list(b_map[bucket_group].keys()), key=lambda e: e.count("/"), reverse=True) if isinstance(b_map[bucket_group], list): - return sorted(list(b_map[bucket_group]), key=lambda e: e.count("/"), reverse=True ) + return sorted(list(b_map[bucket_group]), key=lambda e: e.count("/"), reverse=True) # Something went wrong. return [] -def check_private_bucket(bucket, b_map, object_name=""): +def check_private_bucket(bucket, b_map, object_name=""): log.debug('check_private_buckets(): bucket: {}'.format(bucket)) # Check public bucket file: if 'PRIVATE_BUCKETS' in b_map: # Prioritize prefixed buckets first, the deeper the better! + # TODO(reweeden): cache the sorted list (refactoring to object would be easiest) sorted_buckets = get_sorted_bucket_list(b_map, 'PRIVATE_BUCKETS') log.debug(f"Sorted PRIVATE buckets are {sorted_buckets}") for priv_bucket in sorted_buckets: @@ -192,10 +191,11 @@ def check_private_bucket(bucket, b_map, object_name=""): return False -def check_public_bucket(bucket, b_map, object_name=""): +def check_public_bucket(bucket, b_map, object_name=""): # Check for PUBLIC_BUCKETS in bucket map file if 'PUBLIC_BUCKETS' in b_map: + # TODO(reweeden): cache the sorted list (refactoring to object would be easiest) sorted_buckets = get_sorted_bucket_list(b_map, 'PUBLIC_BUCKETS') log.debug(f"Sorted PUBLIC buckets are {sorted_buckets}") for pub_bucket in sorted_buckets: diff --git a/rain_api_core/general_util.py b/rain_api_core/general_util.py index 2fd08fc..cc60f46 100644 --- a/rain_api_core/general_util.py +++ b/rain_api_core/general_util.py @@ -1,3 +1,4 @@ +import contextlib import json import logging import os @@ -16,8 +17,8 @@ "replace": "\\g<1>XXXXXX\\g<2>", "description": "X-out non-JWT EDL token" }, - { "regex": r"(Basic [A-Za-z0-9-_]{5})[A-Za-z0-9]*([A-Za-z0-9-_]{5})", - "replace": "\\g<1>XXXXXX\\g<2>", + { "regex": r"(Basic )[A-Za-z0-9+/=]{4,}", + "replace": "\\g<1>XXXXXX", "description": "X-out Basic Auth Credentials" }, { "regex": r"([^A-Za-z0-9/+=][A-Za-z0-9/+=]{5})[A-Za-z0-9/+=]{30}([A-Za-z0-9/+=]{5}[^A-Za-z0-9/+=])", @@ -26,15 +27,24 @@ } ] + def return_timing_object(**timing): - timing_object = { "service": "Unknown", "endpoint": "Unknown", "method": "GET", "duration": 0, "unit": "milliseconds"} - timing_object.update({k.lower(): v for k,v in timing.items()}) - return {"timing":timing_object } + timing_object = { + "service": "Unknown", + "endpoint": "Unknown", + "method": "GET", + "duration": 0, + "unit": "milliseconds" + } + timing_object.update({k.lower(): v for k, v in timing.items()}) + return {"timing": timing_object} + def duration(time_in): # Return the time duration in milliseconds delta = time.time() - time_in - return(float("{:.2f}".format(delta*1000))) + return round(delta * 1000, ndigits=2) + def filter_log_credentials(msg): if UNCENSORED_LOGGING: @@ -49,29 +59,27 @@ def filter_log_credentials(msg): def reformat_for_json(msg): - if type(msg) is dict: - return json.dumps(msg).replace("'", '"') - if isinstance(msg, str) and '{' in msg: - try: - json_obj = json.loads(msg) - return json.dumps(json_obj).replace("'", '"') - except json.decoder.JSONDecodeError: - # Not JSON. - pass + if isinstance(msg, dict): + return json.dumps(msg) + if isinstance(msg, str): + if '{' in msg: + with contextlib.suppress(json.decoder.JSONDecodeError): + return json.dumps(json.loads(msg)) + return msg return str(msg) class CustomLogFilter(logging.Filter): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.params = { 'build_vers': os.getenv("BUILD_VERSION", "NOBUILD"), - 'maturity': os.getenv('MATURITY', 'DEV'), - 'request_id': None, - 'origin_request_id': None, - 'user_id': None, - 'route': None - } + self.params = { + 'build_vers': os.getenv("BUILD_VERSION", "NOBUILD"), + 'maturity': os.getenv('MATURITY', 'DEV'), + 'request_id': None, + 'origin_request_id': None, + 'user_id': None, + 'route': None + } def filter(self, record): record.msg = filter_log_credentials(reformat_for_json(record.msg)) @@ -84,8 +92,7 @@ def filter(self, record): return True def update(self, **context): - for key in context: - self.params.update({key: context[key]}) + self.params.update(context) custom_log_filter = CustomLogFilter() @@ -96,35 +103,38 @@ def log_context(**context): def get_log(): - loglevel = os.getenv('LOGLEVEL', 'INFO') logtype = os.getenv('LOGTYPE', 'json') if logtype == 'flat': - log_fmt_str = "%(levelname)s: %(message)s (%(filename)s line " + \ - "%(lineno)d/%(build_vers)s/%(maturity)s) - " + \ - "RequestId: %(request_id)s; OriginRequestId: %(origin_request_id)s; user_id: %(user_id)s; route: %(route)s" + log_fmt_str = ( + "%(levelname)s: %(message)s (%(filename)s line %(lineno)d/%(build_vers)s/%(maturity)s) - " + "RequestId: %(request_id)s; OriginRequestId: %(origin_request_id)s; user_id: %(user_id)s; route: %(route)s" + ) else: - log_fmt_str = '{"level": "%(levelname)s", ' + \ - '"RequestId": "%(request_id)s", ' + \ - '"OriginRequestId": "%(origin_request_id)s", ' + \ - '"message": %(message)s, ' + \ - '"maturity": "%(maturity)s", ' + \ - '"user_id": "%(user_id)s", ' + \ - '"route": "%(route)s", ' + \ - '"build": "%(build_vers)s", ' + \ - '"filename": "%(filename)s", ' + \ - '"lineno": %(lineno)d } ' + log_fmt_str = ( + '{"level": "%(levelname)s", ' + '"RequestId": "%(request_id)s", ' + '"OriginRequestId": "%(origin_request_id)s", ' + '"message": "%(message)s", ' + '"maturity": "%(maturity)s", ' + '"user_id": "%(user_id)s", ' + '"route": "%(route)s", ' + '"build": "%(build_vers)s", ' + '"filename": "%(filename)s", ' + '"lineno": %(lineno)d}' + ) logger = logging.getLogger() for h in logger.handlers: logger.removeHandler(h) - h = logging.StreamHandler(sys.stdout) - h.setFormatter(logging.Formatter(log_fmt_str)) - h.addFilter(custom_log_filter) - logger.addHandler(h) - logger.setLevel(getattr(logging, loglevel)) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter(log_fmt_str)) + handler.addFilter(custom_log_filter) + + logger.addHandler(handler) + logger.setLevel(loglevel) if os.getenv("QUIETBOTO", 'TRUE').upper() == 'TRUE': # BOTO, be quiet plz diff --git a/rain_api_core/urs_util.py b/rain_api_core/urs_util.py index 3ee8ccc..86aff5d 100644 --- a/rain_api_core/urs_util.py +++ b/rain_api_core/urs_util.py @@ -1,13 +1,12 @@ - +import json import logging import os import urllib from time import time -from json import loads -from rain_api_core.general_util import log_context, return_timing_object, duration -from rain_api_core.view_util import make_set_cookie_headers_jwt, get_exp_time, JWT_COOKIE_NAME -from rain_api_core.aws_util import retrieve_secret +from rain_api_core.aws_util import retrieve_secret +from rain_api_core.general_util import duration, log_context, return_timing_object +from rain_api_core.view_util import JWT_COOKIE_NAME, get_exp_time, make_set_cookie_headers_jwt log = logging.getLogger(__name__) @@ -15,8 +14,8 @@ def get_base_url(ctxt=False): # Make a redirect url using optional custom domain_name, otherwise use raw domain/stage provided by API Gateway. try: - return 'https://{}/'.format( - os.getenv('DOMAIN_NAME', '{}/{}'.format(ctxt['domainName'], ctxt['stage']))) + domain = os.getenv('DOMAIN_NAME') or '{}/{}'.format(ctxt['domainName'], ctxt['stage']) + return f'https://{domain}/' except (TypeError, IndexError) as e: # TODO(reweeden): Should `IndexError` actually be `KeyError`? log.error('could not create a redirect_url, because {}'.format(e)) @@ -28,7 +27,7 @@ def get_redirect_url(ctxt=False): def do_auth(code, redirect_url, aux_headers=None): - aux_headers = aux_headers or {} # A safer default + aux_headers = aux_headers or {} # A safer default url = os.getenv('AUTH_BASE_URL', 'https://urs.earthdata.nasa.gov') + "/oauth/token" # App U:P from URS Application @@ -50,13 +49,13 @@ def do_auth(code, redirect_url, aux_headers=None): log.debug('url: {}'.format(url)) log.debug('post_data: {}'.format(post_data)) - response = urllib.request.urlopen(post_request) #nosec URL is *always* URS. + response = urllib.request.urlopen(post_request) # nosec URL is *always* URS. t1 = time() packet = response.read() log.debug('ET to do_auth() urlopen(): {} sec'.format(t1 - t0)) log.debug('ET to do_auth() request to URS: {} sec'.format(time() - t0)) log.info(return_timing_object(service="EDL", endpoint=url, method="POST", duration=duration(t0))) - return loads(packet) + return json.loads(packet) # TODO(reweeden): can there be other errors such as HTTPError? except urllib.error.URLError as e: @@ -66,7 +65,6 @@ def do_auth(code, redirect_url, aux_headers=None): def get_urs_url(ctxt, to=False): - base_url = os.getenv('AUTH_BASE_URL', 'https://urs.earthdata.nasa.gov') + '/oauth/authorize' # From URS Application @@ -77,12 +75,12 @@ def get_urs_url(ctxt, to=False): urs_url = '{0}?client_id={1}&response_type=code&redirect_uri={2}'.format(base_url, client_id, get_redirect_url(ctxt)) if to: - urs_url += "&state={0}".format(to) + urs_url += f"&state={to}" # Try to handle scripts try: download_agent = ctxt['identity']['userAgent'] - except IndexError: + except KeyError: log.debug("No User Agent!") return urs_url @@ -93,7 +91,7 @@ def get_urs_url(ctxt, to=False): def get_profile(user_id, token, temptoken=False, aux_headers=None): - aux_headers = aux_headers or {} # Safer Default + aux_headers = aux_headers or {} # Safer Default if not user_id or not token: return {} @@ -115,13 +113,13 @@ def get_profile(user_id, token, temptoken=False, aux_headers=None): response = urllib.request.urlopen(req) # nosec URL is *always* URS. packet = response.read() log.info(return_timing_object(service="EDL", endpoint=url, duration=duration(timer))) - user_profile = loads(packet) + user_profile = json.loads(packet) return user_profile except urllib.error.URLError as e: log.warning("Error fetching profile: {0}".format(e)) - if not temptoken: # This keeps get_new_token_and_profile() from calling this over and over + if not temptoken: # This keeps get_new_token_and_profile() from calling this over and over log.debug('because error above, going to get_new_token_and_profile()') return get_new_token_and_profile(user_id, token, aux_headers) @@ -130,14 +128,14 @@ def get_profile(user_id, token, temptoken=False, aux_headers=None): def get_new_token_and_profile(user_id, cookietoken, aux_headers=None): - aux_headers = aux_headers or {} # A safer default + aux_headers = aux_headers or {} # A safer default # get a new token url = os.getenv('AUTH_BASE_URL', 'https://urs.earthdata.nasa.gov') + "/oauth/token" # App U:P from URS Application auth = get_urs_creds()['UrsAuth'] - post_data = {"grant_type": "client_credentials" } + post_data = {"grant_type": "client_credentials"} headers = {"Authorization": "Basic " + auth} headers.update(aux_headers) @@ -149,15 +147,15 @@ def get_new_token_and_profile(user_id, cookietoken, aux_headers=None): try: log.info("Attempting to get new Token") - response = urllib.request.urlopen(post_request) #nosec URL is *always* URS. + response = urllib.request.urlopen(post_request) # nosec URL is *always* URS. t1 = time() packet = response.read() log.info(return_timing_object(service="EDL", endpoint=url, duration=duration(t0))) - new_token = loads(packet)['access_token'] + new_token = json.loads(packet)['access_token'] t2 = time() log.info("Retrieved new token: {0}".format(new_token)) log.debug('ET for get_new_token_and_profile() urlopen() {} sec'.format(t1 - t0)) - log.debug('ET for get_new_token_and_profile() response.read() and loads() {} sec'.format(t2- t1)) + log.debug('ET for get_new_token_and_profile() response.read() and json.loads() {} sec'.format(t2 - t1)) # Get user profile with new token return get_profile(user_id, cookietoken, new_token, aux_headers=aux_headers) @@ -170,17 +168,18 @@ def get_new_token_and_profile(user_id, cookietoken, aux_headers=None): def user_in_group_list(private_groups, user_groups): client_id = get_urs_creds()['UrsId'] log.info("Searching for private groups {0} in {1}".format(private_groups, user_groups)) - for u_g in user_groups: - if u_g['client_id'] == client_id: - for p_g in private_groups: - if p_g == u_g['name']: - # Found the matching group! - log.info("User belongs to private group {}".format(p_g)) - return True + + group_names = {group["name"] for group in user_groups if group["client_id"] == client_id} + + for group in private_groups: + if group in group_names: + log.info("User belongs to private group {}".format(group)) + return True + return False def user_in_group_urs(private_groups, user_id, token, user_profile=None, refresh_first=False, aux_headers=None): - aux_headers = aux_headers or {} # A safer default + aux_headers = aux_headers or {} # A safer default new_profile = {} if refresh_first or not user_profile: @@ -203,7 +202,7 @@ def user_in_group_urs(private_groups, user_id, token, user_profile=None, refresh def user_in_group(private_groups, cookievars, refresh_first=False, aux_headers=None): - aux_headers = aux_headers or {} # A safer default + aux_headers = aux_headers or {} # A safer default # If a new profile is fetched, it is assigned to this var, and returned so that a fresh jwt cookie can be set. new_profile = {} @@ -242,11 +241,12 @@ def get_urs_creds(): } :type: dict """ - secret_name = os.getenv('URS_CREDS_SECRET_NAME', None) + secret_name = os.getenv('URS_CREDS_SECRET_NAME') if not secret_name: log.error('URS_CREDS_SECRET_NAME not set') return {} + secret = retrieve_secret(secret_name) if not ('UrsId' in secret and 'UrsAuth' in secret): log.error('AWS secret {} does not contain required keys "UrsId" and "UrsAuth"'.format(secret_name)) @@ -256,21 +256,21 @@ def get_urs_creds(): def user_profile_2_jwt_payload(user_id, access_token, user_profile): return { - # Do we want more items in here? - 'first_name': user_profile['first_name'], - 'last_name': user_profile['last_name'], - 'email': user_profile['email_address'], - 'urs-user-id': user_id, - 'urs-access-token': access_token, - 'urs-groups': user_profile['user_groups'], - 'iat': int(time()), - 'exp': get_exp_time(), - } + # Do we want more items in here? + 'first_name': user_profile['first_name'], + 'last_name': user_profile['last_name'], + 'email': user_profile['email_address'], + 'urs-user-id': user_id, + 'urs-access-token': access_token, + 'urs-groups': user_profile['user_groups'], + 'iat': int(time()), + 'exp': get_exp_time(), + } # This do_login() is mainly for chalice clients. def do_login(args, context, cookie_domain='', aux_headers=None): - aux_headers = aux_headers or {} # A safer default + aux_headers = aux_headers or {} # A safer default log.debug('the query_params: {}'.format(args)) @@ -295,7 +295,7 @@ def do_login(args, context, cookie_domain='', aux_headers=None): if 'code' not in args: contentstring = 'Did not get the required CODE from URS' - template_vars = {'contentstring': contentstring, 'title': 'Could not login.'} + template_vars = {'contentstring': contentstring, 'title': 'Could Not Login'} headers = {} return 400, template_vars, headers @@ -322,6 +322,7 @@ def do_login(args, context, cookie_domain='', aux_headers=None): else: redirect_to = get_base_url(context) + # TODO(reweeden): Why is the last check there if we're setting to the empty list anyways? if 'user_groups' not in user_profile or not user_profile['user_groups']: user_profile['user_groups'] = [] diff --git a/rain_api_core/view_util.py b/rain_api_core/view_util.py index c4a328b..8d138b9 100644 --- a/rain_api_core/view_util.py +++ b/rain_api_core/view_util.py @@ -1,8 +1,11 @@ import base64 +import contextlib +import functools import json import logging import os import urllib +from http.cookies import CookieError, SimpleCookie from pathlib import Path from time import time from wsgiref.handlers import format_date_time as format_7231_date @@ -23,46 +26,39 @@ HTML_TEMPLATE_LOCAL_CACHEDIR = '/tmp/templates/' # nosec We want to leverage instance persistance HTML_TEMPLATE_PROJECT_DIR = Path().resolve() / 'templates' +# TODO(reweeden): explain what is going on here SESSTTL = int(os.getenv('SESSION_TTL', '168')) * 60 * 60 JWT_ALGO = os.getenv('JWT_ALGO', 'RS256') -JWT_KEYS = {} JWT_COOKIE_NAME = os.getenv('JWT_COOKIENAME', 'asf-urs') JWT_BLACKLIST = {} +@functools.lru_cache(maxsize=None) def get_jwt_keys(): - global JWT_KEYS # pylint: disable=global-statement - - if JWT_KEYS: - # Cached - return JWT_KEYS raw_keys = retrieve_secret(os.getenv('JWT_KEY_SECRET_NAME', '')) - return_dict = {} - - for k in raw_keys: - return_dict[k] = base64.b64decode(raw_keys[k].encode('utf-8')) - - JWT_KEYS = return_dict # Cache it - return return_dict + return { + k: base64.b64decode(v.encode('utf-8')) + for k, v in raw_keys.items() + } -def cache_html_templates(): +def cache_html_templates() -> str: try: os.mkdir(HTML_TEMPLATE_LOCAL_CACHEDIR, 0o700) except FileExistsError: # good. log.debug('somehow, {} exists already'.format(HTML_TEMPLATE_LOCAL_CACHEDIR)) - if os.getenv('HTML_TEMPLATE_DIR', '') == '': + templatedir = os.getenv('HTML_TEMPLATE_DIR') + if not templatedir: return 'DEFAULT' bucket = os.getenv('CONFIG_BUCKET') - templatedir = os.getenv('HTML_TEMPLATE_DIR') - if not templatedir[-1] == '/': # we need a trailing slash - templatedir = '{}/'.format(templatedir) + if not templatedir.endswith('/'): # we need a trailing slash + templatedir = f'{templatedir}/' timer = time() client = botoclient('s3') @@ -84,7 +80,7 @@ def cache_html_templates(): return 'ERROR' -def get_html_body(template_vars: dict, templatefile: str = 'root.html'): +def get_html_body(template_vars: dict, templatefile: str = 'root.html') -> str: global HTML_TEMPLATE_STATUS # pylint: disable=global-statement if HTML_TEMPLATE_STATUS == '': @@ -100,13 +96,12 @@ def get_html_body(template_vars: dict, templatefile: str = 'root.html'): autoescape=select_autoescape(['html', 'xml']) ) try: - jin_tmp = jin_env.get_template(templatefile) - + template = jin_env.get_template(templatefile) except TemplateNotFound as e: log.error('Template not found: {}'.format(e)) return 'Cannot find the HTML template directory' - return jin_tmp.render(**template_vars) + return template.render(**template_vars) def get_cookie_vars(headers: dict): @@ -117,19 +112,16 @@ def get_cookie_vars(headers: dict): :type: dict """ cooks = get_cookies(headers) - # log.debug('cooks: {}'.format(cooks)) - cookie_vars = {} try: if JWT_COOKIE_NAME in cooks: decoded_payload = decode_jwt_payload(cooks[JWT_COOKIE_NAME], JWT_ALGO) - cookie_vars.update({JWT_COOKIE_NAME: decoded_payload}) + return {JWT_COOKIE_NAME: decoded_payload} else: log.debug('could not find jwt cookie in get_cookie_vars()') - cookie_vars = {} except KeyError as e: log.debug('Key error trying to get cookie vars: {}'.format(e)) - cookie_vars = {} - return cookie_vars + + return {} def get_exp_time(): @@ -141,17 +133,18 @@ def get_cookie_expiration_date_str(): def get_cookies(hdrs: dict): - cookies = {} - pre_cookies = [] - c = hdrs.get('cookie', hdrs.get('Cookie', hdrs.get('COOKIE', None))) - if c: - pre_cookies = c.split(';') - for cook in pre_cookies: - # print('x: {}'.format(cook)) - splitcook = cook.split('=') - cookies.update({splitcook[0].strip(): splitcook[1].strip()}) + cookie_string = hdrs.get('cookie') or hdrs.get('Cookie') or hdrs.get('COOKIE') + if not cookie_string: + return {} - return cookies + cookie = SimpleCookie() + with contextlib.suppress(CookieError): + cookie.load(cookie_string) + + return { + key: morsel.value + for key, morsel in cookie.items() + } def make_jwt_payload(payload, algo=JWT_ALGO): @@ -178,12 +171,12 @@ def decode_jwt_payload(jwt_payload, algo=JWT_ALGO): timer = time() cookiedecoded = jwt.decode(jwt_payload, rsa_pub_key, algo) log.info(return_timing_object(service="jwt", endpoint="jwt.decode()", duration=duration(timer))) - except jwt.ExpiredSignatureError as e: + except jwt.ExpiredSignatureError: # Signature has expired log.info('JWT has expired') # TODO what more to do with this, if anything? return {} - except jwt.InvalidSignatureError as e: + except jwt.InvalidSignatureError: log.info('JWT has failed verification. returning empty dict') return {} @@ -203,7 +196,7 @@ def decode_jwt_payload(jwt_payload, algo=JWT_ALGO): def craft_cookie_domain_payloadpiece(cookie_domain): if cookie_domain: - return '; Domain={}'.format(cookie_domain) + return f'; Domain={cookie_domain}' return '' @@ -214,20 +207,18 @@ def make_set_cookie_headers_jwt(payload, expdate='', cookie_domain=''): if not expdate: expdate = get_cookie_expiration_date_str() - headers = {'SET-COOKIE': '{}={}; Expires={}; Path=/{}'.format(JWT_COOKIE_NAME, - jwt_payload, - expdate, - cookie_domain_payloadpiece)} + headers = {'SET-COOKIE': f'{JWT_COOKIE_NAME}={jwt_payload}; Expires={expdate}; Path=/{cookie_domain_payloadpiece}'} return headers def is_jwt_blacklisted(decoded_jwt): set_jwt_blacklist() urs_user_id = decoded_jwt["urs-user-id"] + blacklist = JWT_BLACKLIST["blacklist"] + user_blacklist_time = blacklist.get(urs_user_id) - if urs_user_id in JWT_BLACKLIST["blacklist"]: + if user_blacklist_time is not None: jwt_mint_time = decoded_jwt["iat"] - user_blacklist_time = JWT_BLACKLIST["blacklist"][urs_user_id] log.debug(f"JWT was minted @: {jwt_mint_time}, the Blacklist is for cookies BEFORE: {user_blacklist_time}") if user_blacklist_time >= jwt_mint_time: