Skip to content

Commit

Permalink
Merge branch 'master' into feat/al2
Browse files Browse the repository at this point in the history
  • Loading branch information
EliseCastle23 authored Oct 16, 2024
2 parents 1130143 + 6121f2c commit 237d66d
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/integration_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,7 @@ jobs:
CI_TEST_ORCID_PASSWORD: ${{ secrets.CI_TEST_ORCID_PASSWORD }}
CI_TEST_RAS_USERID: ${{ secrets.CI_TEST_RAS_USERID }}
CI_TEST_RAS_PASSWORD: ${{ secrets.CI_TEST_RAS_PASSWORD }}
CI_TEST_RAS_2_USERID: ${{ secrets.CI_TEST_RAS_2_USERID }}
CI_TEST_RAS_2_PASSWORD: ${{ secrets.CI_TEST_RAS_2_PASSWORD }}
CI_SLACK_BOT_TOKEN: ${{ secrets.CI_SLACK_BOT_TOKEN }}
CI_SLACK_CHANNEL_ID: ${{ secrets.CI_SLACK_CHANNEL_ID }}
96 changes: 96 additions & 0 deletions tests/app_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from urllib.parse import urljoin
import flask
import json
import mock
Expand All @@ -6,6 +7,9 @@
import uuid
import urllib

from authlib.oauth2.client import OAuth2Client
from authlib.integrations.requests_client import OAuth2Session

from wts.models import RefreshToken
from wts.resources.oauth2 import find_valid_refresh_token

Expand Down Expand Up @@ -156,6 +160,81 @@ def test_authorize_endpoint(client, test_user, db_session, auth_header):
assert original_refresh_token == fake_tokens["idp_a"]


def test_fetch_token_header(client, test_user, db_session, auth_header, app):
fake_tokens = {"default": "eyJhbGciOiJvvvv", "idp_a": "eyJhbGciOiJwwww"}
app_version = app.config.get("APP_VERSION")

# mock `fetch_access_token` to avoid external calls
mocked_response = mock.MagicMock()
with mock.patch.object(OAuth2Client, "fetch_token", return_value=mocked_response):

# mock `jwt.decode` to return fake data
now = int(time.time())
mocked_jwt_response = mock.MagicMock()
mocked_jwt_response.side_effect = [
# decoded id_token for IdP "default":
{"context": {"user": {"name": test_user.username}}},
# decoded refresh_token for IdP "default":
{
"jti": str(uuid.uuid4()),
"exp": now + 100,
"sub": test_user.userid,
"scope": ["openid", "access", "user", "test_aud"],
"aud": "https://localhost/user",
"iss": "https://localhost/user",
},
# decoded id_token for IdP "idp_a":
{"context": {"user": {"name": test_user.username}}},
# decoded refresh_token for IdP "idp_a":
{
"jti": str(uuid.uuid4()),
"exp": now + 100,
"sub": test_user.userid,
"scope": ["openid", "access", "user", "test_aud"],
"aud": "https://localhost/user",
"iss": "https://localhost/user",
},
]
patched_jwt_decode = mock.patch("jose.jwt.decode", mocked_jwt_response)
patched_jwt_decode.start()

# get refresh token for IdP "default"
OAuth2Client.fetch_token.return_value = {
"refresh_token": fake_tokens["default"],
"id_token": "eyJhbGciOiJ",
}
fake_state = "qwerty"
with client.session_transaction() as session:
session["state"] = fake_state
res = client.get(
"/oauth2/authorize?state={}".format(fake_state), headers=auth_header
)
OAuth2Client.fetch_token.assert_called_with(
"https://localhost/user/oauth2/token",
headers={"User-Agent": f"Gen3WTS/{app_version}"},
state=fake_state,
)
assert res.status_code == 200, res.json

# get refresh token for IdP "idp_a"
OAuth2Client.fetch_token.return_value = {
"refresh_token": fake_tokens["idp_a"],
"id_token": "eyJhbGciOiJ",
}
with client.session_transaction() as session:
session["state"] = fake_state
session["idp"] = "idp_a"
res = client.get(
"/oauth2/authorize?state={}".format(fake_state), headers=auth_header
)
OAuth2Client.fetch_token.assert_called_with(
"https://some.data.commons/user/oauth2/token",
headers={"User-Agent": f"Gen3WTS/{app_version}"},
state=fake_state,
)
assert res.status_code == 200


def test_authorization_url_endpoint(client):
res = client.get("/oauth2/authorization_url?idp=idp_a")
assert res.status_code == 302
Expand Down Expand Up @@ -207,6 +286,23 @@ def test_external_oidc_endpoint_with_persisted_refresh_tokens(
assert provider["refresh_token_expiration"] == None


def test_revoke_token_header(client, auth_header, app):

url = urljoin(app.config.get("USER_API"), "/oauth2/revoke")
app_version = app.config.get("APP_VERSION")

with mock.patch.object(
OAuth2Session,
"revoke_token",
):
res = client.get("/oauth2/logout", headers=auth_header)
assert res.status_code == 204
assert res.text == ""
OAuth2Session.revoke_token.assert_called_with(
url, None, headers={"User-Agent": f"Gen3WTS/{app_version}"}
)


def test_app_config(app):
assert (
app.config["OIDC"]["idp_a"]["redirect_uri"]
Expand Down
2 changes: 2 additions & 0 deletions wts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from cryptography.fernet import Fernet
import flask
from flask import Flask
from importlib import metadata
import json
from urllib.parse import urlparse, urljoin
from cdislogging import get_logger
Expand Down Expand Up @@ -120,6 +121,7 @@ def load_settings(app):
app.config["SESSION_COOKIE_NAME"] = "wts"
app.config["SESSION_COOKIE_SECURE"] = True
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
app.config["APP_VERSION"] = metadata.version("wts")


def _log_and_jsonify_exception(e):
Expand Down
5 changes: 4 additions & 1 deletion wts/blueprints/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def logout_oauth():
client = get_oauth_client(idp="default")

try:
client.session.revoke_token(url, token)
app_version = flask.current_app.config.get("APP_VERSION", "0.0.0")
client.session.revoke_token(
url, token, headers={"User-Agent": f"Gen3WTS/{app_version}"}
)
except APIError as e:
msg = "could not log out, failed to revoke token: {}".format(e.message)
return msg, 400
Expand Down
9 changes: 8 additions & 1 deletion wts/resources/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ def client_do_authorize():
if mismatched_state:
raise AuthError("could not authorize; state did not match across auth requests")
try:
tokens = client.fetch_token(token_url, **flask.request.args.to_dict())
app_version = flask.current_app.config.get("APP_VERSION", "0.0.0")
tokens = client.fetch_token(
token_url,
headers={"User-Agent": f"Gen3WTS/{app_version}"},
**flask.request.args.to_dict(),
)
refresh_refresh_token(tokens, requested_idp, username_field)
except KeyError as e:
raise AuthError("error in token response: {}".format(tokens))
Expand All @@ -46,6 +51,7 @@ def find_valid_refresh_token(username, idp):
flask.current_app.logger.info("Purging expired token {}".format(token.jti))
else:
has_valid = True
db.session.close()
return has_valid


Expand Down Expand Up @@ -112,3 +118,4 @@ def refresh_refresh_token(tokens, idp, username_field):
)
db.session.add(new_token)
db.session.commit()
db.session.close()

0 comments on commit 237d66d

Please sign in to comment.