Skip to content

Commit

Permalink
Merge branch 'release-0.2.0' into fix/hamed-fuel-code-logic-fix-1303-…
Browse files Browse the repository at this point in the history
…1433
  • Loading branch information
dhaselhan authored Dec 18, 2024
2 parents ec65eb2 + f3b404d commit 8d403e5
Show file tree
Hide file tree
Showing 25 changed files with 697 additions and 273 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Add legacy id to compliance reports
Revision ID: 5b374dd97469
Revises: f93546eaec61
Create Date: 2024-17-13 12:25:32.076684
"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "5b374dd97469"
down_revision = "f93546eaec61"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"compliance_report",
sa.Column(
"legacy_id",
sa.Integer(),
nullable=True,
comment="ID from TFRS if this is a transferred application, NULL otherwise",
),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("compliance_report", "legacy_id")
# ### end Alembic commands ###
5 changes: 5 additions & 0 deletions backend/lcfs/db/models/compliance/ComplianceReport.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ class ComplianceReport(BaseModel, Auditable):
default=lambda: str(uuid.uuid4()),
comment="UUID that groups all versions of a compliance report",
)
legacy_id = Column(
Integer,
nullable=True,
comment="ID from TFRS if this is a transferred application, NULL otherwise",
)
version = Column(
Integer,
nullable=False,
Expand Down
5 changes: 3 additions & 2 deletions backend/lcfs/services/rabbitmq/base_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import aio_pika
from aio_pika.abc import AbstractChannel, AbstractQueue
from fastapi import FastAPI

from lcfs.settings import settings

Expand All @@ -12,11 +13,12 @@


class BaseConsumer:
def __init__(self, queue_name=None):
def __init__(self, app: FastAPI, queue_name: str):
self.connection = None
self.channel = None
self.queue = None
self.queue_name = queue_name
self.app = app

async def connect(self):
"""Connect to RabbitMQ and set up the consumer."""
Expand All @@ -42,7 +44,6 @@ async def start_consuming(self):
async with message.process():
logger.debug(f"Received message: {message.body.decode()}")
await self.process_message(message.body)
logger.debug("Message Processed")

async def process_message(self, body: bytes):
"""Process the incoming message. Override this method in subclasses."""
Expand Down
12 changes: 6 additions & 6 deletions backend/lcfs/services/rabbitmq/consumers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import asyncio

from lcfs.services.rabbitmq.transaction_consumer import (
setup_transaction_consumer,
close_transaction_consumer,
from lcfs.services.rabbitmq.report_consumer import (
setup_report_consumer,
close_report_consumer,
)


async def start_consumers():
await setup_transaction_consumer()
async def start_consumers(app):
await setup_report_consumer(app)


async def stop_consumers():
await close_transaction_consumer()
await close_report_consumer()
292 changes: 292 additions & 0 deletions backend/lcfs/services/rabbitmq/report_consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
import asyncio
import json
import logging
from typing import Optional

from fastapi import FastAPI
from sqlalchemy.ext.asyncio import AsyncSession

from lcfs.db.dependencies import async_engine
from lcfs.db.models.compliance.ComplianceReportStatus import ComplianceReportStatusEnum
from lcfs.db.models.transaction.Transaction import TransactionActionEnum
from lcfs.db.models.user import UserProfile
from lcfs.services.rabbitmq.base_consumer import BaseConsumer
from lcfs.services.tfrs.redis_balance import RedisBalanceService
from lcfs.settings import settings
from lcfs.web.api.compliance_report.repo import ComplianceReportRepository
from lcfs.web.api.compliance_report.schema import ComplianceReportCreateSchema
from lcfs.web.api.compliance_report.services import ComplianceReportServices
from lcfs.web.api.organizations.repo import OrganizationsRepository
from lcfs.web.api.organizations.services import OrganizationsService
from lcfs.web.api.transaction.repo import TransactionRepository
from lcfs.web.api.user.repo import UserRepository
from lcfs.web.exception.exceptions import ServiceException

logger = logging.getLogger(__name__)

consumer = None
consumer_task = None

VALID_ACTIONS = {"Created", "Submitted", "Approved"}


async def setup_report_consumer(app: FastAPI):
"""
Set up the report consumer and start consuming messages.
"""
global consumer, consumer_task
consumer = ReportConsumer(app)
await consumer.connect()
consumer_task = asyncio.create_task(consumer.start_consuming())


async def close_report_consumer():
"""
Cancel the consumer task if it exists and close the consumer connection.
"""
global consumer, consumer_task

if consumer_task:
consumer_task.cancel()

if consumer:
await consumer.close_connection()


class ReportConsumer(BaseConsumer):
"""
A consumer for handling TFRS compliance report messages from a RabbitMQ queue.
"""

def __init__(
self, app: FastAPI, queue_name: str = settings.rabbitmq_transaction_queue
):
super().__init__(app, queue_name)

async def process_message(self, body: bytes):
"""
Process an incoming message from the queue.
Expected message structure:
{
"tfrs_id": int,
"organization_id": int,
"compliance_period": str,
"nickname": str,
"action": "Created"|"Submitted"|"Approved",
"credits": int (optional),
"user_id": int
}
"""
message = self._parse_message(body)
if not message:
return # Invalid message already logged

action = message["action"]
org_id = message["organization_id"]

if action not in VALID_ACTIONS:
logger.error(f"Invalid action '{action}' in message.")
return

logger.info(f"Received '{action}' action from TFRS for Org {org_id}")

try:
await self.handle_message(
action=action,
compliance_period=message.get("compliance_period"),
compliance_units=message.get("credits"),
legacy_id=message["tfrs_id"],
nickname=message.get("nickname"),
org_id=org_id,
user_id=message["user_id"],
)
except Exception:
logger.exception("Failed to handle message")

def _parse_message(self, body: bytes) -> Optional[dict]:
"""
Parse the message body into a dictionary.
Log and return None if parsing fails or required fields are missing.
"""
try:
message_content = json.loads(body.decode())
except json.JSONDecodeError:
logger.error("Failed to decode message body as JSON.")
return None

required_fields = ["tfrs_id", "organization_id", "action", "user_id"]
if any(field not in message_content for field in required_fields):
logger.error("Message missing required fields.")
return None

return message_content

async def handle_message(
self,
action: str,
compliance_period: str,
compliance_units: Optional[int],
legacy_id: int,
nickname: Optional[str],
org_id: int,
user_id: int,
):
"""
Handle a given message action by loading dependencies and calling the respective handler.
"""
redis_client = self.app.state.redis_client

async with AsyncSession(async_engine) as session:
async with session.begin():
# Initialize repositories and services
org_repo = OrganizationsRepository(db=session)
transaction_repo = TransactionRepository(db=session)
redis_balance_service = RedisBalanceService(
transaction_repo=transaction_repo, redis_client=redis_client
)
org_service = OrganizationsService(
repo=org_repo,
transaction_repo=transaction_repo,
redis_balance_service=redis_balance_service,
)
compliance_report_repo = ComplianceReportRepository(db=session)
compliance_report_service = ComplianceReportServices(
repo=compliance_report_repo
)
user = await UserRepository(db=session).get_user_by_id(user_id)

if action == "Created":
await self._handle_created(
org_id,
legacy_id,
compliance_period,
nickname,
user,
compliance_report_service,
)
elif action == "Submitted":
await self._handle_submitted(
compliance_report_repo,
compliance_units,
legacy_id,
org_id,
org_service,
session,
user,
)
elif action == "Approved":
await self._handle_approved(
legacy_id,
compliance_report_repo,
transaction_repo,
user,
session,
)

async def _handle_created(
self,
org_id: int,
legacy_id: int,
compliance_period: str,
nickname: str,
user: UserProfile,
compliance_report_service: ComplianceReportServices,
):
"""
Handle the 'Created' action by creating a new compliance report draft.
"""
lcfs_report = ComplianceReportCreateSchema(
legacy_id=legacy_id,
compliance_period=compliance_period,
organization_id=org_id,
nickname=nickname,
status=ComplianceReportStatusEnum.Draft.value,
)
await compliance_report_service.create_compliance_report(
org_id, lcfs_report, user
)

async def _handle_approved(
self,
legacy_id: int,
compliance_report_repo: ComplianceReportRepository,
transaction_repo: TransactionRepository,
user: UserProfile,
session: AsyncSession,
):
"""
Handle the 'Approved' action by updating the report status to 'Assessed'
and confirming the associated transaction.
"""
existing_report = (
await compliance_report_repo.get_compliance_report_by_legacy_id(legacy_id)
)
if not existing_report:
raise ServiceException(
f"No compliance report found for legacy ID {legacy_id}"
)

new_status = await compliance_report_repo.get_compliance_report_status_by_desc(
ComplianceReportStatusEnum.Assessed.value
)
existing_report.current_status_id = new_status.compliance_report_status_id
session.add(existing_report)
await session.flush()

await compliance_report_repo.add_compliance_report_history(
existing_report, user
)

existing_transaction = await transaction_repo.get_transaction_by_id(
existing_report.transaction_id
)
if not existing_transaction:
raise ServiceException(
"Compliance Report does not have an associated transaction"
)

if existing_transaction.transaction_action != TransactionActionEnum.Reserved:
raise ServiceException(
f"Transaction {existing_transaction.transaction_id} is not in 'Reserved' status"
)

await transaction_repo.confirm_transaction(existing_transaction.transaction_id)

async def _handle_submitted(
self,
compliance_report_repo: ComplianceReportRepository,
compliance_units: int,
legacy_id: int,
org_id: int,
org_service: OrganizationsService,
session: AsyncSession,
user: UserProfile,
):
"""
Handle the 'Submitted' action by linking a reserved transaction
to the compliance report and updating its status.
"""
existing_report = (
await compliance_report_repo.get_compliance_report_by_legacy_id(legacy_id)
)
if not existing_report:
raise ServiceException(
f"No compliance report found for legacy ID {legacy_id}"
)

transaction = await org_service.adjust_balance(
TransactionActionEnum.Reserved, compliance_units, org_id
)
existing_report.transaction_id = transaction.transaction_id

new_status = await compliance_report_repo.get_compliance_report_status_by_desc(
ComplianceReportStatusEnum.Submitted.value
)
existing_report.current_status_id = new_status.compliance_report_status_id
session.add(existing_report)
await session.flush()

await compliance_report_repo.add_compliance_report_history(
existing_report, user
)
Loading

0 comments on commit 8d403e5

Please sign in to comment.