Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue #2809] Handle parsing the jwt we created, and connect to a user #2959

Merged
merged 34 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4271713
WIP
chouinar Nov 12, 2024
b5e8ff4
Merge branch 'main' into chouinar/2721-jwk-backend
chouinar Nov 12, 2024
f83f0d1
Tests and cleanup
chouinar Nov 13, 2024
638d332
Merge branch 'main' into chouinar/2721-jwk-backend
chouinar Nov 13, 2024
8eff255
Minor config update
chouinar Nov 13, 2024
8a38474
Testing whether this would fix it (although probably is too much)
chouinar Nov 13, 2024
6be928a
trivy yaml
chouinar Nov 13, 2024
6f588d0
trivyignore stuff
chouinar Nov 13, 2024
4d2e229
Merge branch 'main' into chouinar/2721-jwk-backend
chouinar Nov 13, 2024
ddef681
Fix path
chouinar Nov 13, 2024
7942476
Wallkicks will work
chouinar Nov 13, 2024
f2c0e75
Merge branch 'main' into chouinar/2721-jwk-backend
chouinar Nov 13, 2024
970e54d
direct path
chouinar Nov 14, 2024
c74d1e8
Adjust path?
chouinar Nov 14, 2024
9756844
Try skip files
chouinar Nov 14, 2024
f2ef7cf
Try a glob for even more simplicity
chouinar Nov 14, 2024
3883512
Undo extra trivy changes
chouinar Nov 14, 2024
1e53074
[Issue #2808] Setup logic for creating a JWT
chouinar Nov 15, 2024
020521c
Merge branch 'main' into chouinar/2721-jwk-backend
chouinar Nov 18, 2024
99660a8
Merge branch 'main' into chouinar/2808-create-a-jwt
chouinar Nov 18, 2024
25635ce
Add migration and cleanup
chouinar Nov 18, 2024
844b0f0
Merge branch 'chouinar/2721-jwk-backend' into chouinar/2808-create-a-jwt
chouinar Nov 18, 2024
bbbb65a
Create ERD diagram and Update OpenAPI spec
nava-platform-bot Nov 18, 2024
6691783
Merge branch 'main' into chouinar/2808-create-a-jwt
chouinar Nov 18, 2024
09f0528
Merge branch 'main' into chouinar/2808-create-a-jwt
chouinar Nov 19, 2024
92664ed
[Issue #2809] Handle parsing the jwt we created, and connect to a user
chouinar Nov 20, 2024
3dc50ad
Merge branch 'main' into chouinar/2808-create-a-jwt
chouinar Nov 20, 2024
5b8dab5
Merge branch 'chouinar/2808-create-a-jwt' into chouinar/2809-parse-a-jwt
chouinar Nov 20, 2024
f40f286
Don't initialize jwt non-locally
chouinar Nov 20, 2024
76cd414
Setup the env file, but with shell scripts
chouinar Nov 20, 2024
c0e77c1
Create ERD diagram and Update OpenAPI spec
nava-platform-bot Nov 20, 2024
039bf15
Merge branch 'main' into chouinar/2809-parse-a-jwt
chouinar Nov 21, 2024
c4631ae
Merge branch 'main' into chouinar/2809-parse-a-jwt
chouinar Nov 21, 2024
da76c75
Create ERD diagram and Update OpenAPI spec
nava-platform-bot Nov 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion api/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ coverage.*
# Environment variables
.env
.envrc
override.env

# mypy
.mypy_cache
Expand All @@ -31,4 +32,4 @@ coverage.*
/test-results/

# localstack
/volume
/volume
5 changes: 4 additions & 1 deletion api/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ setup-local:
# Install dependencies
poetry install --no-root --all-extras --with dev

setup-env-override-file:
$(PY_RUN_CMD) setup-env-override-file $(args)

##################################################
# API Build & Run
##################################################
Expand All @@ -100,7 +103,7 @@ start-debug:
run-logs: start ## Start the API and follow the logs
docker compose logs --follow --no-color $(APP_NAME)

init: build init-db init-opensearch init-localstack
init: setup-env-override-file build init-db init-opensearch init-localstack

clean-volumes: ## Remove project docker volumes - which includes the DB, and OpenSearch state
docker compose down --volumes
Expand Down
100 changes: 100 additions & 0 deletions api/bin/setup_env_override_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import logging
import pathlib
from typing import Tuple

import click
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa

import src.logging
from src.util.local import error_if_not_local

logger = logging.getLogger(__name__)

EMPTY_LINE = "\n"

DEFAULT_DESCRIPTION = """# override.env
#
# Any environment variables written to this file
# will take precedence over those defined in local.env
#
# This file will not be checked into github and it is safe
# to store secrets here, however you should still follow caution
# with using any secrets locally if they cause the app to interact
# with external systems.
#
# This file was generated by running:
# make setup-env-override-file
#
# Which runs as part of our "make init" flow.
#
# If you would like to re-generate this file, please run:
# make setup-env-override-file args="--recreate"
#
# Note that this will completely erase any existing configuration you may have
"""

AUTHENTICATION_HEADER = """
############################
# Authentication
############################
"""


def get_keys() -> Tuple[str, str]:
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)

private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)

public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo
)

# decode produces a multi-line key that looks like:
# ------BEGIN...
# ...
# ------END...
#
# With a newline at the end we want to strip for formatting simplicity
return private_key.decode().removesuffix("\n"), public_key.decode().removesuffix("\n")


def create_override_file(recreate: bool) -> None:
override_file_path = pathlib.Path(__file__).parent.parent / "override.env"
if override_file_path.exists():
if not recreate:
logger.info("override.env already exists, not recreating")
return
logger.info("Recreating existing override.env file")

private_key, public_key = get_keys()

with open(override_file_path, "w") as override_file:
override_file.writelines(DEFAULT_DESCRIPTION)
override_file.write(EMPTY_LINE)
override_file.write(AUTHENTICATION_HEADER)
override_file.write(EMPTY_LINE)
override_file.write(f'API_JWT_PRIVATE_KEY="{private_key}"')
override_file.write(EMPTY_LINE)
override_file.write(EMPTY_LINE)
override_file.write(f'API_JWT_PUBLIC_KEY="{public_key}"')
override_file.write(EMPTY_LINE)

logger.info("Created override.env file")


@click.command()
@click.option(
"--recreate",
is_flag=True,
default=False,
help="Whether to recreate the override file if it already exists",
)
def setup_env_override_file(recreate: bool) -> None:
with src.logging.init("create_env_override_file"):
error_if_not_local()
create_override_file(recreate)
6 changes: 5 additions & 1 deletion api/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ services:
"--reload",
]
container_name: grants-api
env_file: ./local.env
env_file:
- path: ./local.env
required: true
- path: ./override.env
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! It's just a list so lower ones override, nice

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, when I came across it in the docs I realized it was exactly what I'd been looking for in env var management from Docker for years

required: false
ports:
- 8080:8080
volumes:
Expand Down
1 change: 1 addition & 0 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ db-seed-local = "tests.lib.seed_local_db:seed_local_db"
create-erds = "bin.create_erds:main"
setup-postgres-db = "src.db.migrations.setup_local_postgres_db:setup_local_postgres_db"
setup-localstack = "bin.setup_localstack:main"
setup-env-override-file = "bin.setup_env_override_file:setup_env_override_file"

[tool.black]
line-length = 100
Expand Down
9 changes: 8 additions & 1 deletion api/src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from src.api.response import restructure_error_response
from src.api.schemas import response_schema
from src.app_config import AppConfig
from src.auth.api_key_auth import get_app_security_scheme
from src.auth.api_jwt_auth import initialize_jwt_auth
from src.auth.auth_utils import get_app_security_scheme
from src.data_migration.data_migration_blueprint import data_migration_blueprint
from src.search.backend.load_search_data_blueprint import load_search_data_blueprint
from src.task import task_blueprint
Expand Down Expand Up @@ -51,6 +52,12 @@ def create_app() -> APIFlask:
register_index(app)
register_search_client(app)

# TODO - once we merge the auth changes for setting up the initial route
# will reuse the config from it, for now we'll do this a bit hacky
# This cannot be removed non-locally until we setup RSA keys for non-local envs
if os.getenv("ENVIRONMENT") == "local":
initialize_jwt_auth()

return app


Expand Down
207 changes: 207 additions & 0 deletions api/src/auth/api_jwt_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import logging
import uuid
from datetime import timedelta
from typing import Tuple

import jwt
from apiflask import HTTPTokenAuth
from pydantic import Field
from sqlalchemy import select
from sqlalchemy.orm import selectinload

import src.util.datetime_util as datetime_util
from src.adapters import db
from src.adapters.db import flask_db
from src.api.route_utils import raise_flask_error
from src.auth.auth_errors import JwtValidationError
from src.db.models.user_models import User, UserTokenSession
from src.logging.flask_logger import add_extra_data_to_current_request_logs
from src.util.env_config import PydanticBaseEnvConfig

logger = logging.getLogger(__name__)

api_jwt_auth = HTTPTokenAuth("ApiKey", header="X-SGG-Token", security_scheme_name="ApiJwtAuth")


class ApiJwtConfig(PydanticBaseEnvConfig):

private_key: str = Field(alias="API_JWT_PRIVATE_KEY")
public_key: str = Field(alias="API_JWT_PUBLIC_KEY")

issuer: str = Field("simpler-grants-api", alias="API_JWT_ISSUER")
audience: str = Field("simpler-grants-api", alias="API_JWT_AUDIENCE")

algorithm: str = Field("RS256", alias="API_JWT_ALGORITHM")

token_expiration_minutes: int = Field(30, alias="API_JWT_TOKEN_EXPIRATION_MINUTES")


# Initialize a config at startup that we'll use below
_config: ApiJwtConfig | None = None


def initialize_jwt_auth() -> None:
global _config
if not _config:
_config = ApiJwtConfig()
logger.info(
"Constructed JWT configuration",
extra={
# NOTE: We don't just log the entire config
# because that would include the encryption keys
"issuer": _config.issuer,
"audience": _config.audience,
"algorithm": _config.algorithm,
"token_expiration_minutes": _config.token_expiration_minutes,
},
)


def get_config() -> ApiJwtConfig:
global _config

if _config is None:
raise Exception("No JWT configuration - initialize_jwt_auth() must be run first")

return _config


def create_jwt_for_user(
user: User, db_session: db.Session, config: ApiJwtConfig | None = None
) -> Tuple[str, UserTokenSession]:
if config is None:
config = get_config()

# Generate a random ID
token_id = uuid.uuid4()

# Always do all time checks in UTC for consistency
current_time = datetime_util.utcnow()
expiration_time = current_time + timedelta(minutes=config.token_expiration_minutes)

# Create the session in the DB
user_token_session = UserTokenSession(user=user, token_id=token_id, expires_at=expiration_time)
db_session.add(user_token_session)

# Create the JWT with information we'll want to receive back
payload = {
"sub": str(token_id),
# iat -> issued at
"iat": current_time,
"aud": config.audience,
"iss": config.issuer,
}

logger.info(
"Created JWT token",
extra={
"auth.user_id": str(user_token_session.user_id),
"auth.token_id": str(user_token_session.token_id),
},
)

return jwt.encode(payload, config.private_key, algorithm="RS256"), user_token_session


def parse_jwt_for_user(
token: str, db_session: db.Session, config: ApiJwtConfig | None = None
) -> UserTokenSession:
"""Handle processing a jwt token, and connecting it to a user token session in our DB"""
if config is None:
config = get_config()

current_timestamp = datetime_util.utcnow()

try:
parsed_jwt: dict = jwt.decode(
token,
config.public_key,
algorithms=[config.algorithm],
issuer=config.issuer,
audience=config.audience,
options={
"verify_signature": True,
"verify_iat": True,
"verify_aud": True,
"verify_iss": True,
# We do not set the following fields
# so do not want to validate.
"verify_exp": False, # expiration is managed in the DB
"verify_nbf": False, # Tokens are always fine to use immediately
},
)

except jwt.ImmatureSignatureError as e: # IAT errors hit this
raise JwtValidationError("Token not yet valid") from e
except jwt.InvalidIssuerError as e:
raise JwtValidationError("Unknown Issuer") from e
except jwt.InvalidAudienceError as e:
raise JwtValidationError("Unknown Audience") from e
except jwt.PyJWTError as e:
# Every other error case wrap in the same generic error message.
raise JwtValidationError("Unable to process token") from e

sub_id = parsed_jwt.get("sub", None)
if sub_id is None:
raise JwtValidationError("Token missing sub field")

token_session: UserTokenSession | None = db_session.execute(
select(UserTokenSession)
.where(UserTokenSession.token_id == sub_id)
.options(selectinload("*"))
).scalar()

# We check both the token expires_at timestamp as well as an
# is_valid flag to make sure the token is still valid.
if token_session is None:
raise JwtValidationError("Token session does not exist")
if token_session.expires_at < current_timestamp:
raise JwtValidationError("Token expired")
if token_session.is_valid is False:
raise JwtValidationError("Token is no longer valid")

return token_session


@api_jwt_auth.verify_token
@flask_db.with_db_session()
def decode_token(db_session: db.Session, token: str) -> UserTokenSession:
"""
Process an internal jwt token as created by the above create_jwt_for_user method.

To add this auth to an endpoint, simply put::

from src.auth.api_jwt_auth import api_jwt_auth

@example_blueprint.get("/example")
@example_blueprint.auth_required(api_jwt_auth)
@flask_db.with_db_session()
def example_method(db_session: db.Session) -> response.ApiResponse:
# The token session object can be fetched from the auth object
token_session: UserTokenSession = api_jwt_auth.current_user

# If you want to modify the token_session or user, you will
# need to add it to the DB session otherwise it won't do anything
db_session.add(token_session)
token_session.expires_at = ...
...
"""

try:
user_token_session = parse_jwt_for_user(token, db_session)

add_extra_data_to_current_request_logs(
{
"auth.user_id": str(user_token_session.user_id),
"auth.token_id": str(user_token_session.token_id),
}
)
logger.info("JWT Authentication Successful")

# Return the user token session object
return user_token_session
except JwtValidationError as e:
# If validation of the jwt fails, pass the error message back to the user
# The message is just the value we set when constructing the JwtValidationError
logger.info("JWT Authentication Failed for provided token", extra={"auth.issue": e.message})
raise_flask_error(401, e.message)
Loading
Loading