diff --git a/backend/lcfs/services/keycloak/authentication.py b/backend/lcfs/services/keycloak/authentication.py index 86026ba6f..f3fc1f776 100644 --- a/backend/lcfs/services/keycloak/authentication.py +++ b/backend/lcfs/services/keycloak/authentication.py @@ -2,9 +2,8 @@ import httpx import jwt -from fastapi import HTTPException, Depends -from redis import ConnectionPool -from redis.asyncio import Redis +from fastapi import HTTPException +from redis.asyncio import Redis, ConnectionPool from sqlalchemy import func from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import async_sessionmaker @@ -27,7 +26,7 @@ class UserAuthentication(AuthenticationBackend): def __init__( self, - redis_pool: Redis, + redis_pool: ConnectionPool, session_factory: async_sessionmaker, settings: Settings, ): @@ -39,30 +38,46 @@ def __init__( self.test_keycloak_user = None async def refresh_jwk(self): - # Try to get the JWKS data from Redis cache - jwks_data = await self.redis_pool.get("jwks_data") - - if jwks_data: - jwks_data = json.loads(jwks_data) - self.jwks = jwks_data.get("jwks") - self.jwks_uri = jwks_data.get("jwks_uri") - return - - # If not in cache, retrieve from the well-known endpoint - async with httpx.AsyncClient() as client: - oidc_response = await client.get(self.settings.well_known_endpoint) - jwks_uri = oidc_response.json().get("jwks_uri") - certs_response = await client.get(jwks_uri) - jwks = certs_response.json() - - # Composite object containing both JWKS and JWKS URI - jwks_data = {"jwks": jwks, "jwks_uri": jwks_uri} - - # Cache the composite JWKS data with a TTL of 1 day (86400 seconds) - await self.redis_pool.set("jwks_data", json.dumps(jwks_data), ex=86400) - - self.jwks = jwks - self.jwks_uri = jwks_uri + """ + Refreshes the JSON Web Key (JWK) used for token verification. + This method attempts to retrieve the JWK from Redis cache. + If not found, it fetches it from the well-known endpoint + and stores it in Redis for future use. + """ + # Create a Redis client from the connection pool + async with Redis(connection_pool=self.redis_pool) as redis: + # Try to get the JWKS data from Redis cache + jwks_data = await redis.get("jwks_data") + + if jwks_data: + jwks_data = json.loads(jwks_data) + self.jwks = jwks_data.get("jwks") + self.jwks_uri = jwks_data.get("jwks_uri") + return + + # If not in cache, retrieve from the well-known endpoint + async with httpx.AsyncClient() as client: + oidc_response = await client.get(self.settings.well_known_endpoint) + oidc_response.raise_for_status() + jwks_uri = oidc_response.json().get("jwks_uri") + + if not jwks_uri: + raise ValueError( + "JWKS URI not found in the well-known endpoint response." + ) + + certs_response = await client.get(jwks_uri) + certs_response.raise_for_status() + jwks = certs_response.json() + + # Composite object containing both JWKS and JWKS URI + jwks_data = {"jwks": jwks, "jwks_uri": jwks_uri} + + # Cache the composite JWKS data with a TTL of 1 day (86400 seconds) + await redis.set("jwks_data", json.dumps(jwks_data), ex=86400) + + self.jwks = jwks + self.jwks_uri = jwks_uri async def authenticate(self, request): # Extract the authorization header from the request diff --git a/backend/lcfs/services/tfrs/redis_balance.py b/backend/lcfs/services/tfrs/redis_balance.py index 69dc96010..02d6f2a12 100644 --- a/backend/lcfs/services/tfrs/redis_balance.py +++ b/backend/lcfs/services/tfrs/redis_balance.py @@ -17,10 +17,19 @@ async def init_org_balance_cache(app: FastAPI): - redis = await app.state.redis_pool + """ + Initialize the organization balance cache and populate it with data. + + :param app: FastAPI application instance. + """ + # Get the Redis connection pool from app state + redis_pool: ConnectionPool = app.state.redis_pool + + # Create a Redis client using the connection pool + redis = Redis(connection_pool=redis_pool) + async with AsyncSession(async_engine) as session: async with session.begin(): - organization_repo = OrganizationsRepository(db=session) transaction_repo = TransactionRepository(db=session) @@ -29,23 +38,32 @@ async def init_org_balance_cache(app: FastAPI): # Get the current year current_year = datetime.now().year - logger.info(f"Starting balance cache population {current_year}") + logger.info(f"Starting balance cache population for {current_year}") + # Fetch all organizations all_orgs = await organization_repo.get_organizations() # Loop from the oldest year to the current year for year in range(int(oldest_year), current_year + 1): - # Call the function to process transactions for each year for org in all_orgs: + # Calculate the balance for each organization and year balance = ( await transaction_repo.calculate_available_balance_for_period( org.organization_id, year ) ) + # Set the balance in Redis await set_cache_value(org.organization_id, year, balance, redis) - logger.debug(f"Set balance for {org.name} for {year} to {balance}") + logger.debug( + f"Set balance for organization {org.name} " + f"for {year} to {balance}" + ) + logger.info(f"Cache populated with {len(all_orgs)} organizations") + # Close the Redis client + await redis.close() + class RedisBalanceService: def __init__( @@ -84,4 +102,12 @@ async def populate_organization_redis_balance( async def set_cache_value( organization_id: int, period: int, balance: int, redis: Redis ) -> None: + """ + Set a cache value in Redis for a specific organization and period. + + :param organization_id: ID of the organization. + :param period: The year or period for which the balance is being set. + :param balance: The balance value to set in the cache. + :param redis: Redis client instance. + """ await redis.set(name=f"balance_{organization_id}_{period}", value=balance) diff --git a/backend/lcfs/tests/services/tfrs/test_redis_balance.py b/backend/lcfs/tests/services/tfrs/test_redis_balance.py index 56ce31fb1..9df7f20fc 100644 --- a/backend/lcfs/tests/services/tfrs/test_redis_balance.py +++ b/backend/lcfs/tests/services/tfrs/test_redis_balance.py @@ -1,5 +1,5 @@ import pytest -from unittest.mock import AsyncMock, patch, MagicMock +from unittest.mock import AsyncMock, patch, MagicMock, call from datetime import datetime from redis.asyncio import ConnectionPool, Redis @@ -13,55 +13,50 @@ @pytest.mark.anyio async def test_init_org_balance_cache(): - # Mock the session and repositories - mock_session = AsyncMock() - - # Mock the Redis client + # Mock the Redis connection pool + mock_redis_pool = AsyncMock() mock_redis = AsyncMock() - mock_redis.set = AsyncMock() # Ensure the `set` method is mocked - - # Mock the settings - mock_settings = MagicMock() - mock_settings.redis_url = "redis://localhost" - - # Create a mock app object - mock_app = MagicMock() - - # Simulate redis_pool as an awaitable returning mock_redis - async def mock_redis_pool(): - return mock_redis + mock_redis.set = AsyncMock() - mock_app.state.redis_pool = mock_redis_pool() - mock_app.state.settings = mock_settings + # Ensure the `Redis` instance is created with the connection pool + with patch("lcfs.services.tfrs.redis_balance.Redis", return_value=mock_redis): + # Mock the app object + mock_app = MagicMock() + mock_app.state.redis_pool = mock_redis_pool - current_year = datetime.now().year - last_year = current_year - 1 + current_year = datetime.now().year + last_year = current_year - 1 - with patch( - "lcfs.web.api.organizations.services.OrganizationsRepository.get_organizations", - return_value=[ - MagicMock(organization_id=1, name="Org1"), - MagicMock(organization_id=2, name="Org2"), - ], - ): + # Mock repository methods with patch( + "lcfs.web.api.organizations.repo.OrganizationsRepository.get_organizations", + return_value=[ + MagicMock(organization_id=1, name="Org1"), + MagicMock(organization_id=2, name="Org2"), + ], + ), patch( "lcfs.web.api.transaction.repo.TransactionRepository.get_transaction_start_year", return_value=last_year, + ), patch( + "lcfs.web.api.transaction.repo.TransactionRepository.calculate_available_balance_for_period", + side_effect=[100, 200, 150, 250], ): - with patch( - "lcfs.web.api.transaction.repo.TransactionRepository.calculate_available_balance_for_period", - side_effect=[100, 200, 150, 250, 300, 350], - ): - # Pass the mock app to the function - await init_org_balance_cache(mock_app) - - # Assert that each cache set operation was called correctly - calls = mock_redis.set.mock_calls - assert len(calls) == 4 - mock_redis.set.assert_any_call(name=f"balance_1_{last_year}", value=100) - mock_redis.set.assert_any_call(name=f"balance_2_{last_year}", value=200) - mock_redis.set.assert_any_call(name=f"balance_1_{current_year}", value=150) - mock_redis.set.assert_any_call(name=f"balance_2_{current_year}", value=250) + # Execute the function with the mocked app + await init_org_balance_cache(mock_app) + + # Define expected calls to Redis `set` + expected_calls = [ + call(name=f"balance_1_{last_year}", value=100), + call(name=f"balance_2_{last_year}", value=200), + call(name=f"balance_1_{current_year}", value=150), + call(name=f"balance_2_{current_year}", value=250), + ] + + # Assert that Redis `set` method was called with the expected arguments + mock_redis.set.assert_has_calls(expected_calls, any_order=True) + + # Ensure the number of calls matches the expected count + assert mock_redis.set.call_count == len(expected_calls) @pytest.mark.anyio diff --git a/backend/lcfs/tests/test_auth_middleware.py b/backend/lcfs/tests/test_auth_middleware.py index d59076107..083c188a5 100644 --- a/backend/lcfs/tests/test_auth_middleware.py +++ b/backend/lcfs/tests/test_auth_middleware.py @@ -1,9 +1,11 @@ from unittest.mock import AsyncMock, patch, MagicMock, Mock import pytest -import asyncio +import json +import redis from starlette.exceptions import HTTPException from starlette.requests import Request +from redis.asyncio import Redis, ConnectionPool from lcfs.db.models import UserProfile from lcfs.services.keycloak.authentication import UserAuthentication @@ -35,43 +37,112 @@ def auth_backend(redis_pool, session_generator, settings): @pytest.mark.anyio -async def test_load_jwk_from_redis(auth_backend): - # Mock auth_backend.redis_pool.get to return a JSON string directly - with patch.object(auth_backend.redis_pool, "get", new_callable=AsyncMock) as mock_redis_get: - mock_redis_get.return_value = '{"jwks": "jwks", "jwks_uri": "jwks_uri"}' +async def test_load_jwk_from_redis(): + # Initialize mock Redis client + mock_redis = AsyncMock(spec=Redis) + mock_redis.get = AsyncMock( + return_value='{"jwks": "jwks_data", "jwks_uri": "jwks_uri_data"}' + ) + # Mock the async context manager (__aenter__ and __aexit__) + mock_redis.__aenter__.return_value = mock_redis + mock_redis.__aexit__.return_value = AsyncMock() + + # Initialize mock ConnectionPool + mock_redis_pool = AsyncMock(spec=ConnectionPool) + + # Patch the Redis class in the UserAuthentication module to return mock_redis + with patch("lcfs.services.keycloak.authentication.Redis", return_value=mock_redis): + # Initialize UserAuthentication with the mocked ConnectionPool + auth_backend = UserAuthentication( + redis_pool=mock_redis_pool, + session_factory=AsyncMock(), + settings=MagicMock( + well_known_endpoint="https://example.com/.well-known/openid-configuration" + ), + ) + + # Call refresh_jwk await auth_backend.refresh_jwk() - assert auth_backend.jwks == "jwks" - assert auth_backend.jwks_uri == "jwks_uri" + # Assertions to verify JWKS data was loaded correctly + assert auth_backend.jwks == "jwks_data" + assert auth_backend.jwks_uri == "jwks_uri_data" + + # Verify that Redis `get` was called with the correct key + mock_redis.get.assert_awaited_once_with("jwks_data") @pytest.mark.anyio @patch("httpx.AsyncClient.get") -async def test_refresh_jwk_sets_new_keys_in_redis(mock_get, auth_backend): - # Create a mock response object - mock_response = MagicMock() - - # Set up the json method to return a dictionary with a .get method - mock_json = MagicMock() - mock_json.get.return_value = "{}" - - # Assign the mock_json to the json method of the response - mock_response.json.return_value = mock_json - - mock_response_2 = MagicMock() - mock_response_2.json.return_value = "{}" - - mock_get.side_effect = [ - mock_response, - mock_response_2, - ] - - with patch.object(auth_backend.redis_pool, "get", new_callable=AsyncMock) as mock_redis_get: - mock_redis_get.return_value = None - +async def test_refresh_jwk_sets_new_keys_in_redis(mock_httpx_get): + # Mock responses for the well-known endpoint and JWKS URI + mock_oidc_response = MagicMock() + mock_oidc_response.json.return_value = {"jwks_uri": "https://example.com/jwks"} + mock_oidc_response.raise_for_status = MagicMock() + + mock_certs_response = MagicMock() + mock_certs_response.json.return_value = { + "keys": [{"kty": "RSA", "kid": "key2", "use": "sig", "n": "def", "e": "AQAB"}] + } + mock_certs_response.raise_for_status = MagicMock() + + # Configure the mock to return the above responses in order + mock_httpx_get.side_effect = [mock_oidc_response, mock_certs_response] + + # Initialize mock Redis client + mock_redis = AsyncMock(spec=Redis) + mock_redis.get = AsyncMock(return_value=None) # JWKS data not in cache + mock_redis.set = AsyncMock() + + # Mock the async context manager (__aenter__ and __aexit__) + mock_redis.__aenter__.return_value = mock_redis + mock_redis.__aexit__.return_value = AsyncMock() + + # Initialize mock ConnectionPool + mock_redis_pool = AsyncMock(spec=ConnectionPool) + + # Patch the Redis class in the UserAuthentication module to return mock_redis + with patch("lcfs.services.keycloak.authentication.Redis", return_value=mock_redis): + # Initialize UserAuthentication with the mocked ConnectionPool + auth_backend = UserAuthentication( + redis_pool=mock_redis_pool, + session_factory=AsyncMock(), + settings=MagicMock( + well_known_endpoint="https://example.com/.well-known/openid-configuration" + ), + ) + + # Call refresh_jwk await auth_backend.refresh_jwk() + # Assertions to verify JWKS data was fetched and set correctly + expected_jwks = { + "keys": [ + {"kty": "RSA", "kid": "key2", "use": "sig", "n": "def", "e": "AQAB"} + ] + } + assert auth_backend.jwks == expected_jwks + assert auth_backend.jwks_uri == "https://example.com/jwks" + + # Verify that Redis `get` was called with "jwks_data" + mock_redis.get.assert_awaited_once_with("jwks_data") + + # Verify that the well-known endpoint was called twice + assert mock_httpx_get.call_count == 2 + mock_httpx_get.assert_any_call( + "https://example.com/.well-known/openid-configuration" + ) + mock_httpx_get.assert_any_call("https://example.com/jwks") + + # Verify that Redis `set` was called with the correct parameters + expected_jwks_data = { + "jwks": expected_jwks, + "jwks_uri": "https://example.com/jwks", + } + mock_redis.set.assert_awaited_once_with( + "jwks_data", json.dumps(expected_jwks_data), ex=86400 + ) @pytest.mark.anyio diff --git a/backend/lcfs/web/lifetime.py b/backend/lcfs/web/lifetime.py index 0215689e7..cbeb556b2 100644 --- a/backend/lcfs/web/lifetime.py +++ b/backend/lcfs/web/lifetime.py @@ -4,7 +4,7 @@ import boto3 from fastapi_cache import FastAPICache from fastapi_cache.backends.redis import RedisBackend -from redis import asyncio as aioredis +from redis.asyncio import Redis from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from lcfs.services.rabbitmq.consumers import start_consumers, stop_consumers @@ -57,8 +57,11 @@ async def _startup() -> None: # noqa: WPS430 # Assign settings to app state for global access app.state.settings = settings - # Initialize the cache with Redis backend using app.state.redis_pool - FastAPICache.init(RedisBackend(app.state.redis_pool), prefix="lcfs") + # Create a Redis client from the connection pool + redis_client = Redis(connection_pool=app.state.redis_pool) + + # Initialize FastAPI cache with the Redis client + FastAPICache.init(RedisBackend(redis_client), prefix="lcfs") await init_org_balance_cache(app)