diff --git a/.github/workflows/integration_tests.yaml b/.github/workflows/integration_tests.yaml index 5a2fe97..1cf0f05 100644 --- a/.github/workflows/integration_tests.yaml +++ b/.github/workflows/integration_tests.yaml @@ -22,5 +22,7 @@ jobs: CI_TEST_ORCID_PASSWORD: ${{ secrets.CI_TEST_ORCID_PASSWORD }} CI_TEST_RAS_USERID: ${{ secrets.CI_TEST_RAS_USERID }} CI_TEST_RAS_PASSWORD: ${{ secrets.CI_TEST_RAS_PASSWORD }} + CI_TEST_RAS_2_USERID: ${{ secrets.CI_TEST_RAS_2_USERID }} + CI_TEST_RAS_2_PASSWORD: ${{ secrets.CI_TEST_RAS_2_PASSWORD }} CI_SLACK_BOT_TOKEN: ${{ secrets.CI_SLACK_BOT_TOKEN }} CI_SLACK_CHANNEL_ID: ${{ secrets.CI_SLACK_CHANNEL_ID }} diff --git a/tests/app_test.py b/tests/app_test.py index 140346b..9356a23 100644 --- a/tests/app_test.py +++ b/tests/app_test.py @@ -1,3 +1,4 @@ +from urllib.parse import urljoin import flask import json import mock @@ -6,6 +7,9 @@ import uuid import urllib +from authlib.oauth2.client import OAuth2Client +from authlib.integrations.requests_client import OAuth2Session + from wts.models import RefreshToken from wts.resources.oauth2 import find_valid_refresh_token @@ -156,6 +160,81 @@ def test_authorize_endpoint(client, test_user, db_session, auth_header): assert original_refresh_token == fake_tokens["idp_a"] +def test_fetch_token_header(client, test_user, db_session, auth_header, app): + fake_tokens = {"default": "eyJhbGciOiJvvvv", "idp_a": "eyJhbGciOiJwwww"} + app_version = app.config.get("APP_VERSION") + + # mock `fetch_access_token` to avoid external calls + mocked_response = mock.MagicMock() + with mock.patch.object(OAuth2Client, "fetch_token", return_value=mocked_response): + + # mock `jwt.decode` to return fake data + now = int(time.time()) + mocked_jwt_response = mock.MagicMock() + mocked_jwt_response.side_effect = [ + # decoded id_token for IdP "default": + {"context": {"user": {"name": test_user.username}}}, + # decoded refresh_token for IdP "default": + { + "jti": str(uuid.uuid4()), + "exp": now + 100, + "sub": test_user.userid, + "scope": ["openid", "access", "user", "test_aud"], + "aud": "https://localhost/user", + "iss": "https://localhost/user", + }, + # decoded id_token for IdP "idp_a": + {"context": {"user": {"name": test_user.username}}}, + # decoded refresh_token for IdP "idp_a": + { + "jti": str(uuid.uuid4()), + "exp": now + 100, + "sub": test_user.userid, + "scope": ["openid", "access", "user", "test_aud"], + "aud": "https://localhost/user", + "iss": "https://localhost/user", + }, + ] + patched_jwt_decode = mock.patch("jose.jwt.decode", mocked_jwt_response) + patched_jwt_decode.start() + + # get refresh token for IdP "default" + OAuth2Client.fetch_token.return_value = { + "refresh_token": fake_tokens["default"], + "id_token": "eyJhbGciOiJ", + } + fake_state = "qwerty" + with client.session_transaction() as session: + session["state"] = fake_state + res = client.get( + "/oauth2/authorize?state={}".format(fake_state), headers=auth_header + ) + OAuth2Client.fetch_token.assert_called_with( + "https://localhost/user/oauth2/token", + headers={"User-Agent": f"Gen3WTS/{app_version}"}, + state=fake_state, + ) + assert res.status_code == 200, res.json + + # get refresh token for IdP "idp_a" + OAuth2Client.fetch_token.return_value = { + "refresh_token": fake_tokens["idp_a"], + "id_token": "eyJhbGciOiJ", + } + with client.session_transaction() as session: + session["state"] = fake_state + session["idp"] = "idp_a" + res = client.get( + "/oauth2/authorize?state={}".format(fake_state), headers=auth_header + ) + OAuth2Client.fetch_token.assert_called_with( + "https://some.data.commons/user/oauth2/token", + headers={"User-Agent": f"Gen3WTS/{app_version}"}, + state=fake_state, + ) + assert res.status_code == 200 + + def test_authorization_url_endpoint(client): res = client.get("/oauth2/authorization_url?idp=idp_a") assert res.status_code == 302 @@ -207,6 +286,23 @@ def test_external_oidc_endpoint_with_persisted_refresh_tokens( assert provider["refresh_token_expiration"] == None +def test_revoke_token_header(client, auth_header, app): + + url = urljoin(app.config.get("USER_API"), "/oauth2/revoke") + app_version = app.config.get("APP_VERSION") + + with mock.patch.object( + OAuth2Session, + "revoke_token", + ): + res = client.get("/oauth2/logout", headers=auth_header) + assert res.status_code == 204 + assert res.text == "" + OAuth2Session.revoke_token.assert_called_with( + url, None, headers={"User-Agent": f"Gen3WTS/{app_version}"} + ) + + def test_app_config(app): assert ( app.config["OIDC"]["idp_a"]["redirect_uri"] diff --git a/wts/api.py b/wts/api.py index dabfeac..65eb1fb 100644 --- a/wts/api.py +++ b/wts/api.py @@ -3,6 +3,7 @@ from cryptography.fernet import Fernet import flask from flask import Flask +from importlib import metadata import json from urllib.parse import urlparse, urljoin from cdislogging import get_logger @@ -120,6 +121,7 @@ def load_settings(app): app.config["SESSION_COOKIE_NAME"] = "wts" app.config["SESSION_COOKIE_SECURE"] = True app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False + app.config["APP_VERSION"] = metadata.version("wts") def _log_and_jsonify_exception(e): diff --git a/wts/blueprints/oauth2.py b/wts/blueprints/oauth2.py index 8aa6afe..ab3beee 100644 --- a/wts/blueprints/oauth2.py +++ b/wts/blueprints/oauth2.py @@ -95,7 +95,10 @@ def logout_oauth(): client = get_oauth_client(idp="default") try: - client.session.revoke_token(url, token) + app_version = flask.current_app.config.get("APP_VERSION", "0.0.0") + client.session.revoke_token( + url, token, headers={"User-Agent": f"Gen3WTS/{app_version}"} + ) except APIError as e: msg = "could not log out, failed to revoke token: {}".format(e.message) return msg, 400 diff --git a/wts/resources/oauth2.py b/wts/resources/oauth2.py index 189ec2b..2bec1a6 100644 --- a/wts/resources/oauth2.py +++ b/wts/resources/oauth2.py @@ -28,7 +28,12 @@ def client_do_authorize(): if mismatched_state: raise AuthError("could not authorize; state did not match across auth requests") try: - tokens = client.fetch_token(token_url, **flask.request.args.to_dict()) + app_version = flask.current_app.config.get("APP_VERSION", "0.0.0") + tokens = client.fetch_token( + token_url, + headers={"User-Agent": f"Gen3WTS/{app_version}"}, + **flask.request.args.to_dict(), + ) refresh_refresh_token(tokens, requested_idp, username_field) except KeyError as e: raise AuthError("error in token response: {}".format(tokens)) @@ -46,6 +51,7 @@ def find_valid_refresh_token(username, idp): flask.current_app.logger.info("Purging expired token {}".format(token.jti)) else: has_valid = True + db.session.close() return has_valid @@ -112,3 +118,4 @@ def refresh_refresh_token(tokens, idp, username_field): ) db.session.add(new_token) db.session.commit() + db.session.close()