Skip to content

Commit

Permalink
Merge branch 'release-0.2.0' into feat/prashanth-pills-ag-grid-1249
Browse files Browse the repository at this point in the history
  • Loading branch information
prv-proton authored Dec 4, 2024
2 parents 005d576 + 07139c3 commit 03b6531
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 107 deletions.
71 changes: 43 additions & 28 deletions backend/lcfs/services/keycloak/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import httpx
import jwt
from fastapi import HTTPException, Depends
from redis import ConnectionPool
from redis.asyncio import Redis
from fastapi import HTTPException
from redis.asyncio import Redis, ConnectionPool
from sqlalchemy import func
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import async_sessionmaker
Expand All @@ -27,7 +26,7 @@ class UserAuthentication(AuthenticationBackend):

def __init__(
self,
redis_pool: Redis,
redis_pool: ConnectionPool,
session_factory: async_sessionmaker,
settings: Settings,
):
Expand All @@ -39,30 +38,46 @@ def __init__(
self.test_keycloak_user = None

async def refresh_jwk(self):
# Try to get the JWKS data from Redis cache
jwks_data = await self.redis_pool.get("jwks_data")

if jwks_data:
jwks_data = json.loads(jwks_data)
self.jwks = jwks_data.get("jwks")
self.jwks_uri = jwks_data.get("jwks_uri")
return

# If not in cache, retrieve from the well-known endpoint
async with httpx.AsyncClient() as client:
oidc_response = await client.get(self.settings.well_known_endpoint)
jwks_uri = oidc_response.json().get("jwks_uri")
certs_response = await client.get(jwks_uri)
jwks = certs_response.json()

# Composite object containing both JWKS and JWKS URI
jwks_data = {"jwks": jwks, "jwks_uri": jwks_uri}

# Cache the composite JWKS data with a TTL of 1 day (86400 seconds)
await self.redis_pool.set("jwks_data", json.dumps(jwks_data), ex=86400)

self.jwks = jwks
self.jwks_uri = jwks_uri
"""
Refreshes the JSON Web Key (JWK) used for token verification.
This method attempts to retrieve the JWK from Redis cache.
If not found, it fetches it from the well-known endpoint
and stores it in Redis for future use.
"""
# Create a Redis client from the connection pool
async with Redis(connection_pool=self.redis_pool) as redis:
# Try to get the JWKS data from Redis cache
jwks_data = await redis.get("jwks_data")

if jwks_data:
jwks_data = json.loads(jwks_data)
self.jwks = jwks_data.get("jwks")
self.jwks_uri = jwks_data.get("jwks_uri")
return

# If not in cache, retrieve from the well-known endpoint
async with httpx.AsyncClient() as client:
oidc_response = await client.get(self.settings.well_known_endpoint)
oidc_response.raise_for_status()
jwks_uri = oidc_response.json().get("jwks_uri")

if not jwks_uri:
raise ValueError(
"JWKS URI not found in the well-known endpoint response."
)

certs_response = await client.get(jwks_uri)
certs_response.raise_for_status()
jwks = certs_response.json()

# Composite object containing both JWKS and JWKS URI
jwks_data = {"jwks": jwks, "jwks_uri": jwks_uri}

# Cache the composite JWKS data with a TTL of 1 day (86400 seconds)
await redis.set("jwks_data", json.dumps(jwks_data), ex=86400)

self.jwks = jwks
self.jwks_uri = jwks_uri

async def authenticate(self, request):
# Extract the authorization header from the request
Expand Down
36 changes: 31 additions & 5 deletions backend/lcfs/services/tfrs/redis_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,19 @@


async def init_org_balance_cache(app: FastAPI):
redis = await app.state.redis_pool
"""
Initialize the organization balance cache and populate it with data.
:param app: FastAPI application instance.
"""
# Get the Redis connection pool from app state
redis_pool: ConnectionPool = app.state.redis_pool

# Create a Redis client using the connection pool
redis = Redis(connection_pool=redis_pool)

async with AsyncSession(async_engine) as session:
async with session.begin():

organization_repo = OrganizationsRepository(db=session)
transaction_repo = TransactionRepository(db=session)

Expand All @@ -29,23 +38,32 @@ async def init_org_balance_cache(app: FastAPI):

# Get the current year
current_year = datetime.now().year
logger.info(f"Starting balance cache population {current_year}")
logger.info(f"Starting balance cache population for {current_year}")

# Fetch all organizations
all_orgs = await organization_repo.get_organizations()

# Loop from the oldest year to the current year
for year in range(int(oldest_year), current_year + 1):
# Call the function to process transactions for each year
for org in all_orgs:
# Calculate the balance for each organization and year
balance = (
await transaction_repo.calculate_available_balance_for_period(
org.organization_id, year
)
)
# Set the balance in Redis
await set_cache_value(org.organization_id, year, balance, redis)
logger.debug(f"Set balance for {org.name} for {year} to {balance}")
logger.debug(
f"Set balance for organization {org.name} "
f"for {year} to {balance}"
)

logger.info(f"Cache populated with {len(all_orgs)} organizations")

# Close the Redis client
await redis.close()


class RedisBalanceService:
def __init__(
Expand Down Expand Up @@ -84,4 +102,12 @@ async def populate_organization_redis_balance(
async def set_cache_value(
organization_id: int, period: int, balance: int, redis: Redis
) -> None:
"""
Set a cache value in Redis for a specific organization and period.
:param organization_id: ID of the organization.
:param period: The year or period for which the balance is being set.
:param balance: The balance value to set in the cache.
:param redis: Redis client instance.
"""
await redis.set(name=f"balance_{organization_id}_{period}", value=balance)
79 changes: 37 additions & 42 deletions backend/lcfs/tests/services/tfrs/test_redis_balance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from unittest.mock import AsyncMock, patch, MagicMock, call
from datetime import datetime

from redis.asyncio import ConnectionPool, Redis
Expand All @@ -13,55 +13,50 @@

@pytest.mark.anyio
async def test_init_org_balance_cache():
# Mock the session and repositories
mock_session = AsyncMock()

# Mock the Redis client
# Mock the Redis connection pool
mock_redis_pool = AsyncMock()
mock_redis = AsyncMock()
mock_redis.set = AsyncMock() # Ensure the `set` method is mocked

# Mock the settings
mock_settings = MagicMock()
mock_settings.redis_url = "redis://localhost"

# Create a mock app object
mock_app = MagicMock()

# Simulate redis_pool as an awaitable returning mock_redis
async def mock_redis_pool():
return mock_redis
mock_redis.set = AsyncMock()

mock_app.state.redis_pool = mock_redis_pool()
mock_app.state.settings = mock_settings
# Ensure the `Redis` instance is created with the connection pool
with patch("lcfs.services.tfrs.redis_balance.Redis", return_value=mock_redis):
# Mock the app object
mock_app = MagicMock()
mock_app.state.redis_pool = mock_redis_pool

current_year = datetime.now().year
last_year = current_year - 1
current_year = datetime.now().year
last_year = current_year - 1

with patch(
"lcfs.web.api.organizations.services.OrganizationsRepository.get_organizations",
return_value=[
MagicMock(organization_id=1, name="Org1"),
MagicMock(organization_id=2, name="Org2"),
],
):
# Mock repository methods
with patch(
"lcfs.web.api.organizations.repo.OrganizationsRepository.get_organizations",
return_value=[
MagicMock(organization_id=1, name="Org1"),
MagicMock(organization_id=2, name="Org2"),
],
), patch(
"lcfs.web.api.transaction.repo.TransactionRepository.get_transaction_start_year",
return_value=last_year,
), patch(
"lcfs.web.api.transaction.repo.TransactionRepository.calculate_available_balance_for_period",
side_effect=[100, 200, 150, 250],
):
with patch(
"lcfs.web.api.transaction.repo.TransactionRepository.calculate_available_balance_for_period",
side_effect=[100, 200, 150, 250, 300, 350],
):
# Pass the mock app to the function
await init_org_balance_cache(mock_app)

# Assert that each cache set operation was called correctly
calls = mock_redis.set.mock_calls
assert len(calls) == 4
mock_redis.set.assert_any_call(name=f"balance_1_{last_year}", value=100)
mock_redis.set.assert_any_call(name=f"balance_2_{last_year}", value=200)
mock_redis.set.assert_any_call(name=f"balance_1_{current_year}", value=150)
mock_redis.set.assert_any_call(name=f"balance_2_{current_year}", value=250)
# Execute the function with the mocked app
await init_org_balance_cache(mock_app)

# Define expected calls to Redis `set`
expected_calls = [
call(name=f"balance_1_{last_year}", value=100),
call(name=f"balance_2_{last_year}", value=200),
call(name=f"balance_1_{current_year}", value=150),
call(name=f"balance_2_{current_year}", value=250),
]

# Assert that Redis `set` method was called with the expected arguments
mock_redis.set.assert_has_calls(expected_calls, any_order=True)

# Ensure the number of calls matches the expected count
assert mock_redis.set.call_count == len(expected_calls)


@pytest.mark.anyio
Expand Down
Loading

0 comments on commit 03b6531

Please sign in to comment.