diff --git a/api/.gitignore b/api/.gitignore index 3db52578e..4568f96f0 100644 --- a/api/.gitignore +++ b/api/.gitignore @@ -19,6 +19,7 @@ coverage.* # Environment variables .env .envrc +override.env # mypy .mypy_cache @@ -31,4 +32,9 @@ coverage.* /test-results/ # localstack -/volume \ No newline at end of file +/volume + +# All pem/pub/secret keys +*.key +*.pub +*.pem diff --git a/api/Makefile b/api/Makefile index 41cce23f1..8688c3fc2 100644 --- a/api/Makefile +++ b/api/Makefile @@ -84,6 +84,9 @@ setup-local: # Install dependencies poetry install --no-root --all-extras --with dev +setup-env-override-file: + ./bin/setup-env-override-file.sh $(args) + ################################################## # API Build & Run ################################################## @@ -100,7 +103,7 @@ start-debug: run-logs: start ## Start the API and follow the logs docker compose logs --follow --no-color $(APP_NAME) -init: build init-db init-opensearch init-localstack +init: setup-env-override-file build init-db init-opensearch init-localstack clean-volumes: ## Remove project docker volumes - which includes the DB, and OpenSearch state docker compose down --volumes diff --git a/api/bin/setup-env-override-file.sh b/api/bin/setup-env-override-file.sh new file mode 100755 index 000000000..3f314a64a --- /dev/null +++ b/api/bin/setup-env-override-file.sh @@ -0,0 +1,123 @@ +#!/usr/bin/env bash +# setup-env-override-file.sh +# +# Generate an override.env file +# with secrets pre-populated for local development. +# +# Examples: +# ./setup-env-override-file.sh +# ./setup-env-override-file.sh --recreate +# + +set -o errexit -o pipefail + +PROGRAM_NAME=$(basename "$0") + +CYAN='\033[96m' +GREEN='\033[92m' +RED='\033[01;31m' +END='\033[0m' + +USAGE="Usage: $PROGRAM_NAME [OPTION] + + --recreate Recreate the override.env file, fully overwriting any existing file +" + +main() { + print_log "Running $PROGRAM_NAME" + + for arg in "$@" + do + if [ "$arg" == "--recreate" ]; then + recreate=1 + else + echo "$USAGE" + exit 1 + fi + done + + OVERRIDE_FILE="override.env" + + if [ -f "$OVERRIDE_FILE" ] ; then + if [ $recreate ] ; then + print_log "Recreating existing override.env file" + else + print_log "override.env already exists, not recreating" + exit 0 + fi + fi + + # Delete any key files that may be leftover from a prior run + cleanup_files + + # Generate RSA keys + # note ssh-keygen generates a different format for + # the public key so we run it through openssl to fix it + ssh-keygen -t rsa -b 2048 -m PEM -N '' -f tmp_jwk.key 2>&1 >/dev/null + openssl rsa -in tmp_jwk.key -pubout -outform PEM -out tmp_jwk.pub + + PUBLIC_KEY=`cat tmp_jwk.pub` + PRIVATE_KEY=`cat tmp_jwk.key` + + cat > $OVERRIDE_FILE < APIFlask: register_index(app) register_search_client(app) + # TODO - once we merge the auth changes for setting up the initial route + # will reuse the config from it, for now we'll do this a bit hacky + # This cannot be removed non-locally until we setup RSA keys for non-local envs + if os.getenv("ENVIRONMENT") == "local": + initialize_jwt_auth() + return app diff --git a/api/src/auth/api_jwt_auth.py b/api/src/auth/api_jwt_auth.py index 14cf2c8c5..5c0e979d1 100644 --- a/api/src/auth/api_jwt_auth.py +++ b/api/src/auth/api_jwt_auth.py @@ -1,19 +1,27 @@ import logging import uuid from datetime import timedelta +from typing import Tuple import jwt +from apiflask import HTTPTokenAuth from pydantic import Field from sqlalchemy import select +from sqlalchemy.orm import selectinload import src.util.datetime_util as datetime_util from src.adapters import db +from src.adapters.db import flask_db +from src.api.route_utils import raise_flask_error from src.auth.auth_errors import JwtValidationError from src.db.models.user_models import User, UserTokenSession +from src.logging.flask_logger import add_extra_data_to_current_request_logs from src.util.env_config import PydanticBaseEnvConfig logger = logging.getLogger(__name__) +api_jwt_auth = HTTPTokenAuth("ApiKey", header="X-SGG-Token", security_scheme_name="ApiJwtAuth") + class ApiJwtConfig(PydanticBaseEnvConfig): @@ -32,7 +40,7 @@ class ApiJwtConfig(PydanticBaseEnvConfig): _config: ApiJwtConfig | None = None -def initialize() -> None: +def initialize_jwt_auth() -> None: global _config if not _config: _config = ApiJwtConfig() @@ -53,14 +61,14 @@ def get_config() -> ApiJwtConfig: global _config if _config is None: - raise Exception("No JWT configuration - initialize() must be run first") + raise Exception("No JWT configuration - initialize_jwt_auth() must be run first") return _config def create_jwt_for_user( user: User, db_session: db.Session, config: ApiJwtConfig | None = None -) -> str: +) -> Tuple[str, UserTokenSession]: if config is None: config = get_config() @@ -72,13 +80,8 @@ def create_jwt_for_user( expiration_time = current_time + timedelta(minutes=config.token_expiration_minutes) # Create the session in the DB - db_session.add( - UserTokenSession( - user=user, - token_id=token_id, - expires_at=expiration_time, - ) - ) + user_token_session = UserTokenSession(user=user, token_id=token_id, expires_at=expiration_time) + db_session.add(user_token_session) # Create the JWT with information we'll want to receive back payload = { @@ -89,13 +92,21 @@ def create_jwt_for_user( "iss": config.issuer, } - return jwt.encode(payload, config.private_key, algorithm="RS256") + logger.info( + "Created JWT token", + extra={ + "auth.user_id": str(user_token_session.user_id), + "auth.token_id": str(user_token_session.token_id), + }, + ) + + return jwt.encode(payload, config.private_key, algorithm="RS256"), user_token_session def parse_jwt_for_user( token: str, db_session: db.Session, config: ApiJwtConfig | None = None -) -> User: - # TODO - more implementation/validation to come in https://github.com/HHS/simpler-grants-gov/issues/2809 +) -> UserTokenSession: + """Handle processing a jwt token, and connecting it to a user token session in our DB""" if config is None: config = get_config() @@ -135,8 +146,10 @@ def parse_jwt_for_user( raise JwtValidationError("Token missing sub field") token_session: UserTokenSession | None = db_session.execute( - select(UserTokenSession).join(User).where(UserTokenSession.token_id == sub_id) - ).scalar_one_or_none() + select(UserTokenSession) + .where(UserTokenSession.token_id == sub_id) + .options(selectinload("*")) + ).scalar() # We check both the token expires_at timestamp as well as an # is_valid flag to make sure the token is still valid. @@ -147,4 +160,48 @@ def parse_jwt_for_user( if token_session.is_valid is False: raise JwtValidationError("Token is no longer valid") - return token_session.user + return token_session + + +@api_jwt_auth.verify_token +@flask_db.with_db_session() +def decode_token(db_session: db.Session, token: str) -> UserTokenSession: + """ + Process an internal jwt token as created by the above create_jwt_for_user method. + + To add this auth to an endpoint, simply put:: + + from src.auth.api_jwt_auth import api_jwt_auth + + @example_blueprint.get("/example") + @example_blueprint.auth_required(api_jwt_auth) + @flask_db.with_db_session() + def example_method(db_session: db.Session) -> response.ApiResponse: + # The token session object can be fetched from the auth object + token_session: UserTokenSession = api_jwt_auth.current_user + + # If you want to modify the token_session or user, you will + # need to add it to the DB session otherwise it won't do anything + db_session.add(token_session) + token_session.expires_at = ... + ... + """ + + try: + user_token_session = parse_jwt_for_user(token, db_session) + + add_extra_data_to_current_request_logs( + { + "auth.user_id": str(user_token_session.user_id), + "auth.token_id": str(user_token_session.token_id), + } + ) + logger.info("JWT Authentication Successful") + + # Return the user token session object + return user_token_session + except JwtValidationError as e: + # If validation of the jwt fails, pass the error message back to the user + # The message is just the value we set when constructing the JwtValidationError + logger.info("JWT Authentication Failed for provided token", extra={"auth.issue": e.message}) + raise_flask_error(401, e.message) diff --git a/api/src/auth/api_key_auth.py b/api/src/auth/api_key_auth.py index 2359d29b3..0554cac5c 100644 --- a/api/src/auth/api_key_auth.py +++ b/api/src/auth/api_key_auth.py @@ -1,7 +1,6 @@ import logging import os from dataclasses import dataclass -from typing import Any import flask from apiflask import HTTPTokenAuth @@ -15,11 +14,7 @@ # this needs to be attached to your # routes as `your_blueprint.auth_required(api_key_auth)` # in order to enable authorization -api_key_auth = HTTPTokenAuth("ApiKey", header="X-Auth") - - -def get_app_security_scheme() -> dict[str, Any]: - return {"ApiKeyAuth": {"type": "apiKey", "in": "header", "name": "X-Auth"}} +api_key_auth = HTTPTokenAuth("ApiKey", header="X-Auth", security_scheme_name="ApiKeyAuth") @dataclass diff --git a/api/src/auth/auth_errors.py b/api/src/auth/auth_errors.py index 1e2bdee1b..bb48dd1ff 100644 --- a/api/src/auth/auth_errors.py +++ b/api/src/auth/auth_errors.py @@ -5,4 +5,6 @@ class JwtValidationError(Exception): cause the endpoint to raise a 401 """ - pass + def __init__(self, message: str): + super().__init__(message) + self.message = message diff --git a/api/src/auth/auth_utils.py b/api/src/auth/auth_utils.py new file mode 100644 index 000000000..0db25d732 --- /dev/null +++ b/api/src/auth/auth_utils.py @@ -0,0 +1,8 @@ +from typing import Any + + +def get_app_security_scheme() -> dict[str, Any]: + return { + "ApiKeyAuth": {"type": "apiKey", "in": "header", "name": "X-Auth"}, + "ApiJwtAuth": {"type": "apiKey", "in": "header", "name": "X-SGG-Token"}, + } diff --git a/api/tests/src/auth/test_api_jwt_auth.py b/api/tests/src/auth/test_api_jwt_auth.py index 4770327f0..43832b615 100644 --- a/api/tests/src/auth/test_api_jwt_auth.py +++ b/api/tests/src/auth/test_api_jwt_auth.py @@ -5,7 +5,14 @@ import pytest from freezegun import freeze_time -from src.auth.api_jwt_auth import ApiJwtConfig, create_jwt_for_user, parse_jwt_for_user +import src.app as app_entry +import src.logging +from src.auth.api_jwt_auth import ( + ApiJwtConfig, + api_jwt_auth, + create_jwt_for_user, + parse_jwt_for_user, +) from src.db.models.user_models import UserTokenSession from tests.src.db.models.factories import UserFactory @@ -18,12 +25,39 @@ def jwt_config(private_rsa_key, public_rsa_key): ) +@pytest.fixture(scope="module") +def mini_app(monkeypatch_module): + def stub(app): + pass + + """Create a separate app that we can modify separate from the base one used by other tests""" + # We want all the configurational setup for the app, but + # don't want blueprints to keep setup simpler + monkeypatch_module.setattr(app_entry, "register_blueprints", stub) + monkeypatch_module.setattr(app_entry, "setup_logging", stub) + mini_app = app_entry.create_app() + + @mini_app.get("/dummy_auth_endpoint") + @mini_app.auth_required(api_jwt_auth) + def dummy_endpoint(): + # For the tests that actually get past auth + # make sure the current user is set to the user session + assert api_jwt_auth.current_user is not None + assert isinstance(api_jwt_auth.current_user, UserTokenSession) + + return {"message": "ok"} + + # To avoid re-initializing logging everytime we + # setup the app, we disabled it above and do it here + # in case you want it while running your tests + with src.logging.init(__package__): + yield mini_app + + @freeze_time("2024-11-14 12:00:00", tz_offset=0) def test_create_jwt_for_user(enable_factory_create, db_session, jwt_config): user = UserFactory.create() - - token = create_jwt_for_user(user, db_session, jwt_config) - + token, token_session = create_jwt_for_user(user, db_session, jwt_config) decoded_token = jwt.decode( token, algorithms=[jwt_config.algorithm], options={"verify_signature": False} ) @@ -49,6 +83,116 @@ def test_create_jwt_for_user(enable_factory_create, db_session, jwt_config): assert token_session.expires_at == datetime.fromisoformat("2024-11-14 12:30:00+00:00") # Basic testing that the JWT we create for a user can in turn be fetched and processed later - # TODO - more in https://github.com/HHS/simpler-grants-gov/issues/2809 - parsed_user = parse_jwt_for_user(token, db_session, jwt_config) - assert parsed_user.user_id == user.user_id + user_session = parse_jwt_for_user(token, db_session, jwt_config) + assert user_session.user_id == user.user_id + + +def test_api_jwt_auth_happy_path(mini_app, enable_factory_create, db_session): + user = UserFactory.create() + token, _ = create_jwt_for_user(user, db_session) + db_session.commit() # need to commit here to push the session to the DB + + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": token}) + + assert resp.status_code == 200 + assert resp.get_json()["message"] == "ok" + + +def test_api_jwt_auth_expired_token(mini_app, enable_factory_create, db_session): + user = UserFactory.create() + token, session = create_jwt_for_user(user, db_session) + session.expires_at = datetime.fromisoformat("1980-01-01 12:00:00+00:00") + db_session.commit() + + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": token}) + + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Token expired" + + +def test_api_jwt_auth_invalid_token(mini_app, enable_factory_create, db_session): + user = UserFactory.create() + token, session = create_jwt_for_user(user, db_session) + session.is_valid = False + db_session.commit() + + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": token}) + + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Token is no longer valid" + + +def test_api_jwt_auth_token_missing_in_db(mini_app, enable_factory_create, db_session): + user = UserFactory.create() + token, session = create_jwt_for_user(user, db_session) + db_session.expunge(session) # Just drop it, never sending to the DB + + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": token}) + + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Token session does not exist" + + +def test_api_jwt_auth_token_not_jwt(mini_app, enable_factory_create, db_session): + # Just call with a random set of characters + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": "abc123"}) + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Unable to process token" + + +def test_api_jwt_auth_token_created_with_different_key( + mini_app, enable_factory_create, db_session, jwt_config +): + # Note - jwt_config uses a key generated in the conftest within this directory + # while the config the app picks up grabs a key from our override.env file + user = UserFactory.create() + token, _ = create_jwt_for_user(user, db_session, jwt_config) + db_session.commit() + + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": token}) + + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Unable to process token" + + +def test_api_jwt_auth_token_iat_future(mini_app, enable_factory_create, db_session): + # Set time to the 14th so the iat value will be then + with freeze_time("2024-11-14 12:00:00", tz_offset=0): + user = UserFactory.create() + token, _ = create_jwt_for_user(user, db_session) + db_session.commit() + + # Set time to the 12th when calling the API so the iat will be in the future now + with freeze_time("2024-11-12 12:00:00", tz_offset=0): + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": token}) + + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Token not yet valid" + + +def test_api_jwt_auth_token_unknown_issuer(mini_app, enable_factory_create, db_session): + config = ApiJwtConfig(API_JWT_ISSUER="some-guy") + user = UserFactory.create() + token, _ = create_jwt_for_user(user, db_session, config) + db_session.commit() + + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": token}) + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Unknown Issuer" + + +def test_api_jwt_auth_token_unknown_audience(mini_app, enable_factory_create, db_session): + config = ApiJwtConfig(API_JWT_AUDIENCE="someone-else") + user = UserFactory.create() + token, _ = create_jwt_for_user(user, db_session, config) + db_session.commit() + + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={"X-SGG-Token": token}) + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Unknown Audience" + + +def test_api_jwt_auth_no_token(mini_app, enable_factory_create, db_session): + resp = mini_app.test_client().get("/dummy_auth_endpoint", headers={}) + assert resp.status_code == 401 + assert resp.get_json()["message"] == "Unable to process token" diff --git a/api/tests/src/db/models/factories.py b/api/tests/src/db/models/factories.py index 05024951f..5a387f254 100644 --- a/api/tests/src/db/models/factories.py +++ b/api/tests/src/db/models/factories.py @@ -1857,3 +1857,17 @@ class Meta: external_user_type_id = factory.fuzzy.FuzzyChoice(ExternalUserType) email = factory.Faker("email") + + +class UserTokenSessionFactory(BaseFactory): + class Meta: + model = user_models.UserTokenSession + + user = factory.SubFactory(UserFactory) + user_id = factory.LazyAttribute(lambda s: s.user.user_id) + + token_id = Generators.UuidObj + + expires_at = factory.Faker("date_time_between", start_date="+1d", end_date="+10d") + + is_valid = True