diff --git a/backend/lcfs/services/keycloak/authentication.py b/backend/lcfs/services/keycloak/authentication.py index 86026ba6f..f3fc1f776 100644 --- a/backend/lcfs/services/keycloak/authentication.py +++ b/backend/lcfs/services/keycloak/authentication.py @@ -2,9 +2,8 @@ import httpx import jwt -from fastapi import HTTPException, Depends -from redis import ConnectionPool -from redis.asyncio import Redis +from fastapi import HTTPException +from redis.asyncio import Redis, ConnectionPool from sqlalchemy import func from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import async_sessionmaker @@ -27,7 +26,7 @@ class UserAuthentication(AuthenticationBackend): def __init__( self, - redis_pool: Redis, + redis_pool: ConnectionPool, session_factory: async_sessionmaker, settings: Settings, ): @@ -39,30 +38,46 @@ def __init__( self.test_keycloak_user = None async def refresh_jwk(self): - # Try to get the JWKS data from Redis cache - jwks_data = await self.redis_pool.get("jwks_data") - - if jwks_data: - jwks_data = json.loads(jwks_data) - self.jwks = jwks_data.get("jwks") - self.jwks_uri = jwks_data.get("jwks_uri") - return - - # If not in cache, retrieve from the well-known endpoint - async with httpx.AsyncClient() as client: - oidc_response = await client.get(self.settings.well_known_endpoint) - jwks_uri = oidc_response.json().get("jwks_uri") - certs_response = await client.get(jwks_uri) - jwks = certs_response.json() - - # Composite object containing both JWKS and JWKS URI - jwks_data = {"jwks": jwks, "jwks_uri": jwks_uri} - - # Cache the composite JWKS data with a TTL of 1 day (86400 seconds) - await self.redis_pool.set("jwks_data", json.dumps(jwks_data), ex=86400) - - self.jwks = jwks - self.jwks_uri = jwks_uri + """ + Refreshes the JSON Web Key (JWK) used for token verification. + This method attempts to retrieve the JWK from Redis cache. + If not found, it fetches it from the well-known endpoint + and stores it in Redis for future use. + """ + # Create a Redis client from the connection pool + async with Redis(connection_pool=self.redis_pool) as redis: + # Try to get the JWKS data from Redis cache + jwks_data = await redis.get("jwks_data") + + if jwks_data: + jwks_data = json.loads(jwks_data) + self.jwks = jwks_data.get("jwks") + self.jwks_uri = jwks_data.get("jwks_uri") + return + + # If not in cache, retrieve from the well-known endpoint + async with httpx.AsyncClient() as client: + oidc_response = await client.get(self.settings.well_known_endpoint) + oidc_response.raise_for_status() + jwks_uri = oidc_response.json().get("jwks_uri") + + if not jwks_uri: + raise ValueError( + "JWKS URI not found in the well-known endpoint response." + ) + + certs_response = await client.get(jwks_uri) + certs_response.raise_for_status() + jwks = certs_response.json() + + # Composite object containing both JWKS and JWKS URI + jwks_data = {"jwks": jwks, "jwks_uri": jwks_uri} + + # Cache the composite JWKS data with a TTL of 1 day (86400 seconds) + await redis.set("jwks_data", json.dumps(jwks_data), ex=86400) + + self.jwks = jwks + self.jwks_uri = jwks_uri async def authenticate(self, request): # Extract the authorization header from the request diff --git a/backend/lcfs/web/lifetime.py b/backend/lcfs/web/lifetime.py index 0215689e7..cbeb556b2 100644 --- a/backend/lcfs/web/lifetime.py +++ b/backend/lcfs/web/lifetime.py @@ -4,7 +4,7 @@ import boto3 from fastapi_cache import FastAPICache from fastapi_cache.backends.redis import RedisBackend -from redis import asyncio as aioredis +from redis.asyncio import Redis from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from lcfs.services.rabbitmq.consumers import start_consumers, stop_consumers @@ -57,8 +57,11 @@ async def _startup() -> None: # noqa: WPS430 # Assign settings to app state for global access app.state.settings = settings - # Initialize the cache with Redis backend using app.state.redis_pool - FastAPICache.init(RedisBackend(app.state.redis_pool), prefix="lcfs") + # Create a Redis client from the connection pool + redis_client = Redis(connection_pool=app.state.redis_pool) + + # Initialize FastAPI cache with the Redis client + FastAPICache.init(RedisBackend(redis_client), prefix="lcfs") await init_org_balance_cache(app)