Skip to content

Commit

Permalink
Merge pull request #1358 from bcgov/fix/alex-services-config-241203
Browse files Browse the repository at this point in the history
Fix: Jwk update for Redis Pool
  • Loading branch information
AlexZorkin authored Dec 4, 2024
2 parents 2a5cfd3 + c7d47e0 commit 6868f8c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 31 deletions.
71 changes: 43 additions & 28 deletions backend/lcfs/services/keycloak/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import httpx
import jwt
from fastapi import HTTPException, Depends
from redis import ConnectionPool
from redis.asyncio import Redis
from fastapi import HTTPException
from redis.asyncio import Redis, ConnectionPool
from sqlalchemy import func
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import async_sessionmaker
Expand All @@ -27,7 +26,7 @@ class UserAuthentication(AuthenticationBackend):

def __init__(
self,
redis_pool: Redis,
redis_pool: ConnectionPool,
session_factory: async_sessionmaker,
settings: Settings,
):
Expand All @@ -39,30 +38,46 @@ def __init__(
self.test_keycloak_user = None

async def refresh_jwk(self):
# Try to get the JWKS data from Redis cache
jwks_data = await self.redis_pool.get("jwks_data")

if jwks_data:
jwks_data = json.loads(jwks_data)
self.jwks = jwks_data.get("jwks")
self.jwks_uri = jwks_data.get("jwks_uri")
return

# If not in cache, retrieve from the well-known endpoint
async with httpx.AsyncClient() as client:
oidc_response = await client.get(self.settings.well_known_endpoint)
jwks_uri = oidc_response.json().get("jwks_uri")
certs_response = await client.get(jwks_uri)
jwks = certs_response.json()

# Composite object containing both JWKS and JWKS URI
jwks_data = {"jwks": jwks, "jwks_uri": jwks_uri}

# Cache the composite JWKS data with a TTL of 1 day (86400 seconds)
await self.redis_pool.set("jwks_data", json.dumps(jwks_data), ex=86400)

self.jwks = jwks
self.jwks_uri = jwks_uri
"""
Refreshes the JSON Web Key (JWK) used for token verification.
This method attempts to retrieve the JWK from Redis cache.
If not found, it fetches it from the well-known endpoint
and stores it in Redis for future use.
"""
# Create a Redis client from the connection pool
async with Redis(connection_pool=self.redis_pool) as redis:
# Try to get the JWKS data from Redis cache
jwks_data = await redis.get("jwks_data")

if jwks_data:
jwks_data = json.loads(jwks_data)
self.jwks = jwks_data.get("jwks")
self.jwks_uri = jwks_data.get("jwks_uri")
return

# If not in cache, retrieve from the well-known endpoint
async with httpx.AsyncClient() as client:
oidc_response = await client.get(self.settings.well_known_endpoint)
oidc_response.raise_for_status()
jwks_uri = oidc_response.json().get("jwks_uri")

if not jwks_uri:
raise ValueError(
"JWKS URI not found in the well-known endpoint response."
)

certs_response = await client.get(jwks_uri)
certs_response.raise_for_status()
jwks = certs_response.json()

# Composite object containing both JWKS and JWKS URI
jwks_data = {"jwks": jwks, "jwks_uri": jwks_uri}

# Cache the composite JWKS data with a TTL of 1 day (86400 seconds)
await redis.set("jwks_data", json.dumps(jwks_data), ex=86400)

self.jwks = jwks
self.jwks_uri = jwks_uri

async def authenticate(self, request):
# Extract the authorization header from the request
Expand Down
9 changes: 6 additions & 3 deletions backend/lcfs/web/lifetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import boto3
from fastapi_cache import FastAPICache
from fastapi_cache.backends.redis import RedisBackend
from redis import asyncio as aioredis
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine

from lcfs.services.rabbitmq.consumers import start_consumers, stop_consumers
Expand Down Expand Up @@ -57,8 +57,11 @@ async def _startup() -> None: # noqa: WPS430
# Assign settings to app state for global access
app.state.settings = settings

# Initialize the cache with Redis backend using app.state.redis_pool
FastAPICache.init(RedisBackend(app.state.redis_pool), prefix="lcfs")
# Create a Redis client from the connection pool
redis_client = Redis(connection_pool=app.state.redis_pool)

# Initialize FastAPI cache with the Redis client
FastAPICache.init(RedisBackend(redis_client), prefix="lcfs")

await init_org_balance_cache(app)

Expand Down

0 comments on commit 6868f8c

Please sign in to comment.