From 487d217e709132fd11fc3b6852693ea797bb56ea Mon Sep 17 00:00:00 2001 From: Michael Chouinard <46358556+chouinar@users.noreply.github.com> Date: Thu, 21 Nov 2024 12:18:30 -0500 Subject: [PATCH] [Issue #2809] Handle parsing the jwt we created, and connect to a user (#2959) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Fixes #2809 ### Time to review: __15 mins__ ## Changes proposed Setup logic to process the jwt we created in https://github.com/HHS/simpler-grants-gov/pull/2898 Setup a method to automatically generate a key for local development in a secure way via an override file ## Context for reviewers The core part of this PR is pretty straightforward, we parse the JWT, do some validation, raise specific error messages for certain scenarios, and have tests for that behavior. For the auth token in the API request header, instead of using `Bearer ..` I left it as a dedicated header field. The bearer format doesn't let you specify the header name and if we ever need multiple tokens supported in an endpoint will lead to more headache. --- Where things got complex was setting up the private/public key for the API. These just need to be stored in env vars, but putting them directly in our local.env file isn't ideal - even though the key will be distinctly local-only, it will always be something flagged in security scans and just generally look problematic. To work around this fun problem, I realized I could solve another annoyance at the same time, Docker as of January 2024 allows you to specify multiple env files + make them optional. So - I used that. I setup a script that creates an `override.env` file that you can freely modify and won't be checked in, and more importantly, automatically contains secrets like those public/private keys we didn't want to check in. (Note - if you're wondering why I didn't use Docker secrets, they're far more complex and this PR would've been 20+ files to make that half-work). ## Additional information Locally I confirmed we can set tokens in the swagger docs and they work - we don't yet have an endpoint that uses this outside of the unit test I setup, but I temporarily modified the healthcheck endpoint to validate things work outside of tests as well. Screenshot 2024-11-20 at 4 06 48 PM The override file we generate looks like this (with the relevant key info removed): ``` # override.env # # Any environment variables written to this file # will take precedence over those defined in local.env # # This file will not be checked into github and it is safe # to store secrets here, however you should still follow caution # with using any secrets locally if they cause the app to interact # with external systems. # # This file was generated by running: # make setup-env-overrides # # Which runs as part of our "make init" flow. # # If you would like to re-generate this file, please run: # make setup-env-overrides --recreate # # Note that this will completely erase any existing configuration you may have ############################ # Authentication ############################ API_JWT_PRIVATE_KEY="-----BEGIN RSA PRIVATE KEY----- ... -----END RSA PRIVATE KEY-----" API_JWT_PUBLIC_KEY="-----BEGIN PUBLIC KEY----- ... -----END PUBLIC KEY-----" ``` --------- Co-authored-by: nava-platform-bot --- api/.gitignore | 8 +- api/Makefile | 5 +- api/bin/setup-env-override-file.sh | 123 ++++++++++++++++++ api/docker-compose.yml | 6 +- api/local.env | 6 + api/openapi.generated.yml | 4 + api/src/app.py | 9 +- api/src/auth/api_jwt_auth.py | 89 ++++++++++--- api/src/auth/api_key_auth.py | 7 +- api/src/auth/auth_errors.py | 4 +- api/src/auth/auth_utils.py | 8 ++ api/tests/src/auth/test_api_jwt_auth.py | 158 ++++++++++++++++++++++-- api/tests/src/db/models/factories.py | 14 +++ 13 files changed, 407 insertions(+), 34 deletions(-) create mode 100755 api/bin/setup-env-override-file.sh create mode 100644 api/src/auth/auth_utils.py diff --git a/api/.gitignore b/api/.gitignore index 3db52578e..4568f96f0 100644 --- a/api/.gitignore +++ b/api/.gitignore @@ -19,6 +19,7 @@ coverage.* # Environment variables .env .envrc +override.env # mypy .mypy_cache @@ -31,4 +32,9 @@ coverage.* /test-results/ # localstack -/volume \ No newline at end of file +/volume + +# All pem/pub/secret keys +*.key +*.pub +*.pem diff --git a/api/Makefile b/api/Makefile index 41cce23f1..8688c3fc2 100644 --- a/api/Makefile +++ b/api/Makefile @@ -84,6 +84,9 @@ setup-local: # Install dependencies poetry install --no-root --all-extras --with dev +setup-env-override-file: + ./bin/setup-env-override-file.sh $(args) + ################################################## # API Build & Run ################################################## @@ -100,7 +103,7 @@ start-debug: run-logs: start ## Start the API and follow the logs docker compose logs --follow --no-color $(APP_NAME) -init: build init-db init-opensearch init-localstack +init: setup-env-override-file build init-db init-opensearch init-localstack clean-volumes: ## Remove project docker volumes - which includes the DB, and OpenSearch state docker compose down --volumes diff --git a/api/bin/setup-env-override-file.sh b/api/bin/setup-env-override-file.sh new file mode 100755 index 000000000..3f314a64a --- /dev/null +++ b/api/bin/setup-env-override-file.sh @@ -0,0 +1,123 @@ +#!/usr/bin/env bash +# setup-env-override-file.sh +# +# Generate an override.env file +# with secrets pre-populated for local development. +# +# Examples: +# ./setup-env-override-file.sh +# ./setup-env-override-file.sh --recreate +# + +set -o errexit -o pipefail + +PROGRAM_NAME=$(basename "$0") + +CYAN='\033[96m' +GREEN='\033[92m' +RED='\033[01;31m' +END='\033[0m' + +USAGE="Usage: $PROGRAM_NAME [OPTION] + + --recreate Recreate the override.env file, fully overwriting any existing file +" + +main() { + print_log "Running $PROGRAM_NAME" + + for arg in "$@" + do + if [ "$arg" == "--recreate" ]; then + recreate=1 + else + echo "$USAGE" + exit 1 + fi + done + + OVERRIDE_FILE="override.env" + + if [ -f "$OVERRIDE_FILE" ] ; then + if [ $recreate ] ; then + print_log "Recreating existing override.env file" + else + print_log "override.env already exists, not recreating" + exit 0 + fi + fi + + # Delete any key files that may be leftover from a prior run + cleanup_files + + # Generate RSA keys + # note ssh-keygen generates a different format for + # the public key so we run it through openssl to fix it + ssh-keygen -t rsa -b 2048 -m PEM -N '' -f tmp_jwk.key 2>&1 >/dev/null + openssl rsa -in tmp_jwk.key -pubout -outform PEM -out tmp_jwk.pub + + PUBLIC_KEY=`cat tmp_jwk.pub` + PRIVATE_KEY=`cat tmp_jwk.key` + + cat > $OVERRIDE_FILE < APIFlask: register_index(app) register_search_client(app) + # TODO - once we merge the auth changes for setting up the initial route + # will reuse the config from it, for now we'll do this a bit hacky + # This cannot be removed non-locally until we setup RSA keys for non-local envs + if os.getenv("ENVIRONMENT") == "local": + initialize_jwt_auth() + return app diff --git a/api/src/auth/api_jwt_auth.py b/api/src/auth/api_jwt_auth.py index 14cf2c8c5..5c0e979d1 100644 --- a/api/src/auth/api_jwt_auth.py +++ b/api/src/auth/api_jwt_auth.py @@ -1,19 +1,27 @@ import logging import uuid from datetime import timedelta +from typing import Tuple import jwt +from apiflask import HTTPTokenAuth from pydantic import Field from sqlalchemy import select +from sqlalchemy.orm import selectinload import src.util.datetime_util as datetime_util from src.adapters import db +from src.adapters.db import flask_db +from src.api.route_utils import raise_flask_error from src.auth.auth_errors import JwtValidationError from src.db.models.user_models import User, UserTokenSession +from src.logging.flask_logger import add_extra_data_to_current_request_logs from src.util.env_config import PydanticBaseEnvConfig logger = logging.getLogger(__name__) +api_jwt_auth = HTTPTokenAuth("ApiKey", header="X-SGG-Token", security_scheme_name="ApiJwtAuth") + class ApiJwtConfig(PydanticBaseEnvConfig): @@ -32,7 +40,7 @@ class ApiJwtConfig(PydanticBaseEnvConfig): _config: ApiJwtConfig | None = None -def initialize() -> None: +def initialize_jwt_auth() -> None: global _config if not _config: _config = ApiJwtConfig() @@ -53,14 +61,14 @@ def get_config() -> ApiJwtConfig: global _config if _config is None: - raise Exception("No JWT configuration - initialize() must be run first") + raise Exception("No JWT configuration - initialize_jwt_auth() must be run first") return _config def create_jwt_for_user( user: User, db_session: db.Session, config: ApiJwtConfig | None = None -) -> str: +) -> Tuple[str, UserTokenSession]: if config is None: config = get_config() @@ -72,13 +80,8 @@ def create_jwt_for_user( expiration_time = current_time + timedelta(minutes=config.token_expiration_minutes) # Create the session in the DB - db_session.add( - UserTokenSession( - user=user, - token_id=token_id, - expires_at=expiration_time, - ) - ) + user_token_session = UserTokenSession(user=user, token_id=token_id, expires_at=expiration_time) + db_session.add(user_token_session) # Create the JWT with information we'll want to receive back payload = { @@ -89,13 +92,21 @@ def create_jwt_for_user( "iss": config.issuer, } - return jwt.encode(payload, config.private_key, algorithm="RS256") + logger.info( + "Created JWT token", + extra={ + "auth.user_id": str(user_token_session.user_id), + "auth.token_id": str(user_token_session.token_id), + }, + ) + + return jwt.encode(payload, config.private_key, algorithm="RS256"), user_token_session def parse_jwt_for_user( token: str, db_session: db.Session, config: ApiJwtConfig | None = None -) -> User: - # TODO - more implementation/validation to come in https://github.com/HHS/simpler-grants-gov/issues/2809 +) -> UserTokenSession: + """Handle processing a jwt token, and connecting it to a user token session in our DB""" if config is None: config = get_config() @@ -135,8 +146,10 @@ def parse_jwt_for_user( raise JwtValidationError("Token missing sub field") token_session: UserTokenSession | None = db_session.execute( - select(UserTokenSession).join(User).where(UserTokenSession.token_id == sub_id) - ).scalar_one_or_none() + select(UserTokenSession) + .where(UserTokenSession.token_id == sub_id) + .options(selectinload("*")) + ).scalar() # We check both the token expires_at timestamp as well as an # is_valid flag to make sure the token is still valid. @@ -147,4 +160,48 @@ def parse_jwt_for_user( if token_session.is_valid is False: raise JwtValidationError("Token is no longer valid") - return token_session.user + return token_session + + +@api_jwt_auth.verify_token +@flask_db.with_db_session() +def decode_token(db_session: db.Session, token: str) -> UserTokenSession: + """ + Process an internal jwt token as created by the above create_jwt_for_user method. + + To add this auth to an endpoint, simply put:: + + from src.auth.api_jwt_auth import api_jwt_auth + + @example_blueprint.get("/example") + @example_blueprint.auth_required(api_jwt_auth) + @flask_db.with_db_session() + def example_method(db_session: db.Session) -> response.ApiResponse: + # The token session object can be fetched from the auth object + token_session: UserTokenSession = api_jwt_auth.current_user + + # If you want to modify the token_session or user, you will + # need to add it to the DB session otherwise it won't do anything + db_session.add(token_session) + token_session.expires_at = ... + ... + """ + + try: + user_token_session = parse_jwt_for_user(token, db_session) + + add_extra_data_to_current_request_logs( + { + "auth.user_id": str(user_token_session.user_id), + "auth.token_id": str(user_token_session.token_id), + } + ) + logger.info("JWT Authentication Successful") + + # Return the user token session object + return user_token_session + except JwtValidationError as e: + # If validation of the jwt fails, pass the error message back to the user + # The message is just the value we set when constructing the JwtValidationError + logger.info("JWT Authentication Failed for provided token", extra={"auth.issue": e.message}) + raise_flask_error(401, e.message) diff --git a/api/src/auth/api_key_auth.py b/api/src/auth/api_key_auth.py index 2359d29b3..0554cac5c 100644 --- a/api/src/auth/api_key_auth.py +++ b/api/src/auth/api_key_auth.py @@ -1,7 +1,6 @@ import logging import os from dataclasses import dataclass -from typing import Any import flask from apiflask import HTTPTokenAuth @@ -15,11 +14,7 @@ # this needs to be attached to your # routes as `your_blueprint.auth_required(api_key_auth)` # in order to enable authorization -api_key_auth = HTTPTokenAuth("ApiKey", header="X-Auth") - - -def get_app_security_scheme() -> dict[str, Any]: - return {"ApiKeyAuth": {"type": "apiKey", "in": "header", "name": "X-Auth"}} +api_key_auth = HTTPTokenAuth("ApiKey", header="X-Auth", security_scheme_name="ApiKeyAuth") @dataclass diff --git a/api/src/auth/auth_errors.py b/api/src/auth/auth_errors.py index 1e2bdee1b..bb48dd1ff 100644 --- a/api/src/auth/auth_errors.py +++ b/api/src/auth/auth_errors.py @@ -5,4 +5,6 @@ class JwtValidationError(Exception): cause the endpoint to raise a 401 """ - pass + def __init__(self, message: str): + super().__init__(message) + self.message = message diff --git a/api/src/auth/auth_utils.py b/api/src/auth/auth_utils.py new file mode 100644 index 000000000..0db25d732 --- /dev/null +++ b/api/src/auth/auth_utils.py @@ -0,0 +1,8 @@ +from typing import Any + + +def get_app_security_scheme() -> dict[str, Any]: + return { + "ApiKeyAuth": {"type": "apiKey", "in": "header", "name": "X-Auth"}, + "ApiJwtAuth": {"type": "apiKey", "in": "header", "name": "X-SGG-Token"}, + } diff --git a/api/tests/src/auth/test_api_jwt_auth.py b/api/tests/src/auth/test_api_jwt_auth.py index 4770327f0..43832b615 100644 --- a/api/tests/src/auth/test_api_jwt_auth.py +++ b/api/tests/src/auth/test_api_jwt_auth.py @@ -5,7 +5,14 @@ import pytest from freezegun import freeze_time -from src.auth.api_jwt_auth import ApiJwtConfig, create_jwt_for_user, parse_jwt_for_user +import src.app as app_entry +import src.logging +from src.auth.api_jwt_auth import ( + ApiJwtConfig, + api_jwt_auth, + create_jwt_for_user, + parse_jwt_for_user, +) from src.db.models.user_models import UserTokenSession from tests.src.db.models.factories import UserFactory @@ -18,12 +25,39 @@ def jwt_config(private_rsa_key, public_rsa_key): ) +@pytest.fixture(scope="module") +def mini_app(monkeypatch_module): + def stub(app): + pass + + """Create a separate app that we can modify separate from the base one used by other tests""" + # We want all the configurational setup for the app, but + # don't want blueprints to keep setup simpler + monkeypatch_module.setattr(app_entry, "register_blueprints", stub) + monkeypatch_module.setattr(app_entry, "setup_logging", stub) + mini_app = app_entry.create_app() + + @mini_app.get("/dummy_auth_endpoint") + @mini_app.auth_required(api_jwt_auth) + def dummy_endpoint(): + # For the tests that actually get past auth + # make sure the current user is set to the user session + assert api_jwt_auth.current_user is not None + assert isinstance(api_jwt_auth.current_user, UserTokenSession) + + return {"message": "ok"} + + # To avoid re-initializing logging everytime we + # setup the app, we disabled it above and do it here + # in case you want it while running your tests + with src.logging.init(__package__): + yield mini_app + + @freeze_time("2024-11-14 12:00:00", tz_offset=0) def test_create_jwt_for_user(enable_factory_create, db_session, jwt_config): user = UserFactory.create() - - token = create_jwt_for_user(user, db_session, jwt_config) - + token, token_session = create_jwt_for_user(user, db_session, jwt_config) decoded_token = jwt.decode( token, algorithms=[jwt_config.algorithm], options={"verify_signature": False} ) @@ -49,6 +83,116 @@ def test_create_jwt_for_user(enable_factory_create, db_session, jwt_config): assert token_session.expires_at == datetime.fromisoformat("2024-11-14 12:30:00+00:00") # Basic testing that the JWT we create for a user can in turn be fetched and processed later - # TODO - more in https://github.com/HHS/simpler-grants-gov/issues/2809 - parsed_user = parse_jwt_for_user(token, db_session, jwt_config) - assert parsed_user.user_id == user.user_id + user_session = parse_jwt_for_user(token, db_session, jwt_config) + assert user_session.user_id == user.user_id + + +def test_api_jwt_auth_happy_path(mini_app, enable_factory_create, db_session): + user = UserFactory.create() + token, _ = create_jwt_for_user(user, db_session) + db_session.commit() # need to commit here to push the session to the DB + + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": token}) + + assert resp.status_code == 200 + assert resp.get_json()["message"] == "ok" + + +def test_api_jwt_auth_expired_token(mini_app, enable_factory_create, db_session): + user = UserFactory.create() + token, session = create_jwt_for_user(user, db_session) + session.expires_at = datetime.fromisoformat("1980-01-01 12:00:00+00:00") + db_session.commit() + + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": token}) + + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Token expired" + + +def test_api_jwt_auth_invalid_token(mini_app, enable_factory_create, db_session): + user = UserFactory.create() + token, session = create_jwt_for_user(user, db_session) + session.is_valid = False + db_session.commit() + + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": token}) + + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Token is no longer valid" + + +def test_api_jwt_auth_token_missing_in_db(mini_app, enable_factory_create, db_session): + user = UserFactory.create() + token, session = create_jwt_for_user(user, db_session) + db_session.expunge(session) # Just drop it, never sending to the DB + + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": token}) + + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Token session does not exist" + + +def test_api_jwt_auth_token_not_jwt(mini_app, enable_factory_create, db_session): + # Just call with a random set of characters + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": "abc123"}) + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Unable to process token" + + +def test_api_jwt_auth_token_created_with_different_key( + mini_app, enable_factory_create, db_session, jwt_config +): + # Note - jwt_config uses a key generated in the conftest within this directory + # while the config the app picks up grabs a key from our override.env file + user = UserFactory.create() + token, _ = create_jwt_for_user(user, db_session, jwt_config) + db_session.commit() + + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": token}) + + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Unable to process token" + + +def test_api_jwt_auth_token_iat_future(mini_app, enable_factory_create, db_session): + # Set time to the 14th so the iat value will be then + with freeze_time("2024-11-14 12:00:00", tz_offset=0): + user = UserFactory.create() + token, _ = create_jwt_for_user(user, db_session) + db_session.commit() + + # Set time to the 12th when calling the API so the iat will be in the future now + with freeze_time("2024-11-12 12:00:00", tz_offset=0): + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": token}) + + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Token not yet valid" + + +def test_api_jwt_auth_token_unknown_issuer(mini_app, enable_factory_create, db_session): + config = ApiJwtConfig(API_JWT_ISSUER="some-guy") + user = UserFactory.create() + token, _ = create_jwt_for_user(user, db_session, config) + db_session.commit() + + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": token}) + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Unknown Issuer" + + +def test_api_jwt_auth_token_unknown_audience(mini_app, enable_factory_create, db_session): + config = ApiJwtConfig(API_JWT_AUDIENCE="someone-else") + user = UserFactory.create() + token, _ = create_jwt_for_user(user, db_session, config) + db_session.commit() + + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": token}) + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Unknown Audience" + + +def test_api_jwt_auth_no_token(mini_app, enable_factory_create, db_session): + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={}) + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Unable to process token" diff --git a/api/tests/src/db/models/factories.py b/api/tests/src/db/models/factories.py index 05024951f..5a387f254 100644 --- a/api/tests/src/db/models/factories.py +++ b/api/tests/src/db/models/factories.py @@ -1857,3 +1857,17 @@ class Meta: external_user_type_id = factory.fuzzy.FuzzyChoice(ExternalUserType) email = factory.Faker("email") + + +class UserTokenSessionFactory(BaseFactory): + class Meta: + model = user_models.UserTokenSession + + user = factory.SubFactory(UserFactory) + user_id = factory.LazyAttribute(lambda s: s.user.user_id) + + token_id = Generators.UuidObj + + expires_at = factory.Faker("date_time_between", start_date="+1d", end_date="+10d") + + is_valid = True