diff --git a/.github/workflows/prod-ci.yaml b/.github/workflows/prod-ci.yaml index f96807cc9..3478be8ff 100644 --- a/.github/workflows/prod-ci.yaml +++ b/.github/workflows/prod-ci.yaml @@ -42,16 +42,31 @@ jobs: echo "IMAGE_TAG retrieved from Test is $imagetag" echo "IMAGE_TAG=$imagetag" >> $GITHUB_OUTPUT + get-current-time: + name: Get Current Time + runs-on: ubuntu-latest + needs: get-image-tag + + outputs: + CURRENT_TIME: ${{ steps.get-current-time.outputs.CURRENT_TIME }} + + steps: + - id: get-current-time + run: | + TZ="America/Vancouver" + echo "CURRENT_TIME=$(date '+%Y-%m-%d %H:%M:%S %Z')" >> $GITHUB_OUTPUT + # Deplog the image which is running on test to prod deploy-on-prod: name: Deploy LCFS on Prod runs-on: ubuntu-latest - needs: get-image-tag + needs: [get-image-tag, get-current-time] timeout-minutes: 60 env: IMAGE_TAG: ${{ needs.get-image-tag.outputs.IMAGE_TAG }} + CURRENT_TIME: ${{ needs.get-current-time.outputs.CURRENT_TIME }} steps: @@ -66,9 +81,17 @@ jobs: uses: trstringer/manual-approval@v1.6.0 with: secret: ${{ github.TOKEN }} - approvers: AlexZorkin,kuanfandevops,hamed-valiollahi,airinggov,areyeslo,dhaselhan,Grulin,justin-lepitzki,kevin-hashimoto + approvers: AlexZorkin,kuanfandevops,hamed-valiollahi,airinggov,areyeslo,dhaselhan,Grulin minimum-approvals: 2 - issue-title: "LCFS ${{env.IMAGE_TAG }} Prod Deployment" + issue-title: "LCFS ${{env.IMAGE_TAG }} Prod Deployment at ${{ env.CURRENT_TIME }}." + + - name: Log in to Openshift + uses: redhat-actions/oc-login@v1.3 + with: + openshift_server_url: ${{ secrets.OPENSHIFT_SERVER }} + openshift_token: ${{ secrets.OPENSHIFT_TOKEN }} + insecure_skip_tls_verify: true + namespace: ${{ env.PROD_NAMESPACE }} - name: Tag LCFS images from Test to Prod run: | @@ -88,6 +111,6 @@ jobs: git config --global user.name "GitHub Actions" git add lcfs/charts/lcfs-frontend/values-prod.yaml git add lcfs/charts/lcfs-backend/values-prod.yaml - git commit -m "update the version with pre-release number for prod" + git commit -m "Update image tag ${{env.IMAGE_TAG }} for prod" git push \ No newline at end of file diff --git a/.github/workflows/test-ci.yaml b/.github/workflows/test-ci.yaml index e8ca4820d..1119b9432 100644 --- a/.github/workflows/test-ci.yaml +++ b/.github/workflows/test-ci.yaml @@ -225,7 +225,7 @@ jobs: uses: trstringer/manual-approval@v1.6.0 with: secret: ${{ github.TOKEN }} - approvers: AlexZorkin,kuanfandevops,hamed-valiollahi,airinggov,areyeslo,dhaselhan,Grulin,justin-lepitzki,kevin-hashimoto + approvers: AlexZorkin,kuanfandevops,hamed-valiollahi,airinggov,areyeslo,dhaselhan,Grulin,kevin-hashimoto minimum-approvals: 1 issue-title: "LCFS ${{ env.VERSION }}-${{ env.PRE_RELEASE }} Test Deployment" diff --git a/backend/lcfs/dependencies/dependencies.py b/backend/lcfs/dependencies/dependencies.py deleted file mode 100644 index 9d160d0dc..000000000 --- a/backend/lcfs/dependencies/dependencies.py +++ /dev/null @@ -1,9 +0,0 @@ -from fastapi import Request -from redis.asyncio import Redis -import boto3 - -async def get_redis_pool(request: Request) -> Redis: - return request.app.state.redis_pool - -async def get_s3_client(request: Request) -> boto3.client: - return request.app.state.s3_client \ No newline at end of file diff --git a/backend/lcfs/services/keycloak/authentication.py b/backend/lcfs/services/keycloak/authentication.py index 86026ba6f..f3fc1f776 100644 --- a/backend/lcfs/services/keycloak/authentication.py +++ b/backend/lcfs/services/keycloak/authentication.py @@ -2,9 +2,8 @@ import httpx import jwt -from fastapi import HTTPException, Depends -from redis import ConnectionPool -from redis.asyncio import Redis +from fastapi import HTTPException +from redis.asyncio import Redis, ConnectionPool from sqlalchemy import func from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import async_sessionmaker @@ -27,7 +26,7 @@ class UserAuthentication(AuthenticationBackend): def __init__( self, - redis_pool: Redis, + redis_pool: ConnectionPool, session_factory: async_sessionmaker, settings: Settings, ): @@ -39,30 +38,46 @@ def __init__( self.test_keycloak_user = None async def refresh_jwk(self): - # Try to get the JWKS data from Redis cache - jwks_data = await self.redis_pool.get("jwks_data") - - if jwks_data: - jwks_data = json.loads(jwks_data) - self.jwks = jwks_data.get("jwks") - self.jwks_uri = jwks_data.get("jwks_uri") - return - - # If not in cache, retrieve from the well-known endpoint - async with httpx.AsyncClient() as client: - oidc_response = await client.get(self.settings.well_known_endpoint) - jwks_uri = oidc_response.json().get("jwks_uri") - certs_response = await client.get(jwks_uri) - jwks = certs_response.json() - - # Composite object containing both JWKS and JWKS URI - jwks_data = {"jwks": jwks, "jwks_uri": jwks_uri} - - # Cache the composite JWKS data with a TTL of 1 day (86400 seconds) - await self.redis_pool.set("jwks_data", json.dumps(jwks_data), ex=86400) - - self.jwks = jwks - self.jwks_uri = jwks_uri + """ + Refreshes the JSON Web Key (JWK) used for token verification. + This method attempts to retrieve the JWK from Redis cache. + If not found, it fetches it from the well-known endpoint + and stores it in Redis for future use. + """ + # Create a Redis client from the connection pool + async with Redis(connection_pool=self.redis_pool) as redis: + # Try to get the JWKS data from Redis cache + jwks_data = await redis.get("jwks_data") + + if jwks_data: + jwks_data = json.loads(jwks_data) + self.jwks = jwks_data.get("jwks") + self.jwks_uri = jwks_data.get("jwks_uri") + return + + # If not in cache, retrieve from the well-known endpoint + async with httpx.AsyncClient() as client: + oidc_response = await client.get(self.settings.well_known_endpoint) + oidc_response.raise_for_status() + jwks_uri = oidc_response.json().get("jwks_uri") + + if not jwks_uri: + raise ValueError( + "JWKS URI not found in the well-known endpoint response." + ) + + certs_response = await client.get(jwks_uri) + certs_response.raise_for_status() + jwks = certs_response.json() + + # Composite object containing both JWKS and JWKS URI + jwks_data = {"jwks": jwks, "jwks_uri": jwks_uri} + + # Cache the composite JWKS data with a TTL of 1 day (86400 seconds) + await redis.set("jwks_data", json.dumps(jwks_data), ex=86400) + + self.jwks = jwks + self.jwks_uri = jwks_uri async def authenticate(self, request): # Extract the authorization header from the request diff --git a/backend/lcfs/services/rabbitmq/transaction_consumer.py b/backend/lcfs/services/rabbitmq/transaction_consumer.py index 10e5367f4..09626551a 100644 --- a/backend/lcfs/services/rabbitmq/transaction_consumer.py +++ b/backend/lcfs/services/rabbitmq/transaction_consumer.py @@ -4,7 +4,7 @@ from redis.asyncio import Redis from sqlalchemy.ext.asyncio import AsyncSession -from lcfs.dependencies.dependencies import get_redis_pool +from lcfs.services.redis.dependency import get_redis_pool from fastapi import Request from lcfs.db.dependencies import async_engine @@ -50,14 +50,14 @@ async def process_message(self, body: bytes, request: Request): compliance_units = message_content.get("compliance_units_amount") org_id = message_content.get("organization_id") - redis = await get_redis_pool(request) + redis_pool = await get_redis_pool(request) async with AsyncSession(async_engine) as session: async with session.begin(): repo = OrganizationsRepository(db=session) transaction_repo = TransactionRepository(db=session) redis_balance_service = RedisBalanceService( - transaction_repo=transaction_repo, redis_pool=redis.connection_pool + transaction_repo=transaction_repo, redis_pool=redis_pool ) org_service = OrganizationsService( repo=repo, diff --git a/backend/lcfs/services/redis/dependency.py b/backend/lcfs/services/redis/dependency.py index 368994ffd..14a19190c 100644 --- a/backend/lcfs/services/redis/dependency.py +++ b/backend/lcfs/services/redis/dependency.py @@ -1,26 +1,23 @@ -from typing import AsyncGenerator - -from redis.asyncio import Redis +from redis.asyncio import ConnectionPool from starlette.requests import Request +# Redis Pool Dependency async def get_redis_pool( request: Request, -) -> AsyncGenerator[Redis, None]: # pragma: no cover +) -> ConnectionPool: """ - Returns connection pool. - - You can use it like this: - - >>> from redis.asyncio import ConnectionPool, Redis - >>> - >>> async def handler(redis_pool: ConnectionPool = Depends(get_redis_pool)): - >>> async with Redis(connection_pool=redis_pool) as redis: - >>> await redis.get('key') + Returns the Redis connection pool. - I use pools, so you don't acquire connection till the end of the handler. + Usage: + >>> from redis.asyncio import ConnectionPool, Redis + >>> + >>> async def handler(redis_pool: ConnectionPool = Depends(get_redis_pool)): + >>> redis = Redis(connection_pool=redis_pool) + >>> await redis.get('key') + >>> await redis.close() - :param request: current request. - :returns: redis connection pool. + :param request: Current request object. + :returns: Redis connection pool. """ return request.app.state.redis_pool diff --git a/backend/lcfs/services/redis/lifetime.py b/backend/lcfs/services/redis/lifetime.py index 3959edbff..2e007adbd 100644 --- a/backend/lcfs/services/redis/lifetime.py +++ b/backend/lcfs/services/redis/lifetime.py @@ -1,12 +1,13 @@ import logging from fastapi import FastAPI -from redis import asyncio as aioredis +from redis.asyncio import ConnectionPool, Redis from redis.exceptions import RedisError from lcfs.settings import settings logger = logging.getLogger(__name__) + async def init_redis(app: FastAPI) -> None: """ Creates connection pool for redis. @@ -14,13 +15,16 @@ async def init_redis(app: FastAPI) -> None: :param app: current fastapi application. """ try: - app.state.redis_pool = aioredis.from_url( + app.state.redis_pool = ConnectionPool.from_url( str(settings.redis_url), encoding="utf8", decode_responses=True, - max_connections=200 + max_connections=200, ) - await app.state.redis_pool.ping() + # Test the connection + redis = Redis(connection_pool=app.state.redis_pool) + await redis.ping() + await redis.close() logger.info("Redis pool initialized successfully.") except RedisError as e: logger.error(f"Redis error during initialization: {e}") @@ -29,6 +33,7 @@ async def init_redis(app: FastAPI) -> None: logger.error(f"Unexpected error during Redis initialization: {e}") raise + async def shutdown_redis(app: FastAPI) -> None: # pragma: no cover """ Closes redis connection pool. @@ -37,8 +42,7 @@ async def shutdown_redis(app: FastAPI) -> None: # pragma: no cover """ try: if hasattr(app.state, "redis_pool"): - await app.state.redis_pool.close() - await app.state.redis_pool.wait_closed() + await app.state.redis_pool.disconnect(inuse_connections=True) logger.info("Redis pool closed successfully.") except RedisError as e: logger.error(f"Redis error during shutdown: {e}") diff --git a/backend/lcfs/services/s3/client.py b/backend/lcfs/services/s3/client.py index c03b54993..11ee27397 100644 --- a/backend/lcfs/services/s3/client.py +++ b/backend/lcfs/services/s3/client.py @@ -7,7 +7,7 @@ from sqlalchemy import select from sqlalchemy.exc import InvalidRequestError from sqlalchemy.ext.asyncio import AsyncSession -from lcfs.dependencies.dependencies import get_s3_client +from lcfs.services.s3.dependency import get_s3_client from lcfs.db.dependencies import get_async_db_session from lcfs.db.models.compliance import ComplianceReport @@ -28,13 +28,13 @@ class DocumentService: def __init__( self, - request: Request, db: AsyncSession = Depends(get_async_db_session), clamav_service: ClamAVService = Depends(), + s3_client=Depends(get_s3_client), ): self.db = db self.clamav_service = clamav_service - self.s3_client = request.app.state.s3_client + self.s3_client = s3_client @repo_handler async def upload_file(self, file, parent_id: str, parent_type="compliance_report"): diff --git a/backend/lcfs/services/s3/dependency.py b/backend/lcfs/services/s3/dependency.py new file mode 100644 index 000000000..d46027c42 --- /dev/null +++ b/backend/lcfs/services/s3/dependency.py @@ -0,0 +1,19 @@ +from starlette.requests import Request +import boto3 + + +# S3 Client Dependency +async def get_s3_client( + request: Request, +) -> boto3.client: + """ + Returns the S3 client from the application state. + + Usage: + >>> async def handler(s3_client = Depends(get_s3_client)): + >>> s3_client.upload_file('file.txt', 'my-bucket', 'file.txt') + + :param request: Current request object. + :returns: S3 client. + """ + return request.app.state.s3_client diff --git a/backend/lcfs/services/s3/lifetime.py b/backend/lcfs/services/s3/lifetime.py new file mode 100644 index 000000000..443f49eae --- /dev/null +++ b/backend/lcfs/services/s3/lifetime.py @@ -0,0 +1,30 @@ +import boto3 +from fastapi import FastAPI +from lcfs.settings import settings + + +async def init_s3(app: FastAPI) -> None: + """ + Initialize the S3 client and store it in the app state. + + :param app: FastAPI application. + """ + app.state.s3_client = boto3.client( + "s3", + aws_access_key_id=settings.s3_access_key, + aws_secret_access_key=settings.s3_secret_key, + endpoint_url=settings.s3_endpoint, + region_name="us-east-1", + ) + print("S3 client initialized.") + + +async def shutdown_s3(app: FastAPI) -> None: + """ + Cleanup the S3 client from the app state. + + :param app: FastAPI application. + """ + if hasattr(app.state, "s3_client"): + del app.state.s3_client + print("S3 client shutdown.") diff --git a/backend/lcfs/services/tfrs/redis_balance.py b/backend/lcfs/services/tfrs/redis_balance.py index 69dc96010..02d6f2a12 100644 --- a/backend/lcfs/services/tfrs/redis_balance.py +++ b/backend/lcfs/services/tfrs/redis_balance.py @@ -17,10 +17,19 @@ async def init_org_balance_cache(app: FastAPI): - redis = await app.state.redis_pool + """ + Initialize the organization balance cache and populate it with data. + + :param app: FastAPI application instance. + """ + # Get the Redis connection pool from app state + redis_pool: ConnectionPool = app.state.redis_pool + + # Create a Redis client using the connection pool + redis = Redis(connection_pool=redis_pool) + async with AsyncSession(async_engine) as session: async with session.begin(): - organization_repo = OrganizationsRepository(db=session) transaction_repo = TransactionRepository(db=session) @@ -29,23 +38,32 @@ async def init_org_balance_cache(app: FastAPI): # Get the current year current_year = datetime.now().year - logger.info(f"Starting balance cache population {current_year}") + logger.info(f"Starting balance cache population for {current_year}") + # Fetch all organizations all_orgs = await organization_repo.get_organizations() # Loop from the oldest year to the current year for year in range(int(oldest_year), current_year + 1): - # Call the function to process transactions for each year for org in all_orgs: + # Calculate the balance for each organization and year balance = ( await transaction_repo.calculate_available_balance_for_period( org.organization_id, year ) ) + # Set the balance in Redis await set_cache_value(org.organization_id, year, balance, redis) - logger.debug(f"Set balance for {org.name} for {year} to {balance}") + logger.debug( + f"Set balance for organization {org.name} " + f"for {year} to {balance}" + ) + logger.info(f"Cache populated with {len(all_orgs)} organizations") + # Close the Redis client + await redis.close() + class RedisBalanceService: def __init__( @@ -84,4 +102,12 @@ async def populate_organization_redis_balance( async def set_cache_value( organization_id: int, period: int, balance: int, redis: Redis ) -> None: + """ + Set a cache value in Redis for a specific organization and period. + + :param organization_id: ID of the organization. + :param period: The year or period for which the balance is being set. + :param balance: The balance value to set in the cache. + :param redis: Redis client instance. + """ await redis.set(name=f"balance_{organization_id}_{period}", value=balance) diff --git a/backend/lcfs/tests/services/tfrs/test_redis_balance.py b/backend/lcfs/tests/services/tfrs/test_redis_balance.py index 56ce31fb1..9df7f20fc 100644 --- a/backend/lcfs/tests/services/tfrs/test_redis_balance.py +++ b/backend/lcfs/tests/services/tfrs/test_redis_balance.py @@ -1,5 +1,5 @@ import pytest -from unittest.mock import AsyncMock, patch, MagicMock +from unittest.mock import AsyncMock, patch, MagicMock, call from datetime import datetime from redis.asyncio import ConnectionPool, Redis @@ -13,55 +13,50 @@ @pytest.mark.anyio async def test_init_org_balance_cache(): - # Mock the session and repositories - mock_session = AsyncMock() - - # Mock the Redis client + # Mock the Redis connection pool + mock_redis_pool = AsyncMock() mock_redis = AsyncMock() - mock_redis.set = AsyncMock() # Ensure the `set` method is mocked - - # Mock the settings - mock_settings = MagicMock() - mock_settings.redis_url = "redis://localhost" - - # Create a mock app object - mock_app = MagicMock() - - # Simulate redis_pool as an awaitable returning mock_redis - async def mock_redis_pool(): - return mock_redis + mock_redis.set = AsyncMock() - mock_app.state.redis_pool = mock_redis_pool() - mock_app.state.settings = mock_settings + # Ensure the `Redis` instance is created with the connection pool + with patch("lcfs.services.tfrs.redis_balance.Redis", return_value=mock_redis): + # Mock the app object + mock_app = MagicMock() + mock_app.state.redis_pool = mock_redis_pool - current_year = datetime.now().year - last_year = current_year - 1 + current_year = datetime.now().year + last_year = current_year - 1 - with patch( - "lcfs.web.api.organizations.services.OrganizationsRepository.get_organizations", - return_value=[ - MagicMock(organization_id=1, name="Org1"), - MagicMock(organization_id=2, name="Org2"), - ], - ): + # Mock repository methods with patch( + "lcfs.web.api.organizations.repo.OrganizationsRepository.get_organizations", + return_value=[ + MagicMock(organization_id=1, name="Org1"), + MagicMock(organization_id=2, name="Org2"), + ], + ), patch( "lcfs.web.api.transaction.repo.TransactionRepository.get_transaction_start_year", return_value=last_year, + ), patch( + "lcfs.web.api.transaction.repo.TransactionRepository.calculate_available_balance_for_period", + side_effect=[100, 200, 150, 250], ): - with patch( - "lcfs.web.api.transaction.repo.TransactionRepository.calculate_available_balance_for_period", - side_effect=[100, 200, 150, 250, 300, 350], - ): - # Pass the mock app to the function - await init_org_balance_cache(mock_app) - - # Assert that each cache set operation was called correctly - calls = mock_redis.set.mock_calls - assert len(calls) == 4 - mock_redis.set.assert_any_call(name=f"balance_1_{last_year}", value=100) - mock_redis.set.assert_any_call(name=f"balance_2_{last_year}", value=200) - mock_redis.set.assert_any_call(name=f"balance_1_{current_year}", value=150) - mock_redis.set.assert_any_call(name=f"balance_2_{current_year}", value=250) + # Execute the function with the mocked app + await init_org_balance_cache(mock_app) + + # Define expected calls to Redis `set` + expected_calls = [ + call(name=f"balance_1_{last_year}", value=100), + call(name=f"balance_2_{last_year}", value=200), + call(name=f"balance_1_{current_year}", value=150), + call(name=f"balance_2_{current_year}", value=250), + ] + + # Assert that Redis `set` method was called with the expected arguments + mock_redis.set.assert_has_calls(expected_calls, any_order=True) + + # Ensure the number of calls matches the expected count + assert mock_redis.set.call_count == len(expected_calls) @pytest.mark.anyio diff --git a/backend/lcfs/tests/test_auth_middleware.py b/backend/lcfs/tests/test_auth_middleware.py index d59076107..083c188a5 100644 --- a/backend/lcfs/tests/test_auth_middleware.py +++ b/backend/lcfs/tests/test_auth_middleware.py @@ -1,9 +1,11 @@ from unittest.mock import AsyncMock, patch, MagicMock, Mock import pytest -import asyncio +import json +import redis from starlette.exceptions import HTTPException from starlette.requests import Request +from redis.asyncio import Redis, ConnectionPool from lcfs.db.models import UserProfile from lcfs.services.keycloak.authentication import UserAuthentication @@ -35,43 +37,112 @@ def auth_backend(redis_pool, session_generator, settings): @pytest.mark.anyio -async def test_load_jwk_from_redis(auth_backend): - # Mock auth_backend.redis_pool.get to return a JSON string directly - with patch.object(auth_backend.redis_pool, "get", new_callable=AsyncMock) as mock_redis_get: - mock_redis_get.return_value = '{"jwks": "jwks", "jwks_uri": "jwks_uri"}' +async def test_load_jwk_from_redis(): + # Initialize mock Redis client + mock_redis = AsyncMock(spec=Redis) + mock_redis.get = AsyncMock( + return_value='{"jwks": "jwks_data", "jwks_uri": "jwks_uri_data"}' + ) + # Mock the async context manager (__aenter__ and __aexit__) + mock_redis.__aenter__.return_value = mock_redis + mock_redis.__aexit__.return_value = AsyncMock() + + # Initialize mock ConnectionPool + mock_redis_pool = AsyncMock(spec=ConnectionPool) + + # Patch the Redis class in the UserAuthentication module to return mock_redis + with patch("lcfs.services.keycloak.authentication.Redis", return_value=mock_redis): + # Initialize UserAuthentication with the mocked ConnectionPool + auth_backend = UserAuthentication( + redis_pool=mock_redis_pool, + session_factory=AsyncMock(), + settings=MagicMock( + well_known_endpoint="https://example.com/.well-known/openid-configuration" + ), + ) + + # Call refresh_jwk await auth_backend.refresh_jwk() - assert auth_backend.jwks == "jwks" - assert auth_backend.jwks_uri == "jwks_uri" + # Assertions to verify JWKS data was loaded correctly + assert auth_backend.jwks == "jwks_data" + assert auth_backend.jwks_uri == "jwks_uri_data" + + # Verify that Redis `get` was called with the correct key + mock_redis.get.assert_awaited_once_with("jwks_data") @pytest.mark.anyio @patch("httpx.AsyncClient.get") -async def test_refresh_jwk_sets_new_keys_in_redis(mock_get, auth_backend): - # Create a mock response object - mock_response = MagicMock() - - # Set up the json method to return a dictionary with a .get method - mock_json = MagicMock() - mock_json.get.return_value = "{}" - - # Assign the mock_json to the json method of the response - mock_response.json.return_value = mock_json - - mock_response_2 = MagicMock() - mock_response_2.json.return_value = "{}" - - mock_get.side_effect = [ - mock_response, - mock_response_2, - ] - - with patch.object(auth_backend.redis_pool, "get", new_callable=AsyncMock) as mock_redis_get: - mock_redis_get.return_value = None - +async def test_refresh_jwk_sets_new_keys_in_redis(mock_httpx_get): + # Mock responses for the well-known endpoint and JWKS URI + mock_oidc_response = MagicMock() + mock_oidc_response.json.return_value = {"jwks_uri": "https://example.com/jwks"} + mock_oidc_response.raise_for_status = MagicMock() + + mock_certs_response = MagicMock() + mock_certs_response.json.return_value = { + "keys": [{"kty": "RSA", "kid": "key2", "use": "sig", "n": "def", "e": "AQAB"}] + } + mock_certs_response.raise_for_status = MagicMock() + + # Configure the mock to return the above responses in order + mock_httpx_get.side_effect = [mock_oidc_response, mock_certs_response] + + # Initialize mock Redis client + mock_redis = AsyncMock(spec=Redis) + mock_redis.get = AsyncMock(return_value=None) # JWKS data not in cache + mock_redis.set = AsyncMock() + + # Mock the async context manager (__aenter__ and __aexit__) + mock_redis.__aenter__.return_value = mock_redis + mock_redis.__aexit__.return_value = AsyncMock() + + # Initialize mock ConnectionPool + mock_redis_pool = AsyncMock(spec=ConnectionPool) + + # Patch the Redis class in the UserAuthentication module to return mock_redis + with patch("lcfs.services.keycloak.authentication.Redis", return_value=mock_redis): + # Initialize UserAuthentication with the mocked ConnectionPool + auth_backend = UserAuthentication( + redis_pool=mock_redis_pool, + session_factory=AsyncMock(), + settings=MagicMock( + well_known_endpoint="https://example.com/.well-known/openid-configuration" + ), + ) + + # Call refresh_jwk await auth_backend.refresh_jwk() + # Assertions to verify JWKS data was fetched and set correctly + expected_jwks = { + "keys": [ + {"kty": "RSA", "kid": "key2", "use": "sig", "n": "def", "e": "AQAB"} + ] + } + assert auth_backend.jwks == expected_jwks + assert auth_backend.jwks_uri == "https://example.com/jwks" + + # Verify that Redis `get` was called with "jwks_data" + mock_redis.get.assert_awaited_once_with("jwks_data") + + # Verify that the well-known endpoint was called twice + assert mock_httpx_get.call_count == 2 + mock_httpx_get.assert_any_call( + "https://example.com/.well-known/openid-configuration" + ) + mock_httpx_get.assert_any_call("https://example.com/jwks") + + # Verify that Redis `set` was called with the correct parameters + expected_jwks_data = { + "jwks": expected_jwks, + "jwks_uri": "https://example.com/jwks", + } + mock_redis.set.assert_awaited_once_with( + "jwks_data", json.dumps(expected_jwks_data), ex=86400 + ) @pytest.mark.anyio diff --git a/backend/lcfs/web/lifetime.py b/backend/lcfs/web/lifetime.py index 186b485cd..cbeb556b2 100644 --- a/backend/lcfs/web/lifetime.py +++ b/backend/lcfs/web/lifetime.py @@ -4,11 +4,12 @@ import boto3 from fastapi_cache import FastAPICache from fastapi_cache.backends.redis import RedisBackend -from redis import asyncio as aioredis +from redis.asyncio import Redis from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from lcfs.services.rabbitmq.consumers import start_consumers, stop_consumers from lcfs.services.redis.lifetime import init_redis, shutdown_redis +from lcfs.services.s3.lifetime import init_s3, shutdown_s3 from lcfs.services.tfrs.redis_balance import init_org_balance_cache from lcfs.settings import settings @@ -32,33 +33,6 @@ def _setup_db(app: FastAPI) -> None: # pragma: no cover app.state.db_session_factory = session_factory -async def startup_s3(app: FastAPI) -> None: - """ - Initialize the S3 client and store it in the app state. - - :param app: fastAPI application. - """ - app.state.s3_client = boto3.client( - "s3", - aws_access_key_id=settings.s3_access_key, - aws_secret_access_key=settings.s3_secret_key, - endpoint_url=settings.s3_endpoint, - region_name="us-east-1", - ) - print("S3 client initialized.") - - -async def shutdown_s3(app: FastAPI) -> None: - """ - Cleanup the S3 client from the app state. - - :param app: fastAPI application. - """ - if hasattr(app.state, "s3_client"): - del app.state.s3_client - print("S3 client shutdown.") - - def register_startup_event( app: FastAPI, ) -> Callable[[], Awaitable[None]]: # pragma: no cover @@ -83,13 +57,16 @@ async def _startup() -> None: # noqa: WPS430 # Assign settings to app state for global access app.state.settings = settings - # Initialize the cache with Redis backend using app.state.redis_pool - FastAPICache.init(RedisBackend(app.state.redis_pool), prefix="lcfs") + # Create a Redis client from the connection pool + redis_client = Redis(connection_pool=app.state.redis_pool) + + # Initialize FastAPI cache with the Redis client + FastAPICache.init(RedisBackend(redis_client), prefix="lcfs") await init_org_balance_cache(app) # Initialize the S3 client - await startup_s3(app) + await init_s3(app) # Setup RabbitMQ Listeners await start_consumers()