diff --git a/.gitignore b/.gitignore index 4a76c3a3e..7e18527a4 100644 --- a/.gitignore +++ b/.gitignore @@ -108,3 +108,6 @@ tests/resources/keys/*.pem .DS_Store .vscode .idea + +# snyk +.dccache \ No newline at end of file diff --git a/fence/__init__.py b/fence/__init__.py index e1aec601d..fdcc9943d 100755 --- a/fence/__init__.py +++ b/fence/__init__.py @@ -470,6 +470,7 @@ def _setup_oidc_clients(app): logger=logger, HTTP_PROXY=config.get("HTTP_PROXY"), idp=settings.get("name") or idp.title(), + arborist=app.arborist, ) clean_idp = idp.lower().replace(" ", "") setattr(app, f"{clean_idp}_client", client) diff --git a/fence/blueprints/login/base.py b/fence/blueprints/login/base.py index 08fcab61d..d65060cfe 100644 --- a/fence/blueprints/login/base.py +++ b/fence/blueprints/login/base.py @@ -1,14 +1,19 @@ +import time +import base64 +import json +from urllib.parse import urlparse, urlencode, parse_qsl +import jwt +from flask import current_app import flask from cdislogging import get_logger from flask_restful import Resource -from urllib.parse import urlparse, urlencode, parse_qsl - from fence.auth import login_user from fence.blueprints.login.redirect import validate_redirect from fence.config import config from fence.errors import UserError from fence.metrics import metrics + logger = get_logger(__name__) @@ -20,7 +25,7 @@ def __init__(self, idp_name, client): Args: idp_name (str): name for the identity provider client (fence.resources.openid.idp_oauth2.Oauth2ClientBase): - Some instaniation of this base client class or a child class + Some instantiation of this base client class or a child class """ self.idp_name = idp_name self.client = client @@ -67,6 +72,9 @@ def __init__( username_field="email", email_field="email", id_from_idp_field="sub", + firstname_claim_field="given_name", + lastname_claim_field="family_name", + organization_claim_field="org", app=flask.current_app, ): """ @@ -92,8 +100,25 @@ def __init__( self.is_mfa_enabled = "multifactor_auth_claim_info" in config[ "OPENID_CONNECT" ].get(self.idp_name, {}) + + # Config option to explicitly persist refresh tokens + self.persist_refresh_token = False + + self.read_authz_groups_from_tokens = False + self.app = app + self.persist_refresh_token = ( + config["OPENID_CONNECT"].get(self.idp_name, {}).get("persist_refresh_token") + ) + + if "is_authz_groups_sync_enabled" in config["OPENID_CONNECT"].get( + self.idp_name, {} + ): + self.read_authz_groups_from_tokens = config["OPENID_CONNECT"][ + self.idp_name + ]["is_authz_groups_sync_enabled"] + def get(self): # Check if user granted access if flask.request.args.get("error"): @@ -119,7 +144,11 @@ def get(self): code = flask.request.args.get("code") result = self.client.get_auth_info(code) + + refresh_token = result.get("refresh_token") + username = result.get(self.username_field) + if not username: raise UserError( f"OAuth2 callback error: no '{self.username_field}' in {result}" @@ -128,12 +157,114 @@ def get(self): email = result.get(self.email_field) id_from_idp = result.get(self.id_from_idp_field) - resp = _login(username, self.idp_name, email=email, id_from_idp=id_from_idp) - self.post_login(user=flask.g.user, token_result=result, id_from_idp=id_from_idp) + resp = _login( + username, + self.idp_name, + email=email, + id_from_idp=id_from_idp, + token_result=result, + ) + + if not flask.g.user: + raise UserError("Authentication failed: flask.g.user is missing.") + + expires = self.extract_exp(refresh_token) + + # if the access token is not a JWT, or does not carry exp, + # default to now + REFRESH_TOKEN_EXPIRES_IN + if expires is None: + expires = int(time.time()) + config["REFRESH_TOKEN_EXPIRES_IN"] + logger.info(self, f"Refresh token not in JWT, using default: {expires}") + + # Store refresh token in db + should_persist_token = ( + self.persist_refresh_token or self.read_authz_groups_from_tokens + ) + if should_persist_token: + # Ensure flask.g.user exists to avoid a potential AttributeError + if getattr(flask.g, "user", None): + self.client.store_refresh_token(flask.g.user, refresh_token, expires) + else: + logger.error( + "User information is missing from flask.g; cannot store refresh token." + ) + + self.post_login( + user=flask.g.user, + token_result=result, + id_from_idp=id_from_idp, + ) + return resp + def extract_exp(self, refresh_token): + """ + Extract the expiration time (`exp`) from a refresh token. + + This function attempts to retrieve the expiration time from the provided + refresh token using three methods: + + 1. Using PyJWT to decode the token (without signature verification). + 2. Introspecting the token (if supported by the identity provider). + 3. Manually base64 decoding the token's payload (if it's a JWT). + + **Disclaimer:** This function assumes that the refresh token is valid and + does not perform any JWT validation. For JWTs from an OpenID Connect (OIDC) + provider, validation should be done using the public keys provided by the + identity provider (from the JWKS endpoint) before using this function to + extract the expiration time. Without validation, the token's integrity and + authenticity cannot be guaranteed, which may expose your system to security + risks. Ensure validation is handled prior to calling this function, + especially in any public or production-facing contexts. + + Args: + refresh_token (str): The JWT refresh token from which to extract the expiration. + + Returns: + int or None: The expiration time (`exp`) in seconds since the epoch, + or None if extraction fails. + """ + + # Method 1: PyJWT + try: + # Skipping keys since we're not verifying the signature + decoded_refresh_token = jwt.decode( + refresh_token, + options={ + "verify_aud": False, + "verify_at_hash": False, + "verify_signature": False, + }, + algorithms=["RS256", "HS512"], + ) + exp = decoded_refresh_token.get("exp") + + if exp is not None: + return exp + except Exception as e: + logger.info(f"Refresh token expiry: Method (PyJWT) failed: {e}") + + # Method 2: Manual base64 decoding + try: + # Assuming the token is a JWT (header.payload.signature) + payload_encoded = refresh_token.split(".")[1] + # Add necessary padding for base64 decoding + payload_encoded += "=" * (4 - len(payload_encoded) % 4) + payload_decoded = base64.urlsafe_b64decode(payload_encoded) + payload_json = json.loads(payload_decoded) + exp = payload_json.get("exp") + + if exp is not None: + return exp + except Exception as e: + logger.info(f"Method 3 (Manual decoding) failed: {e}") + + # If all methods fail, return None + return None + def post_login(self, user=None, token_result=None, **kwargs): prepare_login_log(self.idp_name) + metrics.add_login_event( user_sub=flask.g.user.id, idp=self.idp_name, @@ -142,6 +273,11 @@ def post_login(self, user=None, token_result=None, **kwargs): client_id=flask.session.get("client_id"), ) + if self.read_authz_groups_from_tokens: + self.client.update_user_authorization( + user=user, pkey_cache=None, db_session=None, idp_name=self.idp_name + ) + if token_result: username = token_result.get(self.username_field) if self.is_mfa_enabled: @@ -171,19 +307,76 @@ def prepare_login_log(idp_name): } -def _login(username, idp_name, email=None, id_from_idp=None): +def _login( + username, + idp_name, + email=None, + id_from_idp=None, + token_result=None, +): """ - Login user with given username, then redirect if session has a saved - redirect. + Login user with given username, then automatically register if needed, + and finally redirect if session has a saved redirect. """ login_user(username, idp_name, email=email, id_from_idp=id_from_idp) + register_idp_users = ( + config["OPENID_CONNECT"] + .get(idp_name, {}) + .get("enable_idp_users_registration", False) + ) + if config["REGISTER_USERS_ON"]: - if not flask.g.user.additional_info.get("registration_info"): - return flask.redirect( - config["BASE_URL"] + flask.url_for("register.register_user") - ) + user = flask.g.user + if not user.additional_info.get("registration_info"): + # If enabled, automatically register user from Idp + if register_idp_users: + firstname = token_result.get("firstname") + lastname = token_result.get("lastname") + organization = token_result.get("org") + email = token_result.get("email") + if email is None: + raise UserError("OAuth2 id token is missing email claim") + # Log warnings and set defaults if needed + if not firstname or not lastname: + logger.warning( + f"User {username} missing name fields. Proceeding with minimal info." + ) + firstname = firstname or "Unknown" + lastname = lastname or "User" + + if not organization: + organization = None + logger.info( + f"User {username} missing organization. Defaulting to None." + ) + + # Store registration info + registration_info = { + "firstname": firstname, + "lastname": lastname, + "org": organization, + "email": email, + } + user.additional_info = user.additional_info or {} + user.additional_info["registration_info"] = registration_info + + # Persist to database + current_app.scoped_session().add(user) + current_app.scoped_session().commit() + + # Ensure user exists in Arborist and assign to group + with current_app.arborist.context(): + current_app.arborist.create_user(dict(name=username)) + current_app.arborist.add_user_to_group( + username=username, + group_name=config["REGISTERED_USERS_GROUP"], + ) + else: + return flask.redirect( + config["BASE_URL"] + flask.url_for("register.register_user") + ) if flask.session.get("redirect"): return flask.redirect(flask.session.get("redirect")) - return flask.jsonify({"username": username}) + return flask.jsonify({"username": username, "registered": True}) diff --git a/fence/config-default.yaml b/fence/config-default.yaml index f87e53282..e7bd5e2ed 100755 --- a/fence/config-default.yaml +++ b/fence/config-default.yaml @@ -45,7 +45,7 @@ ENCRYPTION_KEY: '' # ////////////////////////////////////////////////////////////////////////////////////// # flask's debug setting # WARNING: DO NOT ENABLE IN PRODUCTION (for testing purposes only) -DEBUG: true +DEBUG: false # if true, will automatically login a user with username "test" # WARNING: DO NOT ENABLE IN PRODUCTION (for testing purposes only) MOCK_AUTH: false @@ -94,6 +94,7 @@ DB_MIGRATION_POSTGRES_LOCK_KEY: 100 # - WARNING: Be careful changing the *_ALLOWED_SCOPES as you can break basic # and optional functionality # ////////////////////////////////////////////////////////////////////////////////////// + OPENID_CONNECT: # any OIDC IDP that does not differ from the generic implementation can be # configured without code changes @@ -115,6 +116,40 @@ OPENID_CONNECT: multifactor_auth_claim_info: # optional, include if you're using arborist to enforce mfa on a per-file level claim: '' # claims field that indicates mfa, either the acr or acm claim. values: [ "" ] # possible values that indicate mfa was used. At least one value configured here is required to be in the token + # When true, it allows refresh tokens to be stored even if is_authz_groups_sync_enabled is set false. + # When false, the system will only store refresh tokens if is_authz_groups_sync_enabled is enabled + persist_refresh_token: false + # is_authz_groups_sync_enabled: A configuration flag that determines whether the application should + # verify and synchronize user group memberships between the identity provider (IdP) + # and the local authorization system (Arborist). When enabled, the refresh token is stored, the system retrieves + # the user's group information from their token issued by the IdP and compares it against + # the groups defined in the local system. Based on the comparison, the user is added to + # or removed from relevant groups in the local system to ensure their group memberships + # remain up-to-date. If this flag is disabled, no group synchronization occurs + is_authz_groups_sync_enabled: true + # Key used to retrieve group information from the token + group_claim_field: "groups" + # IdP group membership expiration (seconds). + group_membership_expiration_duration: 604800 + authz_groups_sync: + # This defines the prefix used to identify authorization groups. + group_prefix: "some_prefix" + # This flag indicates whether the audience (aud) claim in the JWT should be verified during token validation. + verify_aud: true + # This specifies the expected audience (aud) value for the JWT, ensuring that the token is intended for use with the 'fence' service. + audience: fence + # default refresh token expiration duration + default_refresh_token_exp: 3600 + # firstname claim field from idp + firstname_claim_field: 'firstName' + # lastname claim field from idp + lastname_claim_field: 'lastName' + # organization claim field from idp + organization_claim_field: 'org' + # default organization + default_organization: 'AU BioCommons' + # automatically register users from Idp + enable_idp_users_registration: true # These Google values must be obtained from Google's Cloud Console # Follow: https://developers.google.com/identity/protocols/OpenIDConnect # diff --git a/fence/config.py b/fence/config.py index 36348b0b1..08bd04390 100644 --- a/fence/config.py +++ b/fence/config.py @@ -140,12 +140,23 @@ def post_process(self): ) for idp_id, idp in self._configs.get("OPENID_CONNECT", {}).items(): + if not isinstance(idp, dict): + raise TypeError( + "Expected 'OPENID_CONNECT' configuration to be a dictionary." + ) mfa_info = idp.get("multifactor_auth_claim_info") if mfa_info and mfa_info["claim"] not in ["amr", "acr"]: logger.warning( f"IdP '{idp_id}' is using multifactor_auth_claim_info '{mfa_info['claim']}', which is neither AMR or ACR. Unable to determine if a user used MFA. Fence will continue and assume they have not used MFA." ) + groups_sync_enabled = idp.get("is_authz_groups_sync_enabled", False) + # when is_authz_groups_sync_enabled, then you must provide authz_groups_sync, with group prefix + if groups_sync_enabled and not idp.get("authz_groups_sync"): + error = f"Error: is_authz_groups_sync_enabled is enabled, required values not configured, for idp: {idp_id}" + logger.error(error) + raise Exception(error) + self._validate_parent_child_studies(self._configs["dbGaP"]) @staticmethod diff --git a/fence/error_handler.py b/fence/error_handler.py index 5b3a0cfdb..e59e86ae3 100644 --- a/fence/error_handler.py +++ b/fence/error_handler.py @@ -8,37 +8,44 @@ from fence.errors import APIError from fence.config import config +import traceback logger = get_logger(__name__) -def get_error_response(error): +def get_error_response(error: Exception): + """ + Generates a response for the given error with detailed logs and appropriate status codes. + + Args: + error (Exception): The error that occurred. + + Returns: + Tuple (str, int): Rendered error HTML and HTTP status code. + """ details, status_code = get_error_details_and_status(error) support_email = config.get("SUPPORT_EMAIL_FOR_ERRORS") app_name = config.get("APP_NAME", "Gen3 Data Commons") - message = details.get("message") - error_id = _get_error_identifier() logger.error( - "{} HTTP error occured. ID: {}\nDetails: {}".format( - status_code, error_id, str(details) + "{} HTTP error occurred. ID: {}\nDetails: {}\nTraceback: {}".format( + status_code, error_id, details, traceback.format_exc() ) ) - # don't include internal details in the public error message - # to do this, only include error messages for known http status codes - # that are less that 500 + # Prepare user-facing message + message = details.get("message") valid_http_status_codes = [ int(code) for code in list(http_responses.keys()) if int(code) < 500 ] + try: status_code = int(status_code) if status_code not in valid_http_status_codes: message = None except (ValueError, TypeError): - # this handles case where status_code is NOT a valid integer (e.g. HTTP status code) message = None status_code = 500 @@ -59,6 +66,15 @@ def get_error_response(error): def get_error_details_and_status(error): + """ + Extracts details and HTTP status code from the given error. + + Args: + error (Exception): The error to process. + + Returns: + Tuple (dict, int): Error details as a dictionary and HTTP status code. + """ message = error.message if hasattr(error, "message") else str(error) if isinstance(error, APIError): if hasattr(error, "json") and error.json: @@ -70,11 +86,11 @@ def get_error_details_and_status(error): error_response = {"message": error.description}, error.status_code elif isinstance(error, HTTPException): error_response = ( - {"message": getattr(error, "description")}, + {"message": getattr(error, "description", str(error))}, error.get_response().status_code, ) else: - logger.exception("Catch exception") + logger.exception("Unexpected exception occurred") error_code = 500 if hasattr(error, "code"): error_code = error.code @@ -86,4 +102,10 @@ def get_error_details_and_status(error): def _get_error_identifier(): + """ + Generates a unique identifier for tracking the error. + + Returns: + UUID: A unique identifier for the error. + """ return uuid.uuid4() diff --git a/fence/job/visa_update_cronjob.py b/fence/job/access_token_updater.py similarity index 76% rename from fence/job/visa_update_cronjob.py rename to fence/job/access_token_updater.py index cac8d9182..8c1c15b1c 100644 --- a/fence/job/visa_update_cronjob.py +++ b/fence/job/access_token_updater.py @@ -3,16 +3,18 @@ import time from cdislogging import get_logger +from flask import current_app from fence.config import config from fence.models import User from fence.resources.openid.ras_oauth2 import RASOauth2Client as RASClient +from fence.resources.openid.idp_oauth2 import Oauth2ClientBase as OIDCClient logger = get_logger(__name__, log_level="debug") -class Visa_Token_Update(object): +class TokenAndAuthUpdater(object): def __init__( self, chunk_size=None, @@ -20,6 +22,7 @@ def __init__( thread_pool_size=None, buffer_size=None, logger=logger, + arborist=None, ): """ args: @@ -44,17 +47,36 @@ def __init__( self.visa_types = config.get("USERSYNC", {}).get("visa_types", {}) - # Initialize visa clients: + # Dict on self which contains all clients that need update + self.oidc_clients_requiring_token_refresh = {} + + # keep this as a special case, because RAS will not set group information configuration. oidc = config.get("OPENID_CONNECT", {}) + if "ras" not in oidc: self.logger.error("RAS client not configured") - self.ras_client = None else: - self.ras_client = RASClient( + ras_client = RASClient( oidc["ras"], HTTP_PROXY=config.get("HTTP_PROXY"), logger=logger, ) + self.oidc_clients_requiring_token_refresh["ras"] = ras_client + + self.arborist = arborist + + # Initialise a client for each OIDC client in oidc, which does have is_authz_groups_sync_enabled set to true and add them + # to oidc_clients_requiring_token_refresh + for oidc_name, settings in oidc.items(): + if settings.get("is_authz_groups_sync_enabled", False): + oidc_client = OIDCClient( + settings=settings, + HTTP_PROXY=config.get("HTTP_PROXY"), + logger=logger, + idp=oidc_name, + arborist=arborist, + ) + self.oidc_clients_requiring_token_refresh[oidc_name] = oidc_client async def update_tokens(self, db_session): """ @@ -68,7 +90,7 @@ async def update_tokens(self, db_session): """ start_time = time.time() - self.logger.info("Initializing Visa Update Cronjob . . .") + self.logger.info("Initializing Visa Update and Token refreshing Cronjob . . .") self.logger.info("Total concurrency size: {}".format(self.concurrency)) self.logger.info("Total thread pool size: {}".format(self.thread_pool_size)) self.logger.info("Total buffer size: {}".format(self.buffer_size)) @@ -139,13 +161,12 @@ async def worker(self, name, queue, updater_queue): queue.task_done() async def updater(self, name, updater_queue, db_session): - """ - Update visas in the updater_queue. - Note that only visas which pass validation will be saved. - """ while True: - user = await updater_queue.get() try: + user = await updater_queue.get() + if user is None: # Use None to signal termination + break + client = self._pick_client(user) if client: self.logger.info( @@ -160,30 +181,35 @@ async def updater(self, name, updater_queue, db_session): pkey_cache=self.pkey_cache, db_session=db_session, ) + else: self.logger.debug( f"Updater {name} NOT updating authorization for " f"user {user.username} because no client was found for IdP: {user.identity_provider}" ) + + updater_queue.task_done() + except Exception as exc: self.logger.error( f"Updater {name} could not update authorization " - f"for {user.username}. Error: {exc}. Continuing." + f"for {user.username if user else 'unknown user'}. Error: {exc}. Continuing." ) - pass - - updater_queue.task_done() + # Ensure task is marked done if exception occurs + updater_queue.task_done() def _pick_client(self, user): """ - Pick oidc client according to the identity provider + Select OIDC client based on identity provider. """ - client = None - if ( - user.identity_provider - and getattr(user.identity_provider, "name") == self.ras_client.idp - ): - client = self.ras_client + + client = self.oidc_clients_requiring_token_refresh.get( + getattr(user.identity_provider, "name"), None + ) + if client: + self.logger.info(f"Picked client: {client.idp} for user {user.username}") + else: + self.logger.info(f"No client found for user {user.username}") return client def _pick_client_from_visa(self, visa): diff --git a/fence/resources/openid/cilogon_oauth2.py b/fence/resources/openid/cilogon_oauth2.py index 163663420..cc1b8674d 100644 --- a/fence/resources/openid/cilogon_oauth2.py +++ b/fence/resources/openid/cilogon_oauth2.py @@ -39,7 +39,9 @@ def get_auth_info(self, code): jwks_endpoint = self.get_value_from_discovery_doc( "jwks_uri", "https://cilogon.org/oauth2/certs" ) - claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) + claims, refresh_token, access_token = self.get_jwt_claims_identity( + token_endpoint, jwks_endpoint, code + ) if claims.get("sub"): return {"sub": claims["sub"]} diff --git a/fence/resources/openid/cognito_oauth2.py b/fence/resources/openid/cognito_oauth2.py index 73038c87f..3da80fb9d 100644 --- a/fence/resources/openid/cognito_oauth2.py +++ b/fence/resources/openid/cognito_oauth2.py @@ -45,7 +45,9 @@ def get_auth_info(self, code): try: token_endpoint = self.get_value_from_discovery_doc("token_endpoint", "") jwks_endpoint = self.get_value_from_discovery_doc("jwks_uri", "") - claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) + claims, refresh_token, access_token = self.get_jwt_claims_identity( + token_endpoint, jwks_endpoint, code + ) self.logger.info(f"Received id token from Cognito: {claims}") diff --git a/fence/resources/openid/google_oauth2.py b/fence/resources/openid/google_oauth2.py index b396fe9ca..60c2d69c2 100644 --- a/fence/resources/openid/google_oauth2.py +++ b/fence/resources/openid/google_oauth2.py @@ -47,7 +47,9 @@ def get_auth_info(self, code): jwks_endpoint = self.get_value_from_discovery_doc( "jwks_uri", "https://www.googleapis.com/oauth2/v3/certs" ) - claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) + claims, refresh_token, access_token = self.get_jwt_claims_identity( + token_endpoint, jwks_endpoint, code + ) if claims.get("email") and claims.get("email_verified"): return {"email": claims["email"], "sub": claims.get("sub")} diff --git a/fence/resources/openid/idp_oauth2.py b/fence/resources/openid/idp_oauth2.py index 463db77cf..80e874382 100644 --- a/fence/resources/openid/idp_oauth2.py +++ b/fence/resources/openid/idp_oauth2.py @@ -1,13 +1,20 @@ from authlib.integrations.requests_client import OAuth2Session from cached_property import cached_property from flask import current_app -from jose import jwt + +from jose.exceptions import JWTError, JWTClaimsError import requests import time - +import datetime +import backoff +import jwt +from fence.utils import DEFAULT_BACKOFF_SETTINGS from fence.errors import AuthError from fence.models import UpstreamRefreshToken +from fence.jwt.validate import validate_jwt +from authutils.token.keys import get_public_key_for_token + class Oauth2ClientBase(object): """ @@ -15,7 +22,14 @@ class Oauth2ClientBase(object): """ def __init__( - self, settings, logger, idp, scope=None, discovery_url=None, HTTP_PROXY=None + self, + settings, + logger, + idp, + scope=None, + discovery_url=None, + HTTP_PROXY=None, + arborist=None, ): self.logger = logger self.settings = settings @@ -25,14 +39,17 @@ def __init__( scope=scope or settings.get("scope") or "openid", redirect_uri=settings["redirect_url"], ) + self.discovery_url = ( discovery_url or settings.get("discovery_url") or getattr(self, "DISCOVERY_URL", None) or "" ) - self.idp = idp # display name for use in logs and error messages + # display name for use in logs and error messages + self.idp = idp self.HTTP_PROXY = HTTP_PROXY + self.authz_groups_from_idp = [] if not self.discovery_url and not settings.get("discovery"): self.logger.warning( @@ -40,6 +57,12 @@ def __init__( f"Some calls for this client may fail if they rely on the OIDC Discovery page. Use 'discovery' to configure clients without a discovery page." ) + self.read_authz_groups_from_tokens = self.settings.get( + "is_authz_groups_sync_enabled", False + ) + + self.arborist = arborist + @cached_property def discovery_doc(self): return requests.get(self.discovery_url) @@ -53,6 +76,7 @@ def get_proxies(self): return None def get_token(self, token_endpoint, code): + return self.session.fetch_token( url=token_endpoint, code=code, proxies=self.get_proxies() ) @@ -63,6 +87,7 @@ def get_jwt_keys(self, jwks_uri): Return None if there is an error while retrieving keys from the api """ resp = requests.get(url=jwks_uri, proxies=self.get_proxies()) + if resp.status_code != requests.codes.ok: self.logger.error( "{} ERROR: Can not retrieve jwt keys from IdP's API {}".format( @@ -76,15 +101,48 @@ def get_jwt_claims_identity(self, token_endpoint, jwks_endpoint, code): """ Get jwt identity claims """ + token = self.get_token(token_endpoint, code) + + refresh_token = token.get("refresh_token", None) + keys = self.get_jwt_keys(jwks_endpoint) - return jwt.decode( - token["id_token"], - keys, - options={"verify_aud": False, "verify_at_hash": False}, - algorithms=["RS256"], - ) + # Extract issuer from the id token without signature verification + try: + decoded_id_token = jwt.decode( + token["id_token"], + options={"verify_signature": False}, + algorithms=["RS256"], + key=keys, + ) + issuer = decoded_id_token.get("iss") + except JWTError as e: + raise JWTError(f"Invalid token: {e}") + + # validate audience and hash. also ensure that the algorithm is correctly derived from the token. + # hash verification has not been implemented yet + verify_aud = self.settings.get("verify_aud", False) + audience = self.settings.get("audience", self.settings.get("client_id")) + + decoded_access_token = None + + if self.read_authz_groups_from_tokens: + try: + decoded_access_token = validate_jwt( + encoded_token=token["access_token"], + aud=audience, + scope=None, + issuers=[issuer], + purpose=None, + require_purpose=False, + options={"verify_aud": verify_aud, "verify_hash": False}, + attempt_refresh=True, + ) + except JWTError as e: + raise JWTError(f"Invalid token: {e}") + + return decoded_id_token, refresh_token, decoded_access_token def get_value_from_discovery_doc(self, key, default_value): """ @@ -163,10 +221,41 @@ def get_auth_info(self, code): user OR "error" field with details of the error. """ user_id_field = self.settings.get("user_id_field", "sub") + try: token_endpoint = self.get_value_from_discovery_doc("token_endpoint", "") jwks_endpoint = self.get_value_from_discovery_doc("jwks_uri", "") - claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) + claims, refresh_token, access_token = self.get_jwt_claims_identity( + token_endpoint, jwks_endpoint, code + ) + + groups = None + group_prefix = None + + organization_claim_field = self.settings.get( + "organization_claim_field", "org" + ) + firstname_claim_field = self.settings.get( + "firstname_claim_field", "given_name" + ) + lastname_claim_field = self.settings.get( + "lastname_claim_field", "family_name" + ) + email_claim_field = self.settings.get("email_claim_field", "email") + + if self.read_authz_groups_from_tokens: + try: + group_claim_field = self.settings.get("group_claim_field", "groups") + # Get groups from access token + groups = access_token.get(group_claim_field) + group_prefix = self.settings.get("authz_groups_sync", {}).get( + "group_prefix", "" + ) + except KeyError as e: + self.logger.error( + f"Error: is_authz_groups_sync_enabled is enabled, however groups not found in claims: {e}" + ) + raise Exception(e) if claims.get(user_id_field): if user_id_field == "email" and not claims.get("email_verified"): @@ -174,6 +263,15 @@ def get_auth_info(self, code): return { user_id_field: claims[user_id_field], "mfa": self.has_mfa_claim(claims), + "refresh_token": refresh_token, + "iat": claims.get("iat"), + "exp": claims.get("exp"), + "groups": groups, + "group_prefix": group_prefix, + "org": claims.get(organization_claim_field), + "firstname": claims.get(firstname_claim_field), + "lastname": claims.get(lastname_claim_field), + "email": claims.get(email_claim_field), } else: self.logger.exception( @@ -183,7 +281,7 @@ def get_auth_info(self, code): except Exception as e: self.logger.exception(f"Can't get user info from {self.idp}: {e}") - return {"error": f"Can't get user info from {self.idp}"} + return {"error": f"Can't get user info from {self.idp}: {e}"} def get_access_token(self, user, token_endpoint, db_session=None): """ @@ -192,11 +290,12 @@ def get_access_token(self, user, token_endpoint, db_session=None): refresh_token = None expires = None - # get refresh_token and expiration from db + # Get the refresh_token and expiration from the database for row in sorted(user.upstream_refresh_tokens, key=lambda row: row.expires): refresh_token = row.refresh_token expires = row.expires + # Check if the token is expired if time.time() > expires: # reset to check for next token refresh_token = None @@ -209,21 +308,40 @@ def get_access_token(self, user, token_endpoint, db_session=None): if not refresh_token: raise AuthError("User doesn't have a valid, non-expired refresh token") - token_response = self.session.refresh_token( - url=token_endpoint, - proxies=self.get_proxies(), - refresh_token=refresh_token, - ) - refresh_token = token_response["refresh_token"] + verify_aud = self.settings.get("verify_aud", False) + audience = self.settings.get("audience", self.settings.get("client_id")) - self.store_refresh_token( - user, - refresh_token=refresh_token, - expires=expires, - db_session=db_session, - ) + refresh_kwargs = { + "url": token_endpoint, + "proxies": self.get_proxies(), + "refresh_token": refresh_token, + } - return token_response + if verify_aud: + refresh_kwargs["audience"] = audience + + try: + token_response = self.session.refresh_token(**refresh_kwargs) + + refresh_token = token_response["refresh_token"] + # Fetching the expires at from token_response. + # Defaulting to config settings. + default_refresh_token_exp = self.settings.get("default_refresh_token_exp") + expires_at = token_response.get( + "expires_at", time.time() + default_refresh_token_exp + ) + + self.store_refresh_token( + user, + refresh_token=refresh_token, + expires=expires_at, + db_session=db_session, + ) + + return token_response + except Exception as e: + self.logger.exception(f"Error refreshing token for user {user.id}: {e}") + raise AuthError("Failed to refresh access token.") def has_mfa_claim(self, decoded_id_token): """ @@ -276,3 +394,160 @@ def store_refresh_token(self, user, refresh_token, expires, db_session=None): current_db_session = db_session.object_session(upstream_refresh_token) current_db_session.add(upstream_refresh_token) db_session.commit() + self.logger.info( + f"Refresh token has been persisted for user: {user} , with expiration of {expires}" + ) + + def get_groups_from_token(self, decoded_access_token, group_prefix=""): + """ + Retrieve and format groups from the decoded token based on a configurable field name. + + Args: + decoded_access_token (dict): The decoded token containing claims. + group_prefix (str): The prefix to strip from group names. + + Returns: + list: A list of formatted group names. + + Variables: + group_claim_field (str): The field name in the token that contains the group information. + authz_groups_from_idp (list): The list of groups retrieved from the token, potentially empty. + """ + # Retrieve the configured field name for groups, defaulting to 'groups' + group_claim_field = self.settings.get("group_claim_field", "groups") + authz_groups_from_idp = decoded_access_token.get(group_claim_field, []) + + if authz_groups_from_idp: + authz_groups_from_idp = [ + group.removeprefix(group_prefix).lstrip("/") + for group in authz_groups_from_idp + ] + return authz_groups_from_idp + + @backoff.on_exception(backoff.expo, Exception, **DEFAULT_BACKOFF_SETTINGS) + def update_user_authorization(self, user, pkey_cache, db_session=None, **kwargs): + """ + Update the user's authorization by refreshing their access token and synchronizing + their group memberships with Arborist. + + This method refreshes the user's access token using an identity provider (IdP), + retrieves and decodes the token, and optionally synchronizes the user's group + memberships between the IdP and Arborist if the `groups` configuration is enabled. + + Args: + user (User): The user object, which contains details like username and identity provider. + pkey_cache (dict): A cache of public keys used for verifying JWT signatures. + db_session (SQLAlchemy Session, optional): A database session object. If not provided, + it defaults to the scoped session of the current application context. + **kwargs: Additional keyword arguments. + + Raises: + Exception: If there is an issue with retrieving the access token, decoding the token, + or synchronizing the user's groups. + + Workflow: + 1. Retrieves the token endpoint and JWKS URI from the identity provider's discovery document. + 2. Uses the user's refresh token to get a new access token and persists it in the database. + 3. Decodes the ID token using the JWKS (JSON Web Key Set) retrieved from the IdP. + 4. If group synchronization is enabled: + a. Retrieves the list of groups from Arborist. + b. Retrieves the user's groups from the IdP. + c. Adds the user to groups in Arborist that match the groups from the IdP. + d. Removes the user from groups in Arborist that they are no longer part of in the IdP. + + Logging: + - Logs the group membership synchronization activities (adding/removing users from groups). + - Logs any issues encountered while refreshing the token or during group synchronization. + + Warnings: + - If groups are not received from the IdP but group synchronization is enabled, logs a warning. + + """ + db_session = db_session or current_app.scoped_session() + + expires_at = None + + try: + token_endpoint = self.get_value_from_discovery_doc("token_endpoint", "") + + # this get_access_token also persists the refresh token in the db + token = self.get_access_token(user, token_endpoint, db_session) + + verify_aud = self.settings.get("verify_aud", False) + audience = self.settings.get("audience", self.settings.get("client_id")) + + key = get_public_key_for_token( + token["id_token"], attempt_refresh=True, pkey_cache={} + ) + + decoded_access_token = jwt.decode( + token["access_token"], + key=key, + options={"verify_aud": verify_aud, "verify_at_hash": False}, + algorithms=["RS256"], + audience=audience, + ) + self.logger.info("Token decoded and validated successfully.") + + except Exception as e: + err_msg = "Could not refresh token" + self.logger.exception("{}: {}".format(err_msg, e)) + raise + + if self.read_authz_groups_from_tokens: + group_prefix = self.settings.get("authz_groups_sync", {}).get( + "group_prefix", "" + ) + + # grab all groups defined in arborist + arborist_groups = self.arborist.list_groups().get("groups") + + # groups defined in idp + authz_groups_from_idp = self.get_groups_from_token( + decoded_access_token, group_prefix + ) + + # if group name is in the list from arborist: + if authz_groups_from_idp: + authz_groups_from_idp = [ + group.removeprefix(group_prefix).lstrip("/") + for group in authz_groups_from_idp + ] + + idp_group_names = set(authz_groups_from_idp) + + # Expiration for group membership. Default 7 days + group_membership_duration = self.settings.get( + "group_membership_expiration_duration", 3600 * 24 * 7 + ) + + # Get the refresh token expiration from the token response + refresh_token_expires_at = datetime.datetime.fromtimestamp( + token.get("expires_at", time.time()), tz=datetime.timezone.utc + ) + + # Calculate the configured group membership expiration + configured_expires_at = datetime.datetime.now( + tz=datetime.timezone.utc + ) + datetime.timedelta(seconds=group_membership_duration) + + # Ensure group membership does not exceed refresh token expiration + group_membership_expires_at = min( + refresh_token_expires_at, configured_expires_at + ) + + # Add user to all matching groups from IDP + for arborist_group in arborist_groups: + if arborist_group["name"] in idp_group_names: + self.logger.info( + f"Adding {user.username} to group: {arborist_group['name']}, sub: {user.id} exp: {group_membership_expires_at}" + ) + self.arborist.add_user_to_group( + username=user.username, + group_name=arborist_group["name"], + expires_at=group_membership_expires_at, + ) + else: + self.logger.warning( + f"is_authz_groups_sync_enabled feature is enabled, but did not receive groups from idp {self.idp} for user: {user.username}" + ) diff --git a/fence/resources/openid/microsoft_oauth2.py b/fence/resources/openid/microsoft_oauth2.py index 916a4a2b1..d4fa6b634 100755 --- a/fence/resources/openid/microsoft_oauth2.py +++ b/fence/resources/openid/microsoft_oauth2.py @@ -48,7 +48,9 @@ def get_auth_info(self, code): "jwks_uri", "https://login.microsoftonline.com/organizations/discovery/v2.0/keys", ) - claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) + claims, refresh_token, access_token = self.get_jwt_claims_identity( + token_endpoint, jwks_endpoint, code + ) if claims.get("email"): return {"email": claims["email"], "sub": claims.get("sub")} diff --git a/fence/resources/openid/okta_oauth2.py b/fence/resources/openid/okta_oauth2.py index 572031623..bb590915b 100644 --- a/fence/resources/openid/okta_oauth2.py +++ b/fence/resources/openid/okta_oauth2.py @@ -37,7 +37,9 @@ def get_auth_info(self, code): "jwks_uri", "", ) - claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) + claims, refresh_token, access_token = self.get_jwt_claims_identity( + token_endpoint, jwks_endpoint, code + ) if claims.get("email"): return {"email": claims["email"], "sub": claims.get("sub")} diff --git a/fence/resources/openid/orcid_oauth2.py b/fence/resources/openid/orcid_oauth2.py index ee8711f33..3e0056984 100644 --- a/fence/resources/openid/orcid_oauth2.py +++ b/fence/resources/openid/orcid_oauth2.py @@ -41,7 +41,9 @@ def get_auth_info(self, code): jwks_endpoint = self.get_value_from_discovery_doc( "jwks_uri", "https://orcid.org/oauth/jwks" ) - claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) + claims, refresh_token, access_token = self.get_jwt_claims_identity( + token_endpoint, jwks_endpoint, code + ) if claims.get("sub"): return {"orcid": claims["sub"], "sub": claims["sub"]} diff --git a/fence/scripting/fence_create.py b/fence/scripting/fence_create.py index a4b15aff8..77d080a5d 100644 --- a/fence/scripting/fence_create.py +++ b/fence/scripting/fence_create.py @@ -38,7 +38,7 @@ generate_signed_refresh_token, issued_and_expiration_times, ) -from fence.job.visa_update_cronjob import Visa_Token_Update +from fence.job.access_token_updater import TokenAndAuthUpdater from fence.models import ( Client, GoogleServiceAccount, @@ -1814,12 +1814,19 @@ def access_token_polling_job( thread_pool_size (int): number of Docker container CPU used for jwt verifcation buffer_size (int): max size of queue """ + # Instantiating a new client here because the existing + # client uses authz_provider + arborist = ArboristClient( + arborist_base_url=config["ARBORIST"], + logger=get_logger("user_syncer.arborist_client"), + ) driver = get_SQLAlchemyDriver(db) - job = Visa_Token_Update( + job = TokenAndAuthUpdater( chunk_size=int(chunk_size) if chunk_size else None, concurrency=int(concurrency) if concurrency else None, thread_pool_size=int(thread_pool_size) if thread_pool_size else None, buffer_size=int(buffer_size) if buffer_size else None, + arborist=arborist, ) with driver.session as db_session: loop = asyncio.get_event_loop() diff --git a/tests/conftest.py b/tests/conftest.py index 9baba01a1..90c81d2fa 100755 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -76,6 +76,7 @@ "cilogon", "generic1", "generic2", + "generic3", ] @@ -396,6 +397,12 @@ def do_patch(urls_to_responses=None): defaults = { "arborist/health": {"GET": ("", 200)}, "arborist/auth/mapping": {"POST": ({}, "200")}, + "arborist/group": { + "GET": ( + {"groups": [{"name": "data_uploaders", "users": ["test_user"]}]}, + 200, + ) + }, } defaults.update(urls_to_responses) urls_to_responses = defaults @@ -479,6 +486,33 @@ def app(kid, rsa_private_key, rsa_public_key): mocker.unmock_functions() +@pytest.fixture +def mock_app(): + return MagicMock() + + +@pytest.fixture +def mock_user(): + return MagicMock() + + +@pytest.fixture +def mock_db_session(): + """Mock the database session.""" + db_session = MagicMock() + return db_session + + +@pytest.fixture +def expired_mock_user(): + """Mock a user object with upstream refresh tokens.""" + user = MagicMock() + user.upstream_refresh_tokens = [ + MagicMock(refresh_token="expired_token", expires=0), # Expired token + ] + return user + + @pytest.fixture(scope="function") def auth_client(request): """ diff --git a/tests/dbgap_sync/test_user_sync.py b/tests/dbgap_sync/test_user_sync.py index 13b07d45b..58e170ea8 100644 --- a/tests/dbgap_sync/test_user_sync.py +++ b/tests/dbgap_sync/test_user_sync.py @@ -10,7 +10,7 @@ from fence import models from fence.resources.google.access_utils import GoogleUpdateException from fence.config import config -from fence.job.visa_update_cronjob import Visa_Token_Update +from fence.job.access_token_updater import TokenAndAuthUpdater from fence.utils import DEFAULT_BACKOFF_SETTINGS from tests.dbgap_sync.conftest import ( @@ -1011,7 +1011,7 @@ def test_user_sync_with_visa_sync_job( # use refresh tokens from users to call access token polling "fence-create update-visa" # and sync authorization from visas - job = Visa_Token_Update() + job = TokenAndAuthUpdater() job.pkey_cache = { "https://stsstg.nih.gov": { kid: rsa_public_key, diff --git a/tests/job/test_access_token_updater.py b/tests/job/test_access_token_updater.py new file mode 100644 index 000000000..f22df71f9 --- /dev/null +++ b/tests/job/test_access_token_updater.py @@ -0,0 +1,206 @@ +import pytest +import asyncio +from unittest.mock import AsyncMock, patch, MagicMock +from fence.models import User +from fence.resources.openid.idp_oauth2 import Oauth2ClientBase as OIDCClient +from fence.resources.openid.ras_oauth2 import RASOauth2Client as RASClient +from fence.job.access_token_updater import TokenAndAuthUpdater + + +@pytest.fixture(scope="session", autouse=True) +def event_loop(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + yield loop + loop.close() + + +@pytest.fixture +def run_async(event_loop): + """Run an async coroutine in the current event loop.""" + + def _run(coro): + return event_loop.run_until_complete(coro) + + return _run + + +@pytest.fixture +def mock_db_session(): + """Fixture to mock the DB session.""" + return MagicMock() + + +@pytest.fixture +def mock_users(): + """Fixture to mock the user list.""" + user1 = MagicMock(spec=User) + user1.username = "testuser1" + user1.identity_provider.name = "ras" + + user2 = MagicMock(spec=User) + user2.username = "testuser2" + user2.identity_provider.name = "test_oidc" + + return [user1, user2] + + +@pytest.fixture +def mock_oidc_clients(): + """Fixture to mock OIDC clients.""" + ras_client = MagicMock(spec=RASClient) + ras_client.idp = "ras" + + oidc_client = MagicMock(spec=OIDCClient) + oidc_client.idp = "test_oidc" + + return [ras_client, oidc_client] + + +@pytest.fixture +def access_token_updater_config(mock_oidc_clients): + """Fixture to instantiate TokenAndAuthUpdater with mocked OIDC clients.""" + with patch( + "fence.config", + { + "OPENID_CONNECT": { + "ras": {}, + "test_oidc": {"groups": {"read_authz_groups_from_tokens": True}}, + }, + "ENABLE_AUTHZ_GROUPS_FROM_OIDC": True, + }, + ): + updater = TokenAndAuthUpdater() + + # Ensure this is a dictionary rather than a list + updater.oidc_clients_requiring_token_refresh = { + client.idp: client for client in mock_oidc_clients + } + + return updater + + +def test_get_user_from_db( + run_async, access_token_updater_config, mock_db_session, mock_users +): + """Test the get_user_from_db method.""" + mock_db_session.query().slice().all.return_value = mock_users + + users = run_async( + access_token_updater_config.get_user_from_db(mock_db_session, chunk_idx=0) + ) + assert len(users) == 2 + assert users[0].username == "testuser1" + assert users[1].username == "testuser2" + + +def test_producer(run_async, access_token_updater_config, mock_db_session, mock_users): + """Test the producer method.""" + queue = asyncio.Queue() + mock_db_session.query().slice().all.return_value = mock_users + + # Run producer to add users to queue + run_async(access_token_updater_config.producer(mock_db_session, queue, chunk_idx=0)) + + assert queue.qsize() == len(mock_users) + assert not queue.empty() + + # Dequeue to check correctness + user = run_async(queue.get()) + assert user.username == "testuser1" + + +def test_worker(run_async, access_token_updater_config, mock_users): + """Test the worker method.""" + queue = asyncio.Queue() + updater_queue = asyncio.Queue() + + # Add users to the queue + for user in mock_users: + run_async(queue.put(user)) + + # Run the worker to transfer users from queue to updater_queue + run_async(access_token_updater_config.worker("worker_1", queue, updater_queue)) + + assert updater_queue.qsize() == len(mock_users) + assert queue.empty() + + +async def updater_with_timeout(updater, queue, db_session, timeout=5): + return await asyncio.wait_for(updater(queue, db_session), timeout) + + +def test_updater( + run_async, + access_token_updater_config, + mock_users, + mock_db_session, + mock_oidc_clients, +): + """Test the updater method.""" + updater_queue = asyncio.Queue() + + # Add a user to the updater_queue + run_async(updater_queue.put(mock_users[0])) + + # Mock the client to return a valid update process + mock_oidc_clients[0].update_user_authorization = AsyncMock() + + # Ensure _pick_client returns the correct client + with patch.object( + access_token_updater_config, "_pick_client", return_value=mock_oidc_clients[0] + ): + # Signal the updater to stop after processing + run_async(updater_queue.put(None)) # This should be an awaited call + + # Run the updater to process the user and update authorization + run_async( + access_token_updater_config.updater( + "updater_1", updater_queue, mock_db_session + ) + ) + + # Verify that the OIDC client was called with the correct user + mock_oidc_clients[0].update_user_authorization.assert_called_once_with( + mock_users[0], + pkey_cache=access_token_updater_config.pkey_cache, + db_session=mock_db_session, + ) + + +def test_no_client_found(run_async, access_token_updater_config, mock_users): + """Test that updater does not crash if no client is found.""" + updater_queue = asyncio.Queue() + + # Modify the user to have an unrecognized identity provider + mock_users[0].identity_provider.name = "unknown_provider" + + run_async(updater_queue.put(mock_users[0])) # Ensure this is awaited + run_async(updater_queue.put(None)) # Signal the updater to terminate + + # Mock the client selection to return None + with patch.object(access_token_updater_config, "_pick_client", return_value=None): + # Run the updater and ensure it skips the user with no client + run_async( + access_token_updater_config.updater("updater_1", updater_queue, MagicMock()) + ) + + assert updater_queue.empty() # The user should still be dequeued + + +def test_pick_client( + run_async, access_token_updater_config, mock_users, mock_oidc_clients +): + """Test that the correct OIDC client is selected based on the user's IDP.""" + # Pick the client for a RAS user + client = access_token_updater_config._pick_client(mock_users[0]) + assert client.idp == "ras" + + # Pick the client for a test OIDC user + client = access_token_updater_config._pick_client(mock_users[1]) + assert client.idp == "test_oidc" + + # Ensure no client is returned for a user with no matching IDP + mock_users[0].identity_provider.name = "nonexistent_idp" + client = access_token_updater_config._pick_client(mock_users[0]) + assert client is None diff --git a/tests/login/test_base.py b/tests/login/test_base.py index a32452b2c..a0b4633e4 100644 --- a/tests/login/test_base.py +++ b/tests/login/test_base.py @@ -1,6 +1,12 @@ +import flask +import pytest from fence.blueprints.login import DefaultOAuth2Callback from fence.config import config from unittest.mock import MagicMock, patch +from fence.errors import UserError +from fence.auth import login_user +from fence.blueprints.login.base import _login +from fence.models import User, IdentityProvider @patch("fence.blueprints.login.base.prepare_login_log") @@ -54,3 +60,126 @@ def test_post_login_no_mfa_enabled(app, monkeypatch, mock_authn_user_flask_conte token_result = {"username": "lisasimpson"} callback.post_login(token_result=token_result) app.arborist.revoke_user_policy.assert_not_called() + yield + + +@pytest.fixture +def mock_user(): + """Fixture to mock a logged-in user with additional_info.""" + user = MagicMock() + user.additional_info = {} + return user + + +@patch("fence.auth.login_user") +def test_login_existing_user(mock_login_user, db_session, app): + """ + Test logging in an existing user without registration. + """ + with app.app_context(): + email = "test@example.com" + provider = "Test Provider" + + response = _login(email, provider) + + mock_login_user.assert_called_once_with( + email, provider, email=None, id_from_idp=None + ) + + assert response.status_code == 200 + assert response.json == {"username": email, "registered": True} + yield + + +@patch("fence.auth.login_user") +@patch("fence.blueprints.login.base.current_app.scoped_session") +def test_login_with_registration(mock_scoped_session, mock_login_user, db_session, app): + """ + Test logging in a user when registration is enabled. + """ + with app.app_context(): + config.REGISTER_USERS_ON = True + config["OPENID_CONNECT"]["mock_idp"] = {"enable_idp_users_registration": True} + config.REGISTERED_USERS_GROUP = "test_group" + + email = "lisa@example.com" + provider = "mock_idp" + token_result = { + "firstname": "Lisa", + "lastname": "Simpson", + "org": "Springfield Elementary", + "email": email, + } + + response = _login(email, provider, token_result=token_result) + + # Ensure login was called + mock_login_user.assert_called_once_with( + email, provider, email=email, id_from_idp=None + ) + + # Ensure user was added to the database + mock_scoped_session.add.assert_called() + mock_scoped_session.commit.assert_called() + + # Ensure response is a JSON response + assert response.status_code == 200 + assert response.json == {"username": email, "registered": True} + yield + + +@patch("fence.auth.login_user") +def test_login_with_missing_email(mock_login_user, app, monkeypatch): + """ + Test that a missing email raises a UserError. + """ + with app.app_context(): + config["REGISTER_USERS_ON"] = True + config["OPENID_CONNECT"]["mock_idp"] = {"enable_idp_users_registration": True} + + provider = "mock_idp" + token_result = { + "firstname": "Lisa", + "lastname": "Simpson", + "org": "Springfield Elementary", + } + yield + + with pytest.raises(UserError, match="OAuth2 id token is missing email claim"): + _login("lisa", provider, token_result=token_result) + + +@patch("fence.auth.login_user") +def test_login_redirect_to_registration_page( + mock_login_user, app, monkeypatch, db_session +): + """ + Test that users are redirected to the registration page when IDP registration is disabled. + """ + with app.app_context(): + config["REGISTER_USERS_ON"] = True + config["OPENID_CONNECT"]["mock_idp"] = {"enable_idp_users_registration": False} + + db_session.query(User).delete() # Remove all users from DB for this test + db_session.commit() + + response = _login("lisaferf", "mock_idp") + + assert response.status_code == 302 + assert response.location == "http://localhost/user/register" + yield + + +@patch("fence.auth.login_user") +def test_login_redirect_after_authentication(mock_login_user, app): + """ + Test that users are redirected to their stored session redirect after authentication. + """ + with app.app_context(): + flask.session["redirect"] = "http://localhost/" + + response = _login("lisa", "mock_idp") + + assert response.status_code == 302 + assert response.location == "http://localhost/" + yield diff --git a/tests/login/test_cilogon_oauth2.py b/tests/login/test_cilogon_oauth2.py new file mode 100644 index 000000000..df8cfdeba --- /dev/null +++ b/tests/login/test_cilogon_oauth2.py @@ -0,0 +1,157 @@ +import pytest +from unittest.mock import MagicMock, patch + +from fence import Oauth2ClientBase +from fence.resources.openid.cilogon_oauth2 import CilogonOauth2Client + + +@pytest.fixture +def mock_settings(): + """Fixture to create mock settings.""" + return { + "client_id": "mock_client_id", + "client_secret": "mock_client_secret", + "redirect_url": "https://mock-redirect.com", + "scope": "openid email profile", + } + + +@pytest.fixture +def mock_logger(): + """Fixture to create a mock logger.""" + return MagicMock() + + +@pytest.fixture +def cilogon_client(mock_settings, mock_logger): + """Fixture to create a CilogonOauth2Client instance.""" + return CilogonOauth2Client(mock_settings, mock_logger) + + +@pytest.fixture +def oauth2_client(): + """Fixture to create an instance of Oauth2ClientBase with a mocked session.""" + mock_settings = { + "client_id": "mock_client_id", + "client_secret": "mock_client_secret", + "redirect_url": "https://mock-redirect.com", + "scope": "openid email profile", + "discovery_url": "https://mock-discovery.com", + } + mock_logger = MagicMock() + + client = Oauth2ClientBase(mock_settings, mock_logger, idp="MockIDP") + client.session = MagicMock() + + return client + + +@patch("fence.resources.openid.cilogon_oauth2.Oauth2ClientBase.__init__") +def test_cilogon_client_init(mock_super_init, mock_settings, mock_logger): + """ + Test that the CilogonOauth2Client initializes correctly and calls the parent class. + """ + client = CilogonOauth2Client(mock_settings, mock_logger) + + mock_super_init.assert_called_once_with( + mock_settings, + mock_logger, + scope="openid email profile", + idp="CILogon", + HTTP_PROXY=None, + ) + + assert ( + client.DISCOVERY_URL == "https://cilogon.org/.well-known/openid-configuration" + ) + + +@patch( + "fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_value_from_discovery_doc" +) +def test_get_auth_url(mock_get_value_from_discovery_doc, oauth2_client): + """ + Test that get_auth_url correctly constructs the authorization URL. + """ + mock_get_value_from_discovery_doc.return_value = "https://cilogon.org/authorize" + oauth2_client.session.create_authorization_url.return_value = ( + "https://mock-auth-url.com", + None, + ) + + auth_url = oauth2_client.get_auth_url() + + assert auth_url == "https://mock-auth-url.com" + mock_get_value_from_discovery_doc.assert_called_once_with( + "authorization_endpoint", "" + ) + oauth2_client.session.create_authorization_url.assert_called_once_with( + "https://cilogon.org/authorize", prompt="login" + ) + + +@patch( + "fence.resources.openid.cilogon_oauth2.Oauth2ClientBase.get_value_from_discovery_doc" +) +@patch("fence.resources.openid.cilogon_oauth2.Oauth2ClientBase.get_jwt_claims_identity") +def test_get_auth_info_success( + mock_get_jwt_claims_identity, mock_get_value_from_discovery_doc, cilogon_client +): + """ + Test that get_auth_info correctly extracts user claims when authentication is successful. + """ + mock_get_value_from_discovery_doc.side_effect = [ + "https://cilogon.org/oauth2/token", + "https://cilogon.org/oauth2/certs", + ] + + mock_get_jwt_claims_identity.return_value = ( + {"sub": "mock_user_id"}, + "mock_refresh_token", + "mock_access_token", + ) + + auth_info = cilogon_client.get_auth_info("mock_code") + + assert auth_info == {"sub": "mock_user_id"} + mock_get_value_from_discovery_doc.assert_any_call( + "token_endpoint", "https://cilogon.org/oauth2/token" + ) + mock_get_value_from_discovery_doc.assert_any_call( + "jwks_uri", "https://cilogon.org/oauth2/certs" + ) + mock_get_jwt_claims_identity.assert_called_once_with( + "https://cilogon.org/oauth2/token", + "https://cilogon.org/oauth2/certs", + "mock_code", + ) + + +@patch("fence.resources.openid.cilogon_oauth2.Oauth2ClientBase.get_jwt_claims_identity") +def test_get_auth_info_missing_sub(mock_get_jwt_claims_identity, cilogon_client): + """ + Test that get_auth_info returns an error when 'sub' claim is missing. + """ + mock_get_jwt_claims_identity.return_value = ( + {}, # No 'sub' in claims + "mock_refresh_token", + "mock_access_token", + ) + + auth_info = cilogon_client.get_auth_info("mock_code") + + assert auth_info == {"error": "Can't get user's CILogon sub"} + + +@patch("fence.resources.openid.cilogon_oauth2.Oauth2ClientBase.get_jwt_claims_identity") +def test_get_auth_info_exception(mock_get_jwt_claims_identity, cilogon_client): + """ + Test that get_auth_info handles exceptions and logs an error. + """ + mock_get_jwt_claims_identity.side_effect = Exception("Test Exception") + + auth_info = cilogon_client.get_auth_info("mock_code") + + assert "error" in auth_info + assert "Can't get your CILogon sub" in auth_info["error"] + cilogon_client.logger.exception.assert_called_once_with("Can't get user info") diff --git a/tests/login/test_idp_oauth2.py b/tests/login/test_idp_oauth2.py index 40ae2349a..2145e6027 100644 --- a/tests/login/test_idp_oauth2.py +++ b/tests/login/test_idp_oauth2.py @@ -1,7 +1,16 @@ +import jwt import pytest +import datetime +from jose.exceptions import JWTClaimsError, JWTError +from unittest.mock import ANY +from flask import Flask, g from cdislogging import get_logger +from unittest.mock import MagicMock, Mock, patch + +from fence.resources.openid.idp_oauth2 import Oauth2ClientBase, AuthError +from fence.blueprints.login.base import DefaultOAuth2Callback +from fence.config import config -from fence import Oauth2ClientBase MOCK_SETTINGS_ACR = { "client_id": "client", @@ -39,11 +48,6 @@ def test_has_mfa_claim_acr(oauth_client_acr): assert has_mfa -def test_has_mfa_claim_acr(oauth_client_acr): - has_mfa = oauth_client_acr.has_mfa_claim({"acr": "mfa"}) - assert has_mfa - - def test_has_mfa_claim_multiple_acr(oauth_client_acr): has_mfa = oauth_client_acr.has_mfa_claim({"acr": "mfa otp duo"}) assert has_mfa @@ -83,3 +87,371 @@ def test_does_not_has_mfa_claim_amr(oauth_client_amr): def test_does_not_has_mfa_claim_multiple_amr(oauth_client_amr): has_mfa = oauth_client_amr.has_mfa_claim({"amr": ["pwd, trustme"]}) assert not has_mfa + + +# To test the store_refresh_token method of the Oauth2ClientBase class +def test_store_refresh_token(mock_user, mock_app): + """ + Test the `store_refresh_token` method of the `Oauth2ClientBase` class to ensure that + refresh tokens are correctly stored in the database using the `UpstreamRefreshToken` model. + """ + mock_logger = MagicMock() + mock_settings = { + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "redirect_url": "http://localhost/callback", + "discovery_url": "http://localhost/.well-known/openid-configuration", + "groups": {"read_authz_groups_from_tokens": True, "group_prefix": "/"}, + "user_id_field": "sub", + } + + # Ensure oauth_client is correctly instantiated + oauth_client = Oauth2ClientBase( + settings=mock_settings, logger=mock_logger, idp="test_idp" + ) + + refresh_token = "mock_refresh_token" + expires = 1700000000 + + # Patch the UpstreamRefreshToken to prevent actual database interactions + with patch( + "fence.resources.openid.idp_oauth2.UpstreamRefreshToken", autospec=True + ) as MockUpstreamRefreshToken: + # Mock the db_session's object_session method to return a mocked session object + mock_session = MagicMock() + mock_app.arborist.object_session.return_value = mock_session + + # Call the method to test + oauth_client.store_refresh_token( + mock_user, refresh_token, expires, db_session=mock_app.arborist + ) + + # Check if UpstreamRefreshToken was instantiated correctly + MockUpstreamRefreshToken.assert_called_once_with( + user=mock_user, + refresh_token=refresh_token, + expires=expires, + ) + + # Check if the mock session's `add` and `commit` methods were called + mock_app.arborist.object_session.assert_called_once() + mock_session.add.assert_called_once_with(MockUpstreamRefreshToken.return_value) + mock_app.arborist.commit.assert_called_once() + + +# To test if a user is granted access using the get_auth_info method in Oauth2ClientBase +@patch("fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_jwt_keys") +@patch("jwt.decode") +@patch("authlib.integrations.requests_client.OAuth2Session.fetch_token") +@patch( + "fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_value_from_discovery_doc" +) +def test_get_auth_info_granted_access( + mock_get_value_from_discovery_doc, + mock_fetch_token, + mock_jwt_decode, + mock_get_jwt_keys, + app, +): + """ + Test that the `get_auth_info` method correctly retrieves, processes, and decodes + an OAuth2 authentication token, including access, refresh, and ID tokens, while also + handling JWT decoding and discovery document lookups. + + Raises: + AssertionError: If the expected claims or tokens are not present in the returned authentication information. + """ + + mock_settings = { + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "redirect_url": "http://localhost/callback", + "discovery_url": "http://localhost/.well-known/openid-configuration", + "is_authz_groups_sync_enabled": True, + "authz_groups_sync": {"group_prefix": "/"}, + "user_id_field": "sub", + } + + # Mock logger + mock_logger = MagicMock() + + with app.app_context(): + yield + oauth2_client = Oauth2ClientBase( + settings=mock_settings, logger=mock_logger, idp="test_idp" + ) + + # Mock token endpoint and jwks_uri + mock_get_value_from_discovery_doc.side_effect = lambda key, default=None: ( + "http://localhost/token" + if key == "token_endpoint" + else "http://localhost/jwks" + ) + + # Setup mock response for fetch_token + mock_fetch_token.return_value = { + "access_token": "mock_access_token", + "id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJtb2NrX3VzZXJfaWQiLCJpYXQiOjE2MDk0NTkyMDAsImV4cCI6MTYwOTQ2MjgwMCwiZ3JvdXBzIjpbImdyb3VwMSIsImdyb3VwMiJdfQ.XYZ", + "refresh_token": "mock_refresh_token", + } + + # Setup mock JWT keys response + mock_get_jwt_keys.return_value = [ + { + "kty": "RSA", + "kid": "1e9gdk7", + "use": "sig", + "n": "example-key", + "e": "AQAB", + } + ] + + # Setup mock decoded JWT token + mock_jwt_decode.return_value = { + "sub": "mock_user_id", + "email_verified": True, + "iat": 1609459200, + "exp": 1609462800, + "groups": ["group1", "group2"], + } + + # Log mock setups + logger.debug( + f"Mock token endpoint: {mock_get_value_from_discovery_doc('token_endpoint', '')}" + ) + logger.debug( + f"Mock jwks_uri: {mock_get_value_from_discovery_doc('jwks_uri', '')}" + ) + logger.debug(f"Mock fetch_token response: {mock_fetch_token.return_value}") + logger.debug(f"Mock JWT decode response: {mock_jwt_decode.return_value}") + + # Call the method + code = "mock_code" + auth_info = oauth2_client.get_auth_info(code) + logger.debug(f"Mock auth_info: {auth_info}") + + # Debug: Check if decode was called + logger.debug(f"JWT decode call count: {mock_jwt_decode.call_count}") + logger.debug(f"Returned auth_info: {auth_info}") + logger.debug(f"JWT decode call args: {mock_jwt_decode.call_args_list}") + logger.debug(f"Fetch token response: {mock_fetch_token.return_value}") + + # Assertions + assert "sub" in auth_info, f"Expected 'sub' in auth_info, got {auth_info}" + assert auth_info["sub"] == "mock_user_id" + assert "refresh_token" in auth_info + assert auth_info["refresh_token"] == "mock_refresh_token" + assert "iat" in auth_info + assert auth_info["iat"] == 1609459200 + assert "exp" in auth_info + assert auth_info["exp"] == 1609462800 + assert "groups" in auth_info + assert auth_info["groups"] == ["group1", "group2"] + + +def test_get_access_token_expired(expired_mock_user, mock_db_session): + """ + Test that attempting to retrieve an access token for a user with an expired refresh token + results in an `AuthError`, the user's token is deleted, and the session is committed. + + + Raises: + AuthError: When the user does not have a valid, non-expired refresh token. + """ + mock_settings = { + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "redirect_url": "http://localhost/callback", + "discovery_url": "http://localhost/.well-known/openid-configuration", + "is_authz_groups_sync_enabled": True, + "authz_groups_sync:": {"group_prefix": "/"}, + "user_id_field": "sub", + } + + # Initialize the Oauth2 client object + oauth2_client = Oauth2ClientBase( + settings=mock_settings, logger=MagicMock(), idp="test_idp" + ) + + # Simulate the token expiration and user not having access + with pytest.raises(AuthError) as excinfo: + logger.debug("get_access_token about to be called") + oauth2_client.get_access_token( + expired_mock_user, + token_endpoint="https://token.endpoint", + db_session=mock_db_session, + ) + + logger.debug(f"Raised exception message: {excinfo.value}") + + assert "User doesn't have a valid, non-expired refresh token" in str(excinfo.value) + + mock_db_session.delete.assert_called() + mock_db_session.commit.assert_called() + + +@patch("fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_auth_info") +def test_post_login_with_group_prefix(mock_get_auth_info, app): + """ + Test the `post_login` method of the `DefaultOAuth2Callback` class, ensuring that user groups + fetched from an identity provider (IdP) are processed correctly and prefixed before being added + to the user in the Arborist service. + """ + with app.app_context(): + yield + with patch.dict(config, {"ENABLE_AUTHZ_GROUPS_FROM_OIDC": True}, clear=False): + mock_user = MagicMock() + mock_user.username = "test_user" + mock_user.id = "user_id" + g.user = mock_user + + # Set up mock responses for user info and groups from the IdP + mock_get_auth_info.return_value = { + "username": "test_user", + "groups": ["group1", "group2", "covid/group3", "group4", "group5"], + "exp": datetime.datetime.now(tz=datetime.timezone.utc).timestamp(), + "group_prefix": "covid/", + } + + # Mock the Arborist client and its methods + mock_arborist = MagicMock() + mock_arborist.list_groups.return_value = { + "groups": [ + {"name": "group1"}, + {"name": "group2"}, + {"name": "group3"}, + {"name": "reviewers"}, + ] + } + mock_arborist.add_user_to_group = MagicMock() + mock_arborist.remove_user_from_group = MagicMock() + + # Mock the Flask app + app = MagicMock() + app.arborist = mock_arborist + + # Create the callback object with the mock app + callback = DefaultOAuth2Callback( + idp_name="generic3", client=MagicMock(), app=app + ) + + # Mock user and call post_login + mock_user = MagicMock() + mock_user.username = "test_user" + + # Simulate calling post_login + callback.post_login( + user=g.user, + token_result=mock_get_auth_info.return_value, + groups_from_idp=mock_get_auth_info.return_value["groups"], + group_prefix=mock_get_auth_info.return_value["group_prefix"], + expires_at=mock_get_auth_info.return_value["exp"], + username=mock_user.username, + ) + + # Assertions to check if groups were processed with the correct prefix + mock_arborist.add_user_to_group.assert_any_call( + username="test_user", + group_name="group1", + expires_at=datetime.datetime.fromtimestamp( + mock_get_auth_info.return_value["exp"], tz=datetime.timezone.utc + ), + ) + mock_arborist.add_user_to_group.assert_any_call( + username="test_user", + group_name="group2", + expires_at=datetime.datetime.fromtimestamp( + mock_get_auth_info.return_value["exp"], tz=datetime.timezone.utc + ), + ) + mock_arborist.add_user_to_group.assert_any_call( + username="test_user", + group_name="group3", + expires_at=datetime.datetime.fromtimestamp( + mock_get_auth_info.return_value["exp"], tz=datetime.timezone.utc + ), + ) + + # Ensure the mock was called exactly three times (once for each group that was added) + assert mock_arborist.add_user_to_group.call_count == 3 + + +@patch("fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_jwt_keys") +@patch("authlib.integrations.requests_client.OAuth2Session.fetch_token") +@patch("fence.resources.openid.idp_oauth2.jwt.decode") # Mock jwt.decode +def test_jwt_audience_verification_fails( + mock_jwt_decode, mock_fetch_token, mock_get_jwt_keys +): + """ + Test the JWT audience verification failure scenario. + + This test mocks various components used in the OIDC flow to simulate the + process of obtaining a token, fetching JWKS (JSON Web Key Set), and verifying + the JWT token's claims. Specifically, it focuses on the audience verification + step and tests that an invalid audience raises the expected `JWTClaimsError`. + + + Raises: + JWTClaimsError: When the audience in the JWT token is invalid. + """ + # Mock fetch_token to simulate a successful token fetch + mock_fetch_token.return_value = { + "id_token": "mock-id-token", + "access_token": "mock_access_token", + "refresh_token": "mock-refresh-token", + } + + # Mock JWKS response + mock_jwks_response = { + "keys": [ + { + "kty": "RSA", + "kid": "test-key-id", + "use": "sig", + # Simulate RSA public key values + "n": "mock-n-value", + "e": "mock-e-value", + } + ] + } + + mock_get_jwt_keys.return_value = MagicMock() + mock_get_jwt_keys.return_value = mock_jwks_response + + # Mock jwt.decode to raise JWTClaimsError for audience verification failure + mock_jwt_decode.side_effect = JWTError("Invalid audience") + + # Setup the mock instance of Oauth2ClientBase + client = Oauth2ClientBase( + settings={ + "client_id": "mock-client-id", + "client_secret": "mock-client-secret", + "redirect_url": "mock-redirect-url", + "discovery_url": "http://localhost/discovery", + "audience": "expected-audience", + "verify_aud": True, + }, + logger=MagicMock(), + idp="mock-idp", + ) + + # Invoke the method and expect JWTClaimsError to be raised + with pytest.raises(JWTError, match="Invalid audience"): + client.get_jwt_claims_identity( + token_endpoint="https://token.endpoint", + jwks_endpoint="https://jwks.uri", + code="auth_code", + ) + + # Verify fetch_token was called correctly + mock_fetch_token.assert_called_once_with( + url="https://token.endpoint", code="auth_code", proxies=None + ) + + # Verify jwt.decode was called with the mock id_token + mock_jwt_decode.assert_called_with( + "mock-id-token", # The mock token + key=mock_jwks_response, + options={"verify_signature": False}, + algorithms=["RS256"], + ) diff --git a/tests/login/test_login_shib.py b/tests/login/test_login_shib.py index db18aa483..f0335fcb3 100644 --- a/tests/login/test_login_shib.py +++ b/tests/login/test_login_shib.py @@ -1,6 +1,5 @@ from fence.config import config - def test_shib_redirect(client, app): r = client.get("/login/shib?redirect=http://localhost") assert r.status_code == 302 diff --git a/tests/login/test_microsoft_login.py b/tests/login/test_microsoft_login.py index 972b8a07f..bbe8b9ebf 100755 --- a/tests/login/test_microsoft_login.py +++ b/tests/login/test_microsoft_login.py @@ -1,8 +1,11 @@ """ Tests for fence.resources.openid.microsoft_oauth2.MicrosoftOauth2Client """ + from unittest.mock import patch +from tests.rfc6749.conftest import access_token + def test_get_auth_url(microsoft_oauth2_client): """ @@ -34,9 +37,11 @@ def test_get_auth_info_missing_claim(microsoft_oauth2_client): """ return_value = {"not_email_claim": "user@contoso.com"} expected_value = {"error": "Can't get user's Microsoft email!"} + refresh_token = {} + access_token = {} with patch( "fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_jwt_claims_identity", - return_value=return_value, + return_value=(return_value, refresh_token, access_token), ): user_id = microsoft_oauth2_client.get_auth_info(code="123") assert user_id == expected_value # nosec diff --git a/tests/ras/test_ras.py b/tests/ras/test_ras.py index f3be7575c..ab2bb2258 100644 --- a/tests/ras/test_ras.py +++ b/tests/ras/test_ras.py @@ -25,7 +25,7 @@ from tests.utils import add_test_ras_user, TEST_RAS_USERNAME, TEST_RAS_SUB from tests.dbgap_sync.conftest import add_visa_manually -from fence.job.visa_update_cronjob import Visa_Token_Update +from fence.job.access_token_updater import TokenAndAuthUpdater import tests.utils from tests.conftest import get_subjects_to_passports @@ -95,6 +95,7 @@ def test_update_visa_token( """ Test to check visa table is updated when getting new visa """ + # ensure we don't actually try to reach out to external sites to refresh public keys def validate_jwt_no_key_refresh(*args, **kwargs): kwargs.update({"attempt_refresh": False}) @@ -713,7 +714,7 @@ def _get_userinfo(*args, **kwargs): mock_userinfo.side_effect = _get_userinfo # test "fence-create update-visa" - job = Visa_Token_Update() + job = TokenAndAuthUpdater() job.pkey_cache = { "https://stsstg.nih.gov": { kid: rsa_public_key, diff --git a/tests/test-fence-config.yaml b/tests/test-fence-config.yaml index d0638a091..47a68894a 100755 --- a/tests/test-fence-config.yaml +++ b/tests/test-fence-config.yaml @@ -44,7 +44,7 @@ ENCRYPTION_KEY: '' # ////////////////////////////////////////////////////////////////////////////////////// # flask's debug setting # WARNING: DO NOT ENABLE IN PRODUCTION -DEBUG: true +DEBUG: false # if true, will automatically login a user with username "test" MOCK_AUTH: true # if true, will only fake a successful login response from Google in /login/google @@ -141,6 +141,38 @@ OPENID_CONNECT: redirect_url: '{{BASE_URL}}/login/generic2/login' discovery: authorization_endpoint: 'https://generic2/authorization_endpoint' + generic3: + name: 'generic3' # optional; display name for this IDP + client_id: '' + client_secret: '' + redirect_url: '{{BASE_URL}}/login/generic3/login' # replace IDP name + # use `discovery` to configure IDPs that do not expose a discovery + # endpoint. One of `discovery_url` or `discovery` should be configured + discovery_url: 'https://localhost/.well-known/openid-configuration' + # When true, it allows refresh tokens to be stored even if is_authz_groups_sync_enabled is set false. + # When false, the system will only store refresh tokens if is_authz_groups_sync_enabled is enabled + persist_refresh_token: false + # is_authz_groups_sync_enabled: A configuration flag that determines whether the application should + # verify and synchronize user group memberships between the identity provider (IdP) + # and the local authorization system (Arborist). When enabled, the system retrieves + # the user's group information from their token issued by the IdP and compares it against + # the groups defined in the local system. Based on the comparison, the user is added to + # or removed from relevant groups in the local system to ensure their group memberships + # remain up-to-date. If this flag is disabled, no group synchronization occurs + is_authz_groups_sync_enabled: false + # Key used to retrieve group information from the token + group_claim_field: "groups" + # IdP group membership expiration (seconds). + group_membership_expiration_duration: 604800 + authz_groups_sync: + # This defines the prefix used to identify authorization groups. + group_prefix: /covid + # This flag indicates whether the audience (aud) claim in the JWT should be verified during token validation. + verify_aud: false + # This specifies the expected audience (aud) value for the JWT, ensuring that the token is intended for use with the 'fence' service. + audience: fence + # default refresh token expiration duration + default_refresh_token_exp: 3600 # these are the *possible* scopes a client can be given, NOT scopes that are # given to all clients. You can be more restrictive during client creation diff --git a/tests/test_metrics.py b/tests/test_metrics.py index be7d6b2ab..d0d47786d 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -519,6 +519,18 @@ def test_login_log_login_endpoint( get_auth_info_value = {"generic1_username": username} elif idp == "generic2": get_auth_info_value = {"sub": username} + elif idp == "generic3": + # get_auth_info_value specific to generic3 + # TODO: Need test when is_authz_groups_sync_enabled == true + get_auth_info_value = { + "username": username, + "sub": username, + "email_verified": True, + "iat": 1609459200, + "exp": 1609462800, + "refresh_token": "mock_refresh_token", + "groups": ["group1", "group2"], + } if idp in ["google", "microsoft", "okta", "synapse", "cognito"]: get_auth_info_value["email"] = username @@ -538,6 +550,7 @@ def test_login_log_login_endpoint( ) path = f"/login/{idp}/{callback_endpoint}" # SEE fence/blueprints/login/fence_login.py L91 response = client.get(path, headers=headers) + print(f"Response: {response.status_code}, Body: {response.data}") assert response.status_code == 200, response user_sub = db_session.query(User).filter(User.username == username).first().id audit_service_requests.post.assert_called_once_with(