diff --git a/backend/lcfs/db/migrations/versions/2024-12-17-12-25_5b374dd97469.py b/backend/lcfs/db/migrations/versions/2024-12-17-12-25_5b374dd97469.py new file mode 100644 index 000000000..3c7475040 --- /dev/null +++ b/backend/lcfs/db/migrations/versions/2024-12-17-12-25_5b374dd97469.py @@ -0,0 +1,36 @@ +"""Add legacy id to compliance reports + +Revision ID: 5b374dd97469 +Revises: f93546eaec61 +Create Date: 2024-17-13 12:25:32.076684 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "5b374dd97469" +down_revision = "f93546eaec61" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "compliance_report", + sa.Column( + "legacy_id", + sa.Integer(), + nullable=True, + comment="ID from TFRS if this is a transferred application, NULL otherwise", + ), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("compliance_report", "legacy_id") + # ### end Alembic commands ### diff --git a/backend/lcfs/db/models/compliance/ComplianceReport.py b/backend/lcfs/db/models/compliance/ComplianceReport.py index d88e023d2..6656cc6ec 100644 --- a/backend/lcfs/db/models/compliance/ComplianceReport.py +++ b/backend/lcfs/db/models/compliance/ComplianceReport.py @@ -100,6 +100,11 @@ class ComplianceReport(BaseModel, Auditable): default=lambda: str(uuid.uuid4()), comment="UUID that groups all versions of a compliance report", ) + legacy_id = Column( + Integer, + nullable=True, + comment="ID from TFRS if this is a transferred application, NULL otherwise", + ) version = Column( Integer, nullable=False, diff --git a/backend/lcfs/services/rabbitmq/base_consumer.py b/backend/lcfs/services/rabbitmq/base_consumer.py index 26bc3ebdd..9a80bf56a 100644 --- a/backend/lcfs/services/rabbitmq/base_consumer.py +++ b/backend/lcfs/services/rabbitmq/base_consumer.py @@ -3,6 +3,7 @@ import aio_pika from aio_pika.abc import AbstractChannel, AbstractQueue +from fastapi import FastAPI from lcfs.settings import settings @@ -12,11 +13,12 @@ class BaseConsumer: - def __init__(self, queue_name=None): + def __init__(self, app: FastAPI, queue_name: str): self.connection = None self.channel = None self.queue = None self.queue_name = queue_name + self.app = app async def connect(self): """Connect to RabbitMQ and set up the consumer.""" @@ -42,7 +44,6 @@ async def start_consuming(self): async with message.process(): logger.debug(f"Received message: {message.body.decode()}") await self.process_message(message.body) - logger.debug("Message Processed") async def process_message(self, body: bytes): """Process the incoming message. Override this method in subclasses.""" diff --git a/backend/lcfs/services/rabbitmq/consumers.py b/backend/lcfs/services/rabbitmq/consumers.py index de934c00f..17cfdf193 100644 --- a/backend/lcfs/services/rabbitmq/consumers.py +++ b/backend/lcfs/services/rabbitmq/consumers.py @@ -1,14 +1,14 @@ import asyncio -from lcfs.services.rabbitmq.transaction_consumer import ( - setup_transaction_consumer, - close_transaction_consumer, +from lcfs.services.rabbitmq.report_consumer import ( + setup_report_consumer, + close_report_consumer, ) -async def start_consumers(): - await setup_transaction_consumer() +async def start_consumers(app): + await setup_report_consumer(app) async def stop_consumers(): - await close_transaction_consumer() + await close_report_consumer() diff --git a/backend/lcfs/services/rabbitmq/report_consumer.py b/backend/lcfs/services/rabbitmq/report_consumer.py new file mode 100644 index 000000000..f03df28ee --- /dev/null +++ b/backend/lcfs/services/rabbitmq/report_consumer.py @@ -0,0 +1,292 @@ +import asyncio +import json +import logging +from typing import Optional + +from fastapi import FastAPI +from sqlalchemy.ext.asyncio import AsyncSession + +from lcfs.db.dependencies import async_engine +from lcfs.db.models.compliance.ComplianceReportStatus import ComplianceReportStatusEnum +from lcfs.db.models.transaction.Transaction import TransactionActionEnum +from lcfs.db.models.user import UserProfile +from lcfs.services.rabbitmq.base_consumer import BaseConsumer +from lcfs.services.tfrs.redis_balance import RedisBalanceService +from lcfs.settings import settings +from lcfs.web.api.compliance_report.repo import ComplianceReportRepository +from lcfs.web.api.compliance_report.schema import ComplianceReportCreateSchema +from lcfs.web.api.compliance_report.services import ComplianceReportServices +from lcfs.web.api.organizations.repo import OrganizationsRepository +from lcfs.web.api.organizations.services import OrganizationsService +from lcfs.web.api.transaction.repo import TransactionRepository +from lcfs.web.api.user.repo import UserRepository +from lcfs.web.exception.exceptions import ServiceException + +logger = logging.getLogger(__name__) + +consumer = None +consumer_task = None + +VALID_ACTIONS = {"Created", "Submitted", "Approved"} + + +async def setup_report_consumer(app: FastAPI): + """ + Set up the report consumer and start consuming messages. + """ + global consumer, consumer_task + consumer = ReportConsumer(app) + await consumer.connect() + consumer_task = asyncio.create_task(consumer.start_consuming()) + + +async def close_report_consumer(): + """ + Cancel the consumer task if it exists and close the consumer connection. + """ + global consumer, consumer_task + + if consumer_task: + consumer_task.cancel() + + if consumer: + await consumer.close_connection() + + +class ReportConsumer(BaseConsumer): + """ + A consumer for handling TFRS compliance report messages from a RabbitMQ queue. + """ + + def __init__( + self, app: FastAPI, queue_name: str = settings.rabbitmq_transaction_queue + ): + super().__init__(app, queue_name) + + async def process_message(self, body: bytes): + """ + Process an incoming message from the queue. + + Expected message structure: + { + "tfrs_id": int, + "organization_id": int, + "compliance_period": str, + "nickname": str, + "action": "Created"|"Submitted"|"Approved", + "credits": int (optional), + "user_id": int + } + """ + message = self._parse_message(body) + if not message: + return # Invalid message already logged + + action = message["action"] + org_id = message["organization_id"] + + if action not in VALID_ACTIONS: + logger.error(f"Invalid action '{action}' in message.") + return + + logger.info(f"Received '{action}' action from TFRS for Org {org_id}") + + try: + await self.handle_message( + action=action, + compliance_period=message.get("compliance_period"), + compliance_units=message.get("credits"), + legacy_id=message["tfrs_id"], + nickname=message.get("nickname"), + org_id=org_id, + user_id=message["user_id"], + ) + except Exception: + logger.exception("Failed to handle message") + + def _parse_message(self, body: bytes) -> Optional[dict]: + """ + Parse the message body into a dictionary. + Log and return None if parsing fails or required fields are missing. + """ + try: + message_content = json.loads(body.decode()) + except json.JSONDecodeError: + logger.error("Failed to decode message body as JSON.") + return None + + required_fields = ["tfrs_id", "organization_id", "action", "user_id"] + if any(field not in message_content for field in required_fields): + logger.error("Message missing required fields.") + return None + + return message_content + + async def handle_message( + self, + action: str, + compliance_period: str, + compliance_units: Optional[int], + legacy_id: int, + nickname: Optional[str], + org_id: int, + user_id: int, + ): + """ + Handle a given message action by loading dependencies and calling the respective handler. + """ + redis_client = self.app.state.redis_client + + async with AsyncSession(async_engine) as session: + async with session.begin(): + # Initialize repositories and services + org_repo = OrganizationsRepository(db=session) + transaction_repo = TransactionRepository(db=session) + redis_balance_service = RedisBalanceService( + transaction_repo=transaction_repo, redis_client=redis_client + ) + org_service = OrganizationsService( + repo=org_repo, + transaction_repo=transaction_repo, + redis_balance_service=redis_balance_service, + ) + compliance_report_repo = ComplianceReportRepository(db=session) + compliance_report_service = ComplianceReportServices( + repo=compliance_report_repo + ) + user = await UserRepository(db=session).get_user_by_id(user_id) + + if action == "Created": + await self._handle_created( + org_id, + legacy_id, + compliance_period, + nickname, + user, + compliance_report_service, + ) + elif action == "Submitted": + await self._handle_submitted( + compliance_report_repo, + compliance_units, + legacy_id, + org_id, + org_service, + session, + user, + ) + elif action == "Approved": + await self._handle_approved( + legacy_id, + compliance_report_repo, + transaction_repo, + user, + session, + ) + + async def _handle_created( + self, + org_id: int, + legacy_id: int, + compliance_period: str, + nickname: str, + user: UserProfile, + compliance_report_service: ComplianceReportServices, + ): + """ + Handle the 'Created' action by creating a new compliance report draft. + """ + lcfs_report = ComplianceReportCreateSchema( + legacy_id=legacy_id, + compliance_period=compliance_period, + organization_id=org_id, + nickname=nickname, + status=ComplianceReportStatusEnum.Draft.value, + ) + await compliance_report_service.create_compliance_report( + org_id, lcfs_report, user + ) + + async def _handle_approved( + self, + legacy_id: int, + compliance_report_repo: ComplianceReportRepository, + transaction_repo: TransactionRepository, + user: UserProfile, + session: AsyncSession, + ): + """ + Handle the 'Approved' action by updating the report status to 'Assessed' + and confirming the associated transaction. + """ + existing_report = ( + await compliance_report_repo.get_compliance_report_by_legacy_id(legacy_id) + ) + if not existing_report: + raise ServiceException( + f"No compliance report found for legacy ID {legacy_id}" + ) + + new_status = await compliance_report_repo.get_compliance_report_status_by_desc( + ComplianceReportStatusEnum.Assessed.value + ) + existing_report.current_status_id = new_status.compliance_report_status_id + session.add(existing_report) + await session.flush() + + await compliance_report_repo.add_compliance_report_history( + existing_report, user + ) + + existing_transaction = await transaction_repo.get_transaction_by_id( + existing_report.transaction_id + ) + if not existing_transaction: + raise ServiceException( + "Compliance Report does not have an associated transaction" + ) + + if existing_transaction.transaction_action != TransactionActionEnum.Reserved: + raise ServiceException( + f"Transaction {existing_transaction.transaction_id} is not in 'Reserved' status" + ) + + await transaction_repo.confirm_transaction(existing_transaction.transaction_id) + + async def _handle_submitted( + self, + compliance_report_repo: ComplianceReportRepository, + compliance_units: int, + legacy_id: int, + org_id: int, + org_service: OrganizationsService, + session: AsyncSession, + user: UserProfile, + ): + """ + Handle the 'Submitted' action by linking a reserved transaction + to the compliance report and updating its status. + """ + existing_report = ( + await compliance_report_repo.get_compliance_report_by_legacy_id(legacy_id) + ) + if not existing_report: + raise ServiceException( + f"No compliance report found for legacy ID {legacy_id}" + ) + + transaction = await org_service.adjust_balance( + TransactionActionEnum.Reserved, compliance_units, org_id + ) + existing_report.transaction_id = transaction.transaction_id + + new_status = await compliance_report_repo.get_compliance_report_status_by_desc( + ComplianceReportStatusEnum.Submitted.value + ) + existing_report.current_status_id = new_status.compliance_report_status_id + session.add(existing_report) + await session.flush() + + await compliance_report_repo.add_compliance_report_history( + existing_report, user + ) diff --git a/backend/lcfs/services/rabbitmq/transaction_consumer.py b/backend/lcfs/services/rabbitmq/transaction_consumer.py deleted file mode 100644 index 381142d6c..000000000 --- a/backend/lcfs/services/rabbitmq/transaction_consumer.py +++ /dev/null @@ -1,71 +0,0 @@ -import asyncio -import json -import logging - -from redis.asyncio import Redis -from sqlalchemy.ext.asyncio import AsyncSession -from lcfs.services.redis.dependency import get_redis_client -from fastapi import Request - -from lcfs.db.dependencies import async_engine -from lcfs.db.models.transaction.Transaction import TransactionActionEnum -from lcfs.services.rabbitmq.base_consumer import BaseConsumer -from lcfs.services.tfrs.redis_balance import RedisBalanceService -from lcfs.settings import settings -from lcfs.web.api.organizations.repo import OrganizationsRepository -from lcfs.web.api.organizations.services import OrganizationsService -from lcfs.web.api.transaction.repo import TransactionRepository - -logger = logging.getLogger(__name__) -consumer = None -consumer_task = None - - -async def setup_transaction_consumer(): - global consumer, consumer_task - consumer = TransactionConsumer() - await consumer.connect() - consumer_task = asyncio.create_task(consumer.start_consuming()) - - -async def close_transaction_consumer(): - global consumer, consumer_task - - if consumer_task: - consumer_task.cancel() - - if consumer: - await consumer.close_connection() - - -class TransactionConsumer(BaseConsumer): - def __init__( - self, - queue_name=settings.rabbitmq_transaction_queue, - ): - super().__init__(queue_name) - - async def process_message(self, body: bytes, request: Request): - message_content = json.loads(body.decode()) - compliance_units = message_content.get("compliance_units_amount") - org_id = message_content.get("organization_id") - - redis_client = await get_redis_client(request) - - async with AsyncSession(async_engine) as session: - async with session.begin(): - repo = OrganizationsRepository(db=session) - transaction_repo = TransactionRepository(db=session) - redis_balance_service = RedisBalanceService( - transaction_repo=transaction_repo, redis_client=redis_client - ) - org_service = OrganizationsService( - repo=repo, - transaction_repo=transaction_repo, - redis_balance_service=redis_balance_service, - ) - - await org_service.adjust_balance( - TransactionActionEnum.Adjustment, compliance_units, org_id - ) - logger.debug(f"Processed Transaction from TFRS for Org {org_id}") diff --git a/backend/lcfs/tests/compliance_report/test_compliance_report_repo.py b/backend/lcfs/tests/compliance_report/test_compliance_report_repo.py index c26603dd8..84ed2520b 100644 --- a/backend/lcfs/tests/compliance_report/test_compliance_report_repo.py +++ b/backend/lcfs/tests/compliance_report/test_compliance_report_repo.py @@ -558,7 +558,7 @@ async def test_add_compliance_report_success( version=1, ) - report = await compliance_report_repo.add_compliance_report(report=new_report) + report = await compliance_report_repo.create_compliance_report(report=new_report) assert isinstance(report, ComplianceReportBaseSchema) assert report.compliance_period_id == compliance_periods[0].compliance_period_id @@ -577,7 +577,7 @@ async def test_add_compliance_report_exception( new_report = ComplianceReport() with pytest.raises(DatabaseException): - await compliance_report_repo.add_compliance_report(report=new_report) + await compliance_report_repo.create_compliance_report(report=new_report) @pytest.mark.anyio diff --git a/backend/lcfs/tests/compliance_report/test_compliance_report_services.py b/backend/lcfs/tests/compliance_report/test_compliance_report_services.py index 3237762be..9300d2918 100644 --- a/backend/lcfs/tests/compliance_report/test_compliance_report_services.py +++ b/backend/lcfs/tests/compliance_report/test_compliance_report_services.py @@ -4,6 +4,7 @@ from lcfs.db.models.compliance.ComplianceReportStatus import ComplianceReportStatus from lcfs.web.exception.exceptions import ServiceException, DataNotFoundException + # get_all_compliance_periods @pytest.mark.anyio async def test_get_all_compliance_periods_success(compliance_report_service, mock_repo): @@ -41,6 +42,8 @@ async def test_create_compliance_report_success( compliance_report_base_schema, compliance_report_create_schema, ): + mock_user = MagicMock() + # Mock the compliance period mock_compliance_period = CompliancePeriod( compliance_period_id=1, @@ -57,10 +60,10 @@ async def test_create_compliance_report_success( # Mock the added compliance report mock_compliance_report = compliance_report_base_schema() - mock_repo.add_compliance_report.return_value = mock_compliance_report + mock_repo.create_compliance_report.return_value = mock_compliance_report result = await compliance_report_service.create_compliance_report( - 1, compliance_report_create_schema + 1, compliance_report_create_schema, mock_user ) assert result == mock_compliance_report @@ -70,14 +73,16 @@ async def test_create_compliance_report_success( mock_repo.get_compliance_report_status_by_desc.assert_called_once_with( compliance_report_create_schema.status ) - mock_repo.add_compliance_report.assert_called_once() + mock_repo.create_compliance_report.assert_called_once() @pytest.mark.anyio async def test_create_compliance_report_unexpected_error( compliance_report_service, mock_repo ): - mock_repo.add_compliance_report.side_effect = Exception("Unexpected error occurred") + mock_repo.create_compliance_report.side_effect = Exception( + "Unexpected error occurred" + ) with pytest.raises(ServiceException): await compliance_report_service.create_compliance_report( diff --git a/backend/lcfs/tests/services/rabbitmq/test_report_consumer.py b/backend/lcfs/tests/services/rabbitmq/test_report_consumer.py new file mode 100644 index 000000000..838c9fe0c --- /dev/null +++ b/backend/lcfs/tests/services/rabbitmq/test_report_consumer.py @@ -0,0 +1,218 @@ +import json +from contextlib import ExitStack +from unittest.mock import AsyncMock, patch, MagicMock + +import pytest +from pandas.io.formats.format import return_docstring + +from lcfs.db.models.transaction.Transaction import TransactionActionEnum, Transaction +from lcfs.services.rabbitmq.report_consumer import ( + ReportConsumer, +) +from lcfs.tests.fuel_export.conftest import mock_compliance_report_repo +from lcfs.web.api.compliance_report.schema import ComplianceReportCreateSchema + + +@pytest.fixture +def mock_app(): + """Fixture to provide a mocked FastAPI app.""" + return MagicMock() + + +@pytest.fixture +def mock_redis(): + """Fixture to mock Redis client.""" + return AsyncMock() + + +@pytest.fixture +def mock_session(): + # Create a mock session that behaves like an async context manager. + # Specifying `spec=AsyncSession` helps ensure it behaves like the real class. + from sqlalchemy.ext.asyncio import AsyncSession + + mock_session = AsyncMock(spec=AsyncSession) + + # `async with mock_session:` should work, so we define what happens on enter/exit + mock_session.__aenter__.return_value = mock_session + mock_session.__aexit__.return_value = None + + # Now mock the transaction context manager returned by `session.begin()` + mock_transaction = AsyncMock() + mock_transaction.__aenter__.return_value = mock_transaction + mock_transaction.__aexit__.return_value = None + mock_session.begin.return_value = mock_transaction + + return mock_session + + +@pytest.fixture +def mock_repositories(): + """Fixture to mock all repositories and services.""" + + mock_compliance_report_repo = MagicMock() + mock_compliance_report_repo.get_compliance_report_by_legacy_id = AsyncMock( + return_value=MagicMock() + ) + mock_compliance_report_repo.get_compliance_report_status_by_desc = AsyncMock( + return_value=MagicMock() + ) + mock_compliance_report_repo.add_compliance_report_history = AsyncMock() + + org_service = MagicMock() + org_service.adjust_balance = AsyncMock() + + mock_transaction_repo = MagicMock() + mock_transaction_repo.get_transaction_by_id = AsyncMock( + return_value=MagicMock( + spec=Transaction, transaction_action=TransactionActionEnum.Reserved + ) + ) + + return { + "compliance_report_repo": mock_compliance_report_repo, + "transaction_repo": mock_transaction_repo, + "user_repo": AsyncMock(), + "org_service": org_service, + "compliance_service": AsyncMock(), + } + + +@pytest.fixture +def setup_patches(mock_redis, mock_session, mock_repositories): + """Fixture to apply patches for dependencies.""" + with ExitStack() as stack: + stack.enter_context( + patch("redis.asyncio.Redis.from_url", return_value=mock_redis) + ) + + stack.enter_context( + patch( + "lcfs.services.rabbitmq.report_consumer.AsyncSession", + return_value=mock_session, + ) + ) + stack.enter_context( + patch("lcfs.services.rabbitmq.report_consumer.async_engine", MagicMock()) + ) + + stack.enter_context( + patch( + "lcfs.services.rabbitmq.report_consumer.ComplianceReportRepository", + return_value=mock_repositories["compliance_report_repo"], + ) + ) + stack.enter_context( + patch( + "lcfs.services.rabbitmq.report_consumer.TransactionRepository", + return_value=mock_repositories["transaction_repo"], + ) + ) + stack.enter_context( + patch( + "lcfs.services.rabbitmq.report_consumer.UserRepository", + return_value=mock_repositories["user_repo"], + ) + ) + stack.enter_context( + patch( + "lcfs.services.rabbitmq.report_consumer.OrganizationsService", + return_value=mock_repositories["org_service"], + ) + ) + stack.enter_context( + patch( + "lcfs.services.rabbitmq.report_consumer.ComplianceReportServices", + return_value=mock_repositories["compliance_service"], + ) + ) + yield stack + + +@pytest.mark.anyio +async def test_process_message_created(mock_app, setup_patches, mock_repositories): + consumer = ReportConsumer(mock_app) + + # Prepare a sample message for "Created" action + message = { + "tfrs_id": 123, + "organization_id": 1, + "compliance_period": "2023", + "nickname": "Test Report", + "action": "Created", + "user_id": 42, + } + body = json.dumps(message).encode() + + # Ensure correct mock setup + mock_user = MagicMock() + mock_repositories["user_repo"].get_user_by_id.return_value = mock_user + + await consumer.process_message(body) + + # Assertions for "Created" action + mock_repositories[ + "compliance_service" + ].create_compliance_report.assert_called_once_with( + 1, # org_id + ComplianceReportCreateSchema( + legacy_id=123, + compliance_period="2023", + organization_id=1, + nickname="Test Report", + status="Draft", + ), + mock_user, + ) + + +@pytest.mark.anyio +async def test_process_message_submitted(mock_app, setup_patches, mock_repositories): + consumer = ReportConsumer(mock_app) + + # Prepare a sample message for "Submitted" action + message = { + "tfrs_id": 123, + "organization_id": 1, + "compliance_period": "2023", + "nickname": "Test Report", + "action": "Submitted", + "credits": 50, + "user_id": 42, + } + body = json.dumps(message).encode() + + await consumer.process_message(body) + + # Assertions for "Submitted" action + mock_repositories[ + "compliance_report_repo" + ].get_compliance_report_by_legacy_id.assert_called_once_with(123) + mock_repositories["org_service"].adjust_balance.assert_called_once_with( + TransactionActionEnum.Reserved, 50, 1 + ) + mock_repositories[ + "compliance_report_repo" + ].add_compliance_report_history.assert_called_once() + + +@pytest.mark.anyio +async def test_process_message_approved(mock_app, setup_patches, mock_repositories): + consumer = ReportConsumer(mock_app) + + # Prepare a sample message for "Approved" action + message = { + "tfrs_id": 123, + "organization_id": 1, + "action": "Approved", + "user_id": 42, + } + body = json.dumps(message).encode() + + await consumer.process_message(body) + + # Assertions for "Approved" action + mock_repositories[ + "compliance_report_repo" + ].get_compliance_report_by_legacy_id.assert_called_once_with(123) + mock_repositories["transaction_repo"].confirm_transaction.assert_called_once() diff --git a/backend/lcfs/tests/services/rabbitmq/test_transaction_consumer.py b/backend/lcfs/tests/services/rabbitmq/test_transaction_consumer.py deleted file mode 100644 index 3bd8d539a..000000000 --- a/backend/lcfs/tests/services/rabbitmq/test_transaction_consumer.py +++ /dev/null @@ -1,111 +0,0 @@ -from contextlib import ExitStack - -import pytest -from unittest.mock import AsyncMock, patch, MagicMock -import json - - -from lcfs.db.models.transaction.Transaction import TransactionActionEnum -from lcfs.services.rabbitmq.transaction_consumer import ( - setup_transaction_consumer, - close_transaction_consumer, - TransactionConsumer, - consumer, - consumer_task, -) - - -@pytest.mark.anyio -async def test_setup_transaction_consumer(): - with patch( - "lcfs.services.rabbitmq.transaction_consumer.TransactionConsumer" - ) as MockConsumer: - mock_consumer = MockConsumer.return_value - mock_consumer.connect = AsyncMock() - mock_consumer.start_consuming = AsyncMock() - - await setup_transaction_consumer() - - mock_consumer.connect.assert_called_once() - mock_consumer.start_consuming.assert_called_once() - - -@pytest.mark.anyio -async def test_close_transaction_consumer(): - with patch( - "lcfs.services.rabbitmq.transaction_consumer.TransactionConsumer" - ) as MockConsumer: - mock_consumer = MockConsumer.return_value - mock_consumer.connect = AsyncMock() - mock_consumer.start_consuming = AsyncMock() - mock_consumer.close_connection = AsyncMock() - - await setup_transaction_consumer() - - await close_transaction_consumer() - - mock_consumer.close_connection.assert_called_once() - - -@pytest.mark.anyio -async def test_process_message(): - mock_redis = AsyncMock() - mock_session = AsyncMock() - mock_repo = AsyncMock() - mock_redis_balance_service = AsyncMock() - adjust_balance = AsyncMock() - - with ExitStack() as stack: - stack.enter_context( - patch("redis.asyncio.Redis.from_url", return_value=mock_redis) - ) - stack.enter_context( - patch("sqlalchemy.ext.asyncio.AsyncSession", return_value=mock_session) - ) - stack.enter_context( - patch( - "lcfs.web.api.organizations.repo.OrganizationsRepository", - return_value=mock_repo, - ) - ) - stack.enter_context( - patch( - "lcfs.web.api.transaction.repo.TransactionRepository.calculate_available_balance", - side_effect=[100, 200, 150, 250, 300, 350], - ) - ) - stack.enter_context( - patch( - "lcfs.web.api.transaction.repo.TransactionRepository.calculate_reserved_balance", - side_effect=[100, 200, 150, 250, 300, 350], - ) - ) - stack.enter_context( - patch( - "lcfs.services.tfrs.redis_balance.RedisBalanceService", - return_value=mock_redis_balance_service, - ) - ) - stack.enter_context( - patch( - "lcfs.web.api.organizations.services.OrganizationsService.adjust_balance", - adjust_balance, - ) - ) - - # Create an instance of the consumer - consumer = TransactionConsumer() - - # Prepare a sample message - message = { - "compliance_units_amount": 100, - "organization_id": 1, - } - body = json.dumps(message).encode() - - mock_request = AsyncMock() - - await consumer.process_message(body, mock_request) - - # Assert that the organization service's adjust_balance method was called correctly - adjust_balance.assert_called_once_with(TransactionActionEnum.Adjustment, 100, 1) diff --git a/backend/lcfs/web/api/compliance_report/repo.py b/backend/lcfs/web/api/compliance_report/repo.py index 194afb8d0..1bac62843 100644 --- a/backend/lcfs/web/api/compliance_report/repo.py +++ b/backend/lcfs/web/api/compliance_report/repo.py @@ -2,6 +2,8 @@ from typing import List, Optional, Dict from collections import defaultdict from datetime import datetime + +from lcfs.db.models import UserProfile from lcfs.db.models.organization.Organization import Organization from lcfs.db.models.fuel.FuelType import FuelType from lcfs.db.models.fuel.FuelCategory import FuelCategory @@ -15,7 +17,6 @@ PaginationRequestSchema, apply_filter_conditions, get_field_for_filter, - get_enum_value, ) from lcfs.db.models.compliance import CompliancePeriod from lcfs.db.models.compliance.ComplianceReport import ( @@ -181,7 +182,9 @@ async def check_compliance_report( ) @repo_handler - async def get_compliance_report_status_by_desc(self, status: str) -> int: + async def get_compliance_report_status_by_desc( + self, status: str + ) -> ComplianceReportStatus: """ Retrieve the compliance report status ID from the database based on the description. Replaces spaces with underscores in the status description. @@ -266,7 +269,7 @@ async def get_assessed_compliance_report_by_period( return result @repo_handler - async def add_compliance_report(self, report: ComplianceReport): + async def create_compliance_report(self, report: ComplianceReport): """ Add a new compliance report to the database """ @@ -304,7 +307,9 @@ async def get_compliance_report_history(self, report: ComplianceReport): return history.scalar_one_or_none() @repo_handler - async def add_compliance_report_history(self, report: ComplianceReport, user): + async def add_compliance_report_history( + self, report: ComplianceReport, user: UserProfile + ): """ Add a new compliance report history record to the database """ @@ -823,3 +828,26 @@ async def get_latest_report_by_group_uuid( .limit(1) ) return result.scalars().first() + + async def get_compliance_report_by_legacy_id(self, legacy_id): + """ + Retrieve a compliance report from the database by ID + """ + result = await self.db.execute( + select(ComplianceReport) + .options( + joinedload(ComplianceReport.organization), + joinedload(ComplianceReport.compliance_period), + joinedload(ComplianceReport.current_status), + joinedload(ComplianceReport.summary), + joinedload(ComplianceReport.history).joinedload( + ComplianceReportHistory.status + ), + joinedload(ComplianceReport.history).joinedload( + ComplianceReportHistory.user_profile + ), + joinedload(ComplianceReport.transaction), + ) + .where(ComplianceReport.legacy_id == legacy_id) + ) + return result.scalars().unique().first() diff --git a/backend/lcfs/web/api/compliance_report/schema.py b/backend/lcfs/web/api/compliance_report/schema.py index 9eb215c53..34696dee2 100644 --- a/backend/lcfs/web/api/compliance_report/schema.py +++ b/backend/lcfs/web/api/compliance_report/schema.py @@ -148,7 +148,6 @@ class ComplianceReportBaseSchema(BaseSchema): current_status_id: int current_status: ComplianceReportStatusSchema transaction_id: Optional[int] = None - # transaction: Optional[TransactionBaseSchema] = None nickname: Optional[str] = None supplemental_note: Optional[str] = None reporting_frequency: Optional[ReportingFrequency] = None @@ -166,6 +165,8 @@ class ComplianceReportCreateSchema(BaseSchema): compliance_period: str organization_id: int status: str + legacy_id: Optional[int] = None + nickname: Optional[str] = None class ComplianceReportListSchema(BaseSchema): diff --git a/backend/lcfs/web/api/compliance_report/services.py b/backend/lcfs/web/api/compliance_report/services.py index 31993bc75..dac78edd9 100644 --- a/backend/lcfs/web/api/compliance_report/services.py +++ b/backend/lcfs/web/api/compliance_report/services.py @@ -27,10 +27,7 @@ class ComplianceReportServices: - def __init__( - self, request: Request = None, repo: ComplianceReportRepository = Depends() - ) -> None: - self.request = request + def __init__(self, repo: ComplianceReportRepository = Depends()) -> None: self.repo = repo @service_handler @@ -41,7 +38,10 @@ async def get_all_compliance_periods(self) -> List[CompliancePeriodSchema]: @service_handler async def create_compliance_report( - self, organization_id: int, report_data: ComplianceReportCreateSchema + self, + organization_id: int, + report_data: ComplianceReportCreateSchema, + user: UserProfile, ) -> ComplianceReportBaseSchema: """Creates a new compliance report.""" period = await self.repo.get_compliance_period(report_data.compliance_period) @@ -52,8 +52,7 @@ async def create_compliance_report( report_data.status ) if not draft_status: - raise DataNotFoundException( - f"Status '{report_data.status}' not found.") + raise DataNotFoundException(f"Status '{report_data.status}' not found.") # Generate a new group_uuid for the new report series group_uuid = str(uuid.uuid4()) @@ -65,15 +64,17 @@ async def create_compliance_report( reporting_frequency=ReportingFrequency.ANNUAL, compliance_report_group_uuid=group_uuid, # New group_uuid for the series version=0, # Start with version 0 - nickname="Original Report", + nickname=report_data.nickname or "Original Report", summary=ComplianceReportSummary(), # Create an empty summary object + legacy_id=report_data.legacy_id, + create_user=user.keycloak_username, ) # Add the new compliance report - report = await self.repo.add_compliance_report(report) + report = await self.repo.create_compliance_report(report) # Create the history record - await self.repo.add_compliance_report_history(report, self.request.user) + await self.repo.add_compliance_report_history(report, user) return ComplianceReportBaseSchema.model_validate(report) @@ -137,7 +138,7 @@ async def create_supplemental_report( ) # Add the new supplemental report - new_report = await self.repo.add_compliance_report(new_report) + new_report = await self.repo.create_compliance_report(new_report) # Create the history record for the new supplemental report await self.repo.add_compliance_report_history(new_report, user) @@ -228,8 +229,7 @@ async def get_compliance_report_by_id( if apply_masking: # Apply masking to each report in the chain - masked_chain = self._mask_report_status( - compliance_report_chain) + masked_chain = self._mask_report_status(compliance_report_chain) # Apply history masking to each report in the chain masked_chain = [ self._mask_report_status_for_history(report, apply_masking) diff --git a/backend/lcfs/web/api/organization/views.py b/backend/lcfs/web/api/organization/views.py index e175bf756..a33cdd984 100644 --- a/backend/lcfs/web/api/organization/views.py +++ b/backend/lcfs/web/api/organization/views.py @@ -33,7 +33,7 @@ ComplianceReportCreateSchema, ComplianceReportListSchema, CompliancePeriodSchema, - ChainedComplianceReportSchema + ChainedComplianceReportSchema, ) from lcfs.web.api.compliance_report.services import ComplianceReportServices from .services import OrganizationService @@ -56,8 +56,7 @@ async def get_org_users( request: Request, organization_id: int, - status: str = Query( - default="Active", description="Active or Inactive users list"), + status: str = Query(default="Active", description="Active or Inactive users list"), pagination: PaginationRequestSchema = Body(..., embed=False), response: Response = None, org_service: OrganizationService = Depends(), @@ -249,7 +248,9 @@ async def create_compliance_report( validate: OrganizationValidation = Depends(), ): await validate.create_compliance_report(organization_id, report_data) - return await report_service.create_compliance_report(organization_id, report_data) + return await report_service.create_compliance_report( + organization_id, report_data, request.user + ) @router.post( @@ -307,4 +308,6 @@ async def get_compliance_report_by_id( This endpoint returns the information of a user by ID, including their roles and organization. """ await report_validate.validate_organization_access(report_id) - return await report_service.get_compliance_report_by_id(report_id, apply_masking=True, get_chain=True) + return await report_service.get_compliance_report_by_id( + report_id, apply_masking=True, get_chain=True + ) diff --git a/backend/lcfs/web/api/organizations/services.py b/backend/lcfs/web/api/organizations/services.py index e8ef43620..35c2155a3 100644 --- a/backend/lcfs/web/api/organizations/services.py +++ b/backend/lcfs/web/api/organizations/services.py @@ -16,6 +16,7 @@ OrganizationStatus, OrgStatusEnum, ) +from lcfs.db.models.transaction import Transaction from lcfs.db.models.transaction.Transaction import TransactionActionEnum from lcfs.services.tfrs.redis_balance import ( RedisBalanceService, @@ -44,6 +45,7 @@ logger = structlog.get_logger(__name__) + class OrganizationsService: def __init__( self, @@ -198,7 +200,6 @@ async def update_organization( updated_organization = await self.repo.update_organization(organization) return updated_organization - @service_handler async def get_organization(self, organization_id: int): """handles fetching an organization""" @@ -400,7 +401,7 @@ async def adjust_balance( transaction_action: TransactionActionEnum, compliance_units: int, organization_id: int, - ): + ) -> Transaction: """ Adjusts an organization's balance based on the transaction action. diff --git a/backend/lcfs/web/api/transaction/repo.py b/backend/lcfs/web/api/transaction/repo.py index 7134e7332..861b1e32b 100644 --- a/backend/lcfs/web/api/transaction/repo.py +++ b/backend/lcfs/web/api/transaction/repo.py @@ -318,7 +318,7 @@ async def create_transaction( transaction_action: TransactionActionEnum, compliance_units: int, organization_id: int, - ): + ) -> Transaction: """ Creates and saves a new transaction to the database. diff --git a/backend/lcfs/web/lifetime.py b/backend/lcfs/web/lifetime.py index 5de67c16c..fbe7b0b6e 100644 --- a/backend/lcfs/web/lifetime.py +++ b/backend/lcfs/web/lifetime.py @@ -62,7 +62,7 @@ async def _startup() -> None: # noqa: WPS430 await init_org_balance_cache(app) # Setup RabbitMQ Listeners - await start_consumers() + await start_consumers(app) return _startup