diff --git a/.github/workflows/dev-ci.yml b/.github/workflows/dev-ci.yml index ac8071561..6de1db940 100644 --- a/.github/workflows/dev-ci.yml +++ b/.github/workflows/dev-ci.yml @@ -23,9 +23,36 @@ concurrency: jobs: + install-oc: + runs-on: ubuntu-latest + outputs: + cache-hit: ${{ steps.cache.outputs.cache-hit }} + steps: + - name: Check out repository + uses: actions/checkout@v4.1.1 + + - name: Set up cache for OpenShift CLI + id: cache + uses: actions/cache@v4.2.0 + with: + path: /usr/local/bin/oc # Path where the `oc` binary will be installed + key: oc-cli-${{ runner.os }} + + - name: Install OpenShift CLI (if not cached) + if: steps.cache.outputs.cache-hit != 'true' + run: | + curl -LO https://mirror.openshift.com/pub/openshift-v4/clients/ocp/stable/openshift-client-linux.tar.gz + tar -xvf openshift-client-linux.tar.gz + sudo mv oc /usr/local/bin/ + oc version --client + + - name: Confirm OpenShift CLI is Available + run: oc version --client + set-pre-release: name: Calculate pre-release number runs-on: ubuntu-latest + needs: [install-oc] outputs: output1: ${{ steps.set-pre-release.outputs.PRE_RELEASE }} @@ -49,6 +76,12 @@ jobs: - name: Check out repository uses: actions/checkout@v4.1.1 + - name: Restore oc command from Cache + uses: actions/cache@v4.2.0 + with: + path: /usr/local/bin/oc + key: oc-cli-${{ runner.os }} + - name: Log in to Openshift uses: redhat-actions/oc-login@v1.3 with: diff --git a/.github/workflows/pr-build.yaml b/.github/workflows/pr-build.yaml index f351a91c5..7086a5834 100644 --- a/.github/workflows/pr-build.yaml +++ b/.github/workflows/pr-build.yaml @@ -19,17 +19,51 @@ concurrency: cancel-in-progress: true jobs: + install-oc: + runs-on: ubuntu-latest + outputs: + cache-hit: ${{ steps.cache.outputs.cache-hit }} + steps: + - name: Check out repository + uses: actions/checkout@v4.1.1 + + - name: Set up cache for OpenShift CLI + id: cache + uses: actions/cache@v4.2.0 + with: + path: /usr/local/bin/oc # Path where the `oc` binary will be installed + key: oc-cli-${{ runner.os }} + + - name: Install OpenShift CLI (if not cached) + if: steps.cache.outputs.cache-hit != 'true' + run: | + curl -LO https://mirror.openshift.com/pub/openshift-v4/clients/ocp/stable/openshift-client-linux.tar.gz + tar -xvf openshift-client-linux.tar.gz + sudo mv oc /usr/local/bin/ + oc version --client + + - name: Confirm OpenShift CLI is Available + run: oc version --client + get-version: if: > (github.event.action == 'labeled' && github.event.label.name == 'build' && github.event.pull_request.base.ref == github.event.repository.default_branch) || (github.event.action == 'synchronize' && contains(github.event.pull_request.labels.*.name, 'build') && github.event.pull_request.base.ref == github.event.repository.default_branch) name: Retrieve version runs-on: ubuntu-latest + needs: [install-oc] outputs: output1: ${{ steps.get-version.outputs.VERSION }} steps: + + - name: Restore oc command from Cache + uses: actions/cache@v4.2.0 + with: + path: /usr/local/bin/oc + key: oc-cli-${{ runner.os }} + - name: Log in to Openshift uses: redhat-actions/oc-login@v1.3 with: @@ -69,6 +103,12 @@ jobs: with: ref: ${{ github.event.pull_request.head.ref }} + - name: Restore oc command from Cache + uses: actions/cache@v4.2.0 + with: + path: /usr/local/bin/oc + key: oc-cli-${{ runner.os }} + - name: Log in to Openshift uses: redhat-actions/oc-login@v1.3 with: @@ -123,6 +163,12 @@ jobs: ref: main ssh-key: ${{ secrets.MANIFEST_REPO_DEPLOY_KEY }} + - name: Restore oc command from Cache + uses: actions/cache@v4.2.0 + with: + path: /usr/local/bin/oc + key: oc-cli-${{ runner.os }} + - name: Log in to Openshift uses: redhat-actions/oc-login@v1.3 with: diff --git a/.github/workflows/pr-teardown.yaml b/.github/workflows/pr-teardown.yaml index 783c221a8..201c08e04 100644 --- a/.github/workflows/pr-teardown.yaml +++ b/.github/workflows/pr-teardown.yaml @@ -13,6 +13,31 @@ concurrency: cancel-in-progress: true jobs: + install-oc: + runs-on: ubuntu-latest + outputs: + cache-hit: ${{ steps.cache.outputs.cache-hit }} + steps: + - name: Check out repository + uses: actions/checkout@v4.1.1 + + - name: Set up cache for OpenShift CLI + id: cache + uses: actions/cache@v4.2.0 + with: + path: /usr/local/bin/oc # Path where the `oc` binary will be installed + key: oc-cli-${{ runner.os }} + + - name: Install OpenShift CLI (if not cached) + if: steps.cache.outputs.cache-hit != 'true' + run: | + curl -LO https://mirror.openshift.com/pub/openshift-v4/clients/ocp/stable/openshift-client-linux.tar.gz + tar -xvf openshift-client-linux.tar.gz + sudo mv oc /usr/local/bin/ + oc version --client + + - name: Confirm OpenShift CLI is Available + run: oc version --client teardown: if: > @@ -20,9 +45,16 @@ jobs: (github.event.action == 'closed' && contains(github.event.pull_request.labels.*.name, 'build') ) name: PR Teardown runs-on: ubuntu-latest + needs: [install-oc] timeout-minutes: 60 steps: + + - name: Restore oc command from Cache + uses: actions/cache@v4.2.0 + with: + path: /usr/local/bin/oc + key: oc-cli-${{ runner.os }} - name: Log in to Openshift uses: redhat-actions/oc-login@v1.3 diff --git a/.github/workflows/prod-ci.yaml b/.github/workflows/prod-ci.yaml index 3478be8ff..b3a1eab61 100644 --- a/.github/workflows/prod-ci.yaml +++ b/.github/workflows/prod-ci.yaml @@ -14,12 +14,38 @@ concurrency: cancel-in-progress: true jobs: + install-oc: + runs-on: ubuntu-latest + outputs: + cache-hit: ${{ steps.cache.outputs.cache-hit }} + steps: + - name: Check out repository + uses: actions/checkout@v4.1.1 + + - name: Set up cache for OpenShift CLI + id: cache + uses: actions/cache@v4.2.0 + with: + path: /usr/local/bin/oc # Path where the `oc` binary will be installed + key: oc-cli-${{ runner.os }} + + - name: Install OpenShift CLI (if not cached) + if: steps.cache.outputs.cache-hit != 'true' + run: | + curl -LO https://mirror.openshift.com/pub/openshift-v4/clients/ocp/stable/openshift-client-linux.tar.gz + tar -xvf openshift-client-linux.tar.gz + sudo mv oc /usr/local/bin/ + oc version --client + + - name: Confirm OpenShift CLI is Available + run: oc version --client # Read the image tag from test environment get-image-tag: name: Get the image-tag from values-test.yaml runs-on: ubuntu-latest + needs: [install-oc] outputs: IMAGE_TAG: ${{ steps.get-image-tag.outputs.IMAGE_TAG }} @@ -84,6 +110,12 @@ jobs: approvers: AlexZorkin,kuanfandevops,hamed-valiollahi,airinggov,areyeslo,dhaselhan,Grulin minimum-approvals: 2 issue-title: "LCFS ${{env.IMAGE_TAG }} Prod Deployment at ${{ env.CURRENT_TIME }}." + + - name: Restore oc command from Cache + uses: actions/cache@v4.2.0 + with: + path: /usr/local/bin/oc + key: oc-cli-${{ runner.os }} - name: Log in to Openshift uses: redhat-actions/oc-login@v1.3 diff --git a/.github/workflows/test-ci.yaml b/.github/workflows/test-ci.yaml index 1119b9432..80d6690f2 100644 --- a/.github/workflows/test-ci.yaml +++ b/.github/workflows/test-ci.yaml @@ -14,9 +14,36 @@ concurrency: cancel-in-progress: true jobs: + install-oc: + runs-on: ubuntu-latest + outputs: + cache-hit: ${{ steps.cache.outputs.cache-hit }} + steps: + - name: Check out repository + uses: actions/checkout@v4.1.1 + + - name: Set up cache for OpenShift CLI + id: cache + uses: actions/cache@v4.2.0 + with: + path: /usr/local/bin/oc # Path where the `oc` binary will be installed + key: oc-cli-${{ runner.os }} + + - name: Install OpenShift CLI (if not cached) + if: steps.cache.outputs.cache-hit != 'true' + run: | + curl -LO https://mirror.openshift.com/pub/openshift-v4/clients/ocp/stable/openshift-client-linux.tar.gz + tar -xvf openshift-client-linux.tar.gz + sudo mv oc /usr/local/bin/ + oc version --client + + - name: Confirm OpenShift CLI is Available + run: oc version --client + run-tests: name: Run Tests runs-on: ubuntu-latest + needs: [install-oc] steps: - uses: actions/checkout@v3 @@ -229,6 +256,12 @@ jobs: minimum-approvals: 1 issue-title: "LCFS ${{ env.VERSION }}-${{ env.PRE_RELEASE }} Test Deployment" + - name: Restore oc command from Cache + uses: actions/cache@v4.2.0 + with: + path: /usr/local/bin/oc + key: oc-cli-${{ runner.os }} + - name: Log in to Openshift uses: redhat-actions/oc-login@v1.3 with: diff --git a/backend/lcfs/db/migrations/versions/2024-12-17-11-23_f93546eaec61.py b/backend/lcfs/db/migrations/versions/2024-12-17-11-23_f93546eaec61.py new file mode 100644 index 000000000..4fbabc280 --- /dev/null +++ b/backend/lcfs/db/migrations/versions/2024-12-17-11-23_f93546eaec61.py @@ -0,0 +1,33 @@ +"""update notification message model + +Revision ID: f93546eaec61 +Revises: 5d729face5ab +Create Date: 2024-12-17 11:23:19.563138 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "f93546eaec61" +down_revision = "5d729face5ab" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("notification_message", sa.Column("type", sa.Text(), nullable=False)) + op.add_column( + "notification_message", + sa.Column("related_transaction_id", sa.Text(), nullable=False), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("notification_message", "related_transaction_id") + op.drop_column("notification_message", "type") + # ### end Alembic commands ### diff --git a/backend/lcfs/db/migrations/versions/2024-12-17-12-25_5b374dd97469.py b/backend/lcfs/db/migrations/versions/2024-12-17-12-25_5b374dd97469.py new file mode 100644 index 000000000..3c7475040 --- /dev/null +++ b/backend/lcfs/db/migrations/versions/2024-12-17-12-25_5b374dd97469.py @@ -0,0 +1,36 @@ +"""Add legacy id to compliance reports + +Revision ID: 5b374dd97469 +Revises: f93546eaec61 +Create Date: 2024-17-13 12:25:32.076684 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "5b374dd97469" +down_revision = "f93546eaec61" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "compliance_report", + sa.Column( + "legacy_id", + sa.Integer(), + nullable=True, + comment="ID from TFRS if this is a transferred application, NULL otherwise", + ), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("compliance_report", "legacy_id") + # ### end Alembic commands ### diff --git a/backend/lcfs/db/models/compliance/ComplianceReport.py b/backend/lcfs/db/models/compliance/ComplianceReport.py index d88e023d2..6656cc6ec 100644 --- a/backend/lcfs/db/models/compliance/ComplianceReport.py +++ b/backend/lcfs/db/models/compliance/ComplianceReport.py @@ -100,6 +100,11 @@ class ComplianceReport(BaseModel, Auditable): default=lambda: str(uuid.uuid4()), comment="UUID that groups all versions of a compliance report", ) + legacy_id = Column( + Integer, + nullable=True, + comment="ID from TFRS if this is a transferred application, NULL otherwise", + ) version = Column( Integer, nullable=False, diff --git a/backend/lcfs/db/models/notification/NotificationMessage.py b/backend/lcfs/db/models/notification/NotificationMessage.py index b28919330..fddc1a961 100644 --- a/backend/lcfs/db/models/notification/NotificationMessage.py +++ b/backend/lcfs/db/models/notification/NotificationMessage.py @@ -20,6 +20,7 @@ class NotificationMessage(BaseModel, Auditable): is_warning = Column(Boolean, default=False) is_error = Column(Boolean, default=False) is_archived = Column(Boolean, default=False) + type = Column(Text, nullable=False) message = Column(Text, nullable=False) related_organization_id = Column( @@ -32,12 +33,9 @@ class NotificationMessage(BaseModel, Auditable): notification_type_id = Column( Integer, ForeignKey("notification_type.notification_type_id") ) + related_transaction_id = Column(Text, nullable=False) - # Models not created yet - # related_transaction_id = Column(Integer,ForeignKey('')) - # related_document_id = Column(Integer, ForeignKey('document.id')) - # related_report_id = Column(Integer, ForeignKey('compliance_report.id')) - + # Relationships related_organization = relationship( "Organization", back_populates="notification_messages" ) diff --git a/backend/lcfs/services/rabbitmq/base_consumer.py b/backend/lcfs/services/rabbitmq/base_consumer.py index 26bc3ebdd..9a80bf56a 100644 --- a/backend/lcfs/services/rabbitmq/base_consumer.py +++ b/backend/lcfs/services/rabbitmq/base_consumer.py @@ -3,6 +3,7 @@ import aio_pika from aio_pika.abc import AbstractChannel, AbstractQueue +from fastapi import FastAPI from lcfs.settings import settings @@ -12,11 +13,12 @@ class BaseConsumer: - def __init__(self, queue_name=None): + def __init__(self, app: FastAPI, queue_name: str): self.connection = None self.channel = None self.queue = None self.queue_name = queue_name + self.app = app async def connect(self): """Connect to RabbitMQ and set up the consumer.""" @@ -42,7 +44,6 @@ async def start_consuming(self): async with message.process(): logger.debug(f"Received message: {message.body.decode()}") await self.process_message(message.body) - logger.debug("Message Processed") async def process_message(self, body: bytes): """Process the incoming message. Override this method in subclasses.""" diff --git a/backend/lcfs/services/rabbitmq/consumers.py b/backend/lcfs/services/rabbitmq/consumers.py index de934c00f..17cfdf193 100644 --- a/backend/lcfs/services/rabbitmq/consumers.py +++ b/backend/lcfs/services/rabbitmq/consumers.py @@ -1,14 +1,14 @@ import asyncio -from lcfs.services.rabbitmq.transaction_consumer import ( - setup_transaction_consumer, - close_transaction_consumer, +from lcfs.services.rabbitmq.report_consumer import ( + setup_report_consumer, + close_report_consumer, ) -async def start_consumers(): - await setup_transaction_consumer() +async def start_consumers(app): + await setup_report_consumer(app) async def stop_consumers(): - await close_transaction_consumer() + await close_report_consumer() diff --git a/backend/lcfs/services/rabbitmq/report_consumer.py b/backend/lcfs/services/rabbitmq/report_consumer.py new file mode 100644 index 000000000..f03df28ee --- /dev/null +++ b/backend/lcfs/services/rabbitmq/report_consumer.py @@ -0,0 +1,292 @@ +import asyncio +import json +import logging +from typing import Optional + +from fastapi import FastAPI +from sqlalchemy.ext.asyncio import AsyncSession + +from lcfs.db.dependencies import async_engine +from lcfs.db.models.compliance.ComplianceReportStatus import ComplianceReportStatusEnum +from lcfs.db.models.transaction.Transaction import TransactionActionEnum +from lcfs.db.models.user import UserProfile +from lcfs.services.rabbitmq.base_consumer import BaseConsumer +from lcfs.services.tfrs.redis_balance import RedisBalanceService +from lcfs.settings import settings +from lcfs.web.api.compliance_report.repo import ComplianceReportRepository +from lcfs.web.api.compliance_report.schema import ComplianceReportCreateSchema +from lcfs.web.api.compliance_report.services import ComplianceReportServices +from lcfs.web.api.organizations.repo import OrganizationsRepository +from lcfs.web.api.organizations.services import OrganizationsService +from lcfs.web.api.transaction.repo import TransactionRepository +from lcfs.web.api.user.repo import UserRepository +from lcfs.web.exception.exceptions import ServiceException + +logger = logging.getLogger(__name__) + +consumer = None +consumer_task = None + +VALID_ACTIONS = {"Created", "Submitted", "Approved"} + + +async def setup_report_consumer(app: FastAPI): + """ + Set up the report consumer and start consuming messages. + """ + global consumer, consumer_task + consumer = ReportConsumer(app) + await consumer.connect() + consumer_task = asyncio.create_task(consumer.start_consuming()) + + +async def close_report_consumer(): + """ + Cancel the consumer task if it exists and close the consumer connection. + """ + global consumer, consumer_task + + if consumer_task: + consumer_task.cancel() + + if consumer: + await consumer.close_connection() + + +class ReportConsumer(BaseConsumer): + """ + A consumer for handling TFRS compliance report messages from a RabbitMQ queue. + """ + + def __init__( + self, app: FastAPI, queue_name: str = settings.rabbitmq_transaction_queue + ): + super().__init__(app, queue_name) + + async def process_message(self, body: bytes): + """ + Process an incoming message from the queue. + + Expected message structure: + { + "tfrs_id": int, + "organization_id": int, + "compliance_period": str, + "nickname": str, + "action": "Created"|"Submitted"|"Approved", + "credits": int (optional), + "user_id": int + } + """ + message = self._parse_message(body) + if not message: + return # Invalid message already logged + + action = message["action"] + org_id = message["organization_id"] + + if action not in VALID_ACTIONS: + logger.error(f"Invalid action '{action}' in message.") + return + + logger.info(f"Received '{action}' action from TFRS for Org {org_id}") + + try: + await self.handle_message( + action=action, + compliance_period=message.get("compliance_period"), + compliance_units=message.get("credits"), + legacy_id=message["tfrs_id"], + nickname=message.get("nickname"), + org_id=org_id, + user_id=message["user_id"], + ) + except Exception: + logger.exception("Failed to handle message") + + def _parse_message(self, body: bytes) -> Optional[dict]: + """ + Parse the message body into a dictionary. + Log and return None if parsing fails or required fields are missing. + """ + try: + message_content = json.loads(body.decode()) + except json.JSONDecodeError: + logger.error("Failed to decode message body as JSON.") + return None + + required_fields = ["tfrs_id", "organization_id", "action", "user_id"] + if any(field not in message_content for field in required_fields): + logger.error("Message missing required fields.") + return None + + return message_content + + async def handle_message( + self, + action: str, + compliance_period: str, + compliance_units: Optional[int], + legacy_id: int, + nickname: Optional[str], + org_id: int, + user_id: int, + ): + """ + Handle a given message action by loading dependencies and calling the respective handler. + """ + redis_client = self.app.state.redis_client + + async with AsyncSession(async_engine) as session: + async with session.begin(): + # Initialize repositories and services + org_repo = OrganizationsRepository(db=session) + transaction_repo = TransactionRepository(db=session) + redis_balance_service = RedisBalanceService( + transaction_repo=transaction_repo, redis_client=redis_client + ) + org_service = OrganizationsService( + repo=org_repo, + transaction_repo=transaction_repo, + redis_balance_service=redis_balance_service, + ) + compliance_report_repo = ComplianceReportRepository(db=session) + compliance_report_service = ComplianceReportServices( + repo=compliance_report_repo + ) + user = await UserRepository(db=session).get_user_by_id(user_id) + + if action == "Created": + await self._handle_created( + org_id, + legacy_id, + compliance_period, + nickname, + user, + compliance_report_service, + ) + elif action == "Submitted": + await self._handle_submitted( + compliance_report_repo, + compliance_units, + legacy_id, + org_id, + org_service, + session, + user, + ) + elif action == "Approved": + await self._handle_approved( + legacy_id, + compliance_report_repo, + transaction_repo, + user, + session, + ) + + async def _handle_created( + self, + org_id: int, + legacy_id: int, + compliance_period: str, + nickname: str, + user: UserProfile, + compliance_report_service: ComplianceReportServices, + ): + """ + Handle the 'Created' action by creating a new compliance report draft. + """ + lcfs_report = ComplianceReportCreateSchema( + legacy_id=legacy_id, + compliance_period=compliance_period, + organization_id=org_id, + nickname=nickname, + status=ComplianceReportStatusEnum.Draft.value, + ) + await compliance_report_service.create_compliance_report( + org_id, lcfs_report, user + ) + + async def _handle_approved( + self, + legacy_id: int, + compliance_report_repo: ComplianceReportRepository, + transaction_repo: TransactionRepository, + user: UserProfile, + session: AsyncSession, + ): + """ + Handle the 'Approved' action by updating the report status to 'Assessed' + and confirming the associated transaction. + """ + existing_report = ( + await compliance_report_repo.get_compliance_report_by_legacy_id(legacy_id) + ) + if not existing_report: + raise ServiceException( + f"No compliance report found for legacy ID {legacy_id}" + ) + + new_status = await compliance_report_repo.get_compliance_report_status_by_desc( + ComplianceReportStatusEnum.Assessed.value + ) + existing_report.current_status_id = new_status.compliance_report_status_id + session.add(existing_report) + await session.flush() + + await compliance_report_repo.add_compliance_report_history( + existing_report, user + ) + + existing_transaction = await transaction_repo.get_transaction_by_id( + existing_report.transaction_id + ) + if not existing_transaction: + raise ServiceException( + "Compliance Report does not have an associated transaction" + ) + + if existing_transaction.transaction_action != TransactionActionEnum.Reserved: + raise ServiceException( + f"Transaction {existing_transaction.transaction_id} is not in 'Reserved' status" + ) + + await transaction_repo.confirm_transaction(existing_transaction.transaction_id) + + async def _handle_submitted( + self, + compliance_report_repo: ComplianceReportRepository, + compliance_units: int, + legacy_id: int, + org_id: int, + org_service: OrganizationsService, + session: AsyncSession, + user: UserProfile, + ): + """ + Handle the 'Submitted' action by linking a reserved transaction + to the compliance report and updating its status. + """ + existing_report = ( + await compliance_report_repo.get_compliance_report_by_legacy_id(legacy_id) + ) + if not existing_report: + raise ServiceException( + f"No compliance report found for legacy ID {legacy_id}" + ) + + transaction = await org_service.adjust_balance( + TransactionActionEnum.Reserved, compliance_units, org_id + ) + existing_report.transaction_id = transaction.transaction_id + + new_status = await compliance_report_repo.get_compliance_report_status_by_desc( + ComplianceReportStatusEnum.Submitted.value + ) + existing_report.current_status_id = new_status.compliance_report_status_id + session.add(existing_report) + await session.flush() + + await compliance_report_repo.add_compliance_report_history( + existing_report, user + ) diff --git a/backend/lcfs/services/rabbitmq/transaction_consumer.py b/backend/lcfs/services/rabbitmq/transaction_consumer.py deleted file mode 100644 index 381142d6c..000000000 --- a/backend/lcfs/services/rabbitmq/transaction_consumer.py +++ /dev/null @@ -1,71 +0,0 @@ -import asyncio -import json -import logging - -from redis.asyncio import Redis -from sqlalchemy.ext.asyncio import AsyncSession -from lcfs.services.redis.dependency import get_redis_client -from fastapi import Request - -from lcfs.db.dependencies import async_engine -from lcfs.db.models.transaction.Transaction import TransactionActionEnum -from lcfs.services.rabbitmq.base_consumer import BaseConsumer -from lcfs.services.tfrs.redis_balance import RedisBalanceService -from lcfs.settings import settings -from lcfs.web.api.organizations.repo import OrganizationsRepository -from lcfs.web.api.organizations.services import OrganizationsService -from lcfs.web.api.transaction.repo import TransactionRepository - -logger = logging.getLogger(__name__) -consumer = None -consumer_task = None - - -async def setup_transaction_consumer(): - global consumer, consumer_task - consumer = TransactionConsumer() - await consumer.connect() - consumer_task = asyncio.create_task(consumer.start_consuming()) - - -async def close_transaction_consumer(): - global consumer, consumer_task - - if consumer_task: - consumer_task.cancel() - - if consumer: - await consumer.close_connection() - - -class TransactionConsumer(BaseConsumer): - def __init__( - self, - queue_name=settings.rabbitmq_transaction_queue, - ): - super().__init__(queue_name) - - async def process_message(self, body: bytes, request: Request): - message_content = json.loads(body.decode()) - compliance_units = message_content.get("compliance_units_amount") - org_id = message_content.get("organization_id") - - redis_client = await get_redis_client(request) - - async with AsyncSession(async_engine) as session: - async with session.begin(): - repo = OrganizationsRepository(db=session) - transaction_repo = TransactionRepository(db=session) - redis_balance_service = RedisBalanceService( - transaction_repo=transaction_repo, redis_client=redis_client - ) - org_service = OrganizationsService( - repo=repo, - transaction_repo=transaction_repo, - redis_balance_service=redis_balance_service, - ) - - await org_service.adjust_balance( - TransactionActionEnum.Adjustment, compliance_units, org_id - ) - logger.debug(f"Processed Transaction from TFRS for Org {org_id}") diff --git a/backend/lcfs/tests/compliance_report/test_compliance_report_repo.py b/backend/lcfs/tests/compliance_report/test_compliance_report_repo.py index c26603dd8..84ed2520b 100644 --- a/backend/lcfs/tests/compliance_report/test_compliance_report_repo.py +++ b/backend/lcfs/tests/compliance_report/test_compliance_report_repo.py @@ -558,7 +558,7 @@ async def test_add_compliance_report_success( version=1, ) - report = await compliance_report_repo.add_compliance_report(report=new_report) + report = await compliance_report_repo.create_compliance_report(report=new_report) assert isinstance(report, ComplianceReportBaseSchema) assert report.compliance_period_id == compliance_periods[0].compliance_period_id @@ -577,7 +577,7 @@ async def test_add_compliance_report_exception( new_report = ComplianceReport() with pytest.raises(DatabaseException): - await compliance_report_repo.add_compliance_report(report=new_report) + await compliance_report_repo.create_compliance_report(report=new_report) @pytest.mark.anyio diff --git a/backend/lcfs/tests/compliance_report/test_compliance_report_services.py b/backend/lcfs/tests/compliance_report/test_compliance_report_services.py index 3237762be..9300d2918 100644 --- a/backend/lcfs/tests/compliance_report/test_compliance_report_services.py +++ b/backend/lcfs/tests/compliance_report/test_compliance_report_services.py @@ -4,6 +4,7 @@ from lcfs.db.models.compliance.ComplianceReportStatus import ComplianceReportStatus from lcfs.web.exception.exceptions import ServiceException, DataNotFoundException + # get_all_compliance_periods @pytest.mark.anyio async def test_get_all_compliance_periods_success(compliance_report_service, mock_repo): @@ -41,6 +42,8 @@ async def test_create_compliance_report_success( compliance_report_base_schema, compliance_report_create_schema, ): + mock_user = MagicMock() + # Mock the compliance period mock_compliance_period = CompliancePeriod( compliance_period_id=1, @@ -57,10 +60,10 @@ async def test_create_compliance_report_success( # Mock the added compliance report mock_compliance_report = compliance_report_base_schema() - mock_repo.add_compliance_report.return_value = mock_compliance_report + mock_repo.create_compliance_report.return_value = mock_compliance_report result = await compliance_report_service.create_compliance_report( - 1, compliance_report_create_schema + 1, compliance_report_create_schema, mock_user ) assert result == mock_compliance_report @@ -70,14 +73,16 @@ async def test_create_compliance_report_success( mock_repo.get_compliance_report_status_by_desc.assert_called_once_with( compliance_report_create_schema.status ) - mock_repo.add_compliance_report.assert_called_once() + mock_repo.create_compliance_report.assert_called_once() @pytest.mark.anyio async def test_create_compliance_report_unexpected_error( compliance_report_service, mock_repo ): - mock_repo.add_compliance_report.side_effect = Exception("Unexpected error occurred") + mock_repo.create_compliance_report.side_effect = Exception( + "Unexpected error occurred" + ) with pytest.raises(ServiceException): await compliance_report_service.create_compliance_report( diff --git a/backend/lcfs/tests/compliance_report/test_update_service.py b/backend/lcfs/tests/compliance_report/test_update_service.py index ec4b7e130..12532c4e0 100644 --- a/backend/lcfs/tests/compliance_report/test_update_service.py +++ b/backend/lcfs/tests/compliance_report/test_update_service.py @@ -30,8 +30,8 @@ def mock_user_has_roles(): def mock_notification_service(): mock_service = AsyncMock(spec=NotificationService) with patch( - "lcfs.web.api.compliance_report.update_service.Depends", - return_value=mock_service + "lcfs.web.api.compliance_report.update_service.Depends", + return_value=mock_service, ): yield mock_service @@ -47,6 +47,7 @@ def mock_environment_vars(): mock_settings.ches_sender_name = "Mock Notification System" yield mock_settings + # Mock for adjust_balance method within the OrganizationsService @pytest.fixture def mock_org_service(): @@ -66,6 +67,9 @@ async def test_update_compliance_report_status_change( mock_report.compliance_report_id = report_id mock_report.current_status = MagicMock(spec=ComplianceReportStatus) mock_report.current_status.status = ComplianceReportStatusEnum.Draft + mock_report.compliance_period = MagicMock() + mock_report.compliance_period.description = "2024" + mock_report.transaction_id = 123 new_status = MagicMock(spec=ComplianceReportStatus) new_status.status = ComplianceReportStatusEnum.Submitted @@ -78,8 +82,8 @@ async def test_update_compliance_report_status_change( mock_repo.get_compliance_report_by_id.return_value = mock_report mock_repo.get_compliance_report_status_by_desc.return_value = new_status compliance_report_update_service.handle_status_change = AsyncMock() - compliance_report_update_service.notfn_service = mock_notification_service mock_repo.update_compliance_report.return_value = mock_report + compliance_report_update_service._perform_notification_call = AsyncMock() # Call the method updated_report = await compliance_report_update_service.update_compliance_report( @@ -101,10 +105,9 @@ async def test_update_compliance_report_status_change( mock_report, compliance_report_update_service.request.user ) mock_repo.update_compliance_report.assert_called_once_with(mock_report) - - assert mock_report.current_status == new_status - assert mock_report.supplemental_note == report_data.supplemental_note - mock_notification_service.send_notification.assert_called_once() + compliance_report_update_service._perform_notification_call.assert_called_once_with( + mock_report, "Submitted" + ) @pytest.mark.anyio @@ -117,7 +120,11 @@ async def test_update_compliance_report_no_status_change( mock_report.compliance_report_id = report_id mock_report.current_status = MagicMock(spec=ComplianceReportStatus) mock_report.current_status.status = ComplianceReportStatusEnum.Draft + mock_report.compliance_period = MagicMock() + mock_report.compliance_period.description = "2024" + mock_report.transaction_id = 123 + # Status does not change report_data = ComplianceReportUpdateSchema( status="Draft", supplemental_note="Test note" ) @@ -128,9 +135,7 @@ async def test_update_compliance_report_no_status_change( mock_report.current_status ) mock_repo.update_compliance_report.return_value = mock_report - - # Mock the handle_status_change method - compliance_report_update_service.handle_status_change = AsyncMock() + compliance_report_update_service._perform_notification_call = AsyncMock() # Call the method updated_report = await compliance_report_update_service.update_compliance_report( @@ -139,19 +144,11 @@ async def test_update_compliance_report_no_status_change( # Assertions assert updated_report == mock_report - mock_repo.get_compliance_report_by_id.assert_called_once_with( - report_id, is_model=True + compliance_report_update_service._perform_notification_call.assert_called_once_with( + mock_report, "Draft" ) - mock_repo.get_compliance_report_status_by_desc.assert_called_once_with( - report_data.status - ) - compliance_report_update_service.handle_status_change.assert_not_called() - mock_repo.add_compliance_report_history.assert_not_called() mock_repo.update_compliance_report.assert_called_once_with(mock_report) - assert mock_report.current_status == mock_report.current_status - assert mock_report.supplemental_note == report_data.supplemental_note - @pytest.mark.anyio async def test_update_compliance_report_not_found( diff --git a/backend/lcfs/tests/initiative_agreement/test_initiative_agreement_services.py b/backend/lcfs/tests/initiative_agreement/test_initiative_agreement_services.py index 85d0299a9..cb0ee6994 100644 --- a/backend/lcfs/tests/initiative_agreement/test_initiative_agreement_services.py +++ b/backend/lcfs/tests/initiative_agreement/test_initiative_agreement_services.py @@ -87,14 +87,23 @@ async def test_get_initiative_agreement(service, mock_repo): mock_repo.get_initiative_agreement_by_id.assert_called_once_with(1) +@pytest.mark.anyio @pytest.mark.anyio async def test_create_initiative_agreement(service, mock_repo, mock_request): + # Mock status for the initiative agreement mock_status = MagicMock(status=InitiativeAgreementStatusEnum.Recommended) mock_repo.get_initiative_agreement_status_by_name.return_value = mock_status - mock_repo.create_initiative_agreement.return_value = MagicMock( - spec=InitiativeAgreement - ) + # Create a mock initiative agreement with serializable fields + mock_initiative_agreement = MagicMock(spec=InitiativeAgreement) + mock_initiative_agreement.initiative_agreement_id = 1 + mock_initiative_agreement.current_status.status = "Recommended" + mock_initiative_agreement.to_organization_id = 3 + + # Mock return value of create_initiative_agreement + mock_repo.create_initiative_agreement.return_value = mock_initiative_agreement + + # Create input data create_data = InitiativeAgreementCreateSchema( compliance_units=150, current_status="Recommended", @@ -104,10 +113,18 @@ async def test_create_initiative_agreement(service, mock_repo, mock_request): internal_comment=None, ) + # Mock _perform_notification_call to isolate it + service._perform_notification_call = AsyncMock() + + # Call the service method result = await service.create_initiative_agreement(create_data) - assert isinstance(result, InitiativeAgreement) + # Assertions + assert result == mock_initiative_agreement mock_repo.create_initiative_agreement.assert_called_once() + service._perform_notification_call.assert_called_once_with( + mock_initiative_agreement + ) @pytest.mark.anyio diff --git a/backend/lcfs/tests/notification/test_notification_repo.py b/backend/lcfs/tests/notification/test_notification_repo.py index 20eb31169..bbc4ee80f 100644 --- a/backend/lcfs/tests/notification/test_notification_repo.py +++ b/backend/lcfs/tests/notification/test_notification_repo.py @@ -1,3 +1,5 @@ +from lcfs.db.models.notification.NotificationChannel import ChannelEnum +from lcfs.web.api.base import NotificationTypeEnum, PaginationRequestSchema import pytest from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import delete @@ -92,34 +94,21 @@ async def mock_execute(*args, **kwargs): @pytest.mark.anyio async def test_get_notification_messages_by_user(notification_repo, mock_db_session): mock_notification1 = MagicMock(spec=NotificationMessage) - mock_notification1.related_user_id = 1 - mock_notification1.origin_user_id = 2 - mock_notification1.notification_message_id = 1 - mock_notification1.message = "Test message 1" - mock_notification2 = MagicMock(spec=NotificationMessage) - mock_notification2.related_user_id = 1 - mock_notification2.origin_user_id = 2 - mock_notification2.notification_message_id = 2 - mock_notification2.message = "Test message 2" - mock_result_chain = MagicMock() - mock_result_chain.scalars.return_value.all.return_value = [ + mock_result = MagicMock() + mock_result.unique.return_value.scalars.return_value.all.return_value = [ mock_notification1, mock_notification2, ] - async def mock_execute(*args, **kwargs): - return mock_result_chain - - # Inject the mocked execute method into the session - mock_db_session.execute = mock_execute + mock_db_session.execute = AsyncMock(return_value=mock_result) result = await notification_repo.get_notification_messages_by_user(1) assert len(result) == 2 - assert result[0].notification_message_id == 1 - assert result[1].notification_message_id == 2 + assert result == [mock_notification1, mock_notification2] + mock_db_session.execute.assert_called_once() @pytest.mark.anyio @@ -158,7 +147,7 @@ async def test_delete_notification_message(notification_repo, mock_db_session): NotificationMessage.notification_message_id == notification_id ) assert str(executed_query) == str(expected_query) - + mock_db_session.execute.assert_called_once() mock_db_session.flush.assert_called_once() @@ -277,3 +266,83 @@ async def mock_execute(*args, **kwargs): assert result is not None assert result.notification_channel_subscription_id == subscription_id + + +@pytest.mark.anyio +async def test_create_notification_messages(notification_repo, mock_db_session): + messages = [ + MagicMock(spec=NotificationMessage), + MagicMock(spec=NotificationMessage), + ] + + await notification_repo.create_notification_messages(messages) + + mock_db_session.add_all.assert_called_once_with(messages) + mock_db_session.flush.assert_called_once() + + +@pytest.mark.anyio +async def test_mark_notifications_as_read(notification_repo, mock_db_session): + user_id = 1 + notification_ids = [1, 2, 3] + + mock_db_session.execute = AsyncMock() + mock_db_session.flush = AsyncMock() + + result = await notification_repo.mark_notifications_as_read( + user_id, notification_ids + ) + + assert result == notification_ids + mock_db_session.execute.assert_called_once() + mock_db_session.flush.assert_called_once() + + +@pytest.mark.anyio +async def test_get_notification_type_by_name(notification_repo, mock_db_session): + # Create a mock result that properly simulates the SQLAlchemy result + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = 123 + mock_result.scalars.return_value = mock_scalars + + mock_db_session.execute = AsyncMock(return_value=mock_result) + + result = await notification_repo.get_notification_type_by_name("TestNotification") + + assert result == 123 + mock_db_session.execute.assert_called_once() + + +@pytest.mark.anyio +async def test_get_notification_channel_by_name(notification_repo, mock_db_session): + # Similar setup to the previous test + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = 456 + mock_result.scalars.return_value = mock_scalars + + mock_db_session.execute = AsyncMock(return_value=mock_result) + + result = await notification_repo.get_notification_channel_by_name(ChannelEnum.EMAIL) + + assert result == 456 + mock_db_session.execute.assert_called_once() + + +@pytest.mark.anyio +async def test_get_subscribed_users_by_channel(notification_repo, mock_db_session): + # Similar setup, but using .all() instead of .first() + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.all.return_value = [1, 2, 3] + mock_result.scalars.return_value = mock_scalars + + mock_db_session.execute = AsyncMock(return_value=mock_result) + + result = await notification_repo.get_subscribed_users_by_channel( + NotificationTypeEnum.BCEID__TRANSFER__PARTNER_ACTIONS, ChannelEnum.EMAIL + ) + + assert result == [1, 2, 3] + mock_db_session.execute.assert_called_once() diff --git a/backend/lcfs/tests/services/rabbitmq/test_report_consumer.py b/backend/lcfs/tests/services/rabbitmq/test_report_consumer.py new file mode 100644 index 000000000..838c9fe0c --- /dev/null +++ b/backend/lcfs/tests/services/rabbitmq/test_report_consumer.py @@ -0,0 +1,218 @@ +import json +from contextlib import ExitStack +from unittest.mock import AsyncMock, patch, MagicMock + +import pytest +from pandas.io.formats.format import return_docstring + +from lcfs.db.models.transaction.Transaction import TransactionActionEnum, Transaction +from lcfs.services.rabbitmq.report_consumer import ( + ReportConsumer, +) +from lcfs.tests.fuel_export.conftest import mock_compliance_report_repo +from lcfs.web.api.compliance_report.schema import ComplianceReportCreateSchema + + +@pytest.fixture +def mock_app(): + """Fixture to provide a mocked FastAPI app.""" + return MagicMock() + + +@pytest.fixture +def mock_redis(): + """Fixture to mock Redis client.""" + return AsyncMock() + + +@pytest.fixture +def mock_session(): + # Create a mock session that behaves like an async context manager. + # Specifying `spec=AsyncSession` helps ensure it behaves like the real class. + from sqlalchemy.ext.asyncio import AsyncSession + + mock_session = AsyncMock(spec=AsyncSession) + + # `async with mock_session:` should work, so we define what happens on enter/exit + mock_session.__aenter__.return_value = mock_session + mock_session.__aexit__.return_value = None + + # Now mock the transaction context manager returned by `session.begin()` + mock_transaction = AsyncMock() + mock_transaction.__aenter__.return_value = mock_transaction + mock_transaction.__aexit__.return_value = None + mock_session.begin.return_value = mock_transaction + + return mock_session + + +@pytest.fixture +def mock_repositories(): + """Fixture to mock all repositories and services.""" + + mock_compliance_report_repo = MagicMock() + mock_compliance_report_repo.get_compliance_report_by_legacy_id = AsyncMock( + return_value=MagicMock() + ) + mock_compliance_report_repo.get_compliance_report_status_by_desc = AsyncMock( + return_value=MagicMock() + ) + mock_compliance_report_repo.add_compliance_report_history = AsyncMock() + + org_service = MagicMock() + org_service.adjust_balance = AsyncMock() + + mock_transaction_repo = MagicMock() + mock_transaction_repo.get_transaction_by_id = AsyncMock( + return_value=MagicMock( + spec=Transaction, transaction_action=TransactionActionEnum.Reserved + ) + ) + + return { + "compliance_report_repo": mock_compliance_report_repo, + "transaction_repo": mock_transaction_repo, + "user_repo": AsyncMock(), + "org_service": org_service, + "compliance_service": AsyncMock(), + } + + +@pytest.fixture +def setup_patches(mock_redis, mock_session, mock_repositories): + """Fixture to apply patches for dependencies.""" + with ExitStack() as stack: + stack.enter_context( + patch("redis.asyncio.Redis.from_url", return_value=mock_redis) + ) + + stack.enter_context( + patch( + "lcfs.services.rabbitmq.report_consumer.AsyncSession", + return_value=mock_session, + ) + ) + stack.enter_context( + patch("lcfs.services.rabbitmq.report_consumer.async_engine", MagicMock()) + ) + + stack.enter_context( + patch( + "lcfs.services.rabbitmq.report_consumer.ComplianceReportRepository", + return_value=mock_repositories["compliance_report_repo"], + ) + ) + stack.enter_context( + patch( + "lcfs.services.rabbitmq.report_consumer.TransactionRepository", + return_value=mock_repositories["transaction_repo"], + ) + ) + stack.enter_context( + patch( + "lcfs.services.rabbitmq.report_consumer.UserRepository", + return_value=mock_repositories["user_repo"], + ) + ) + stack.enter_context( + patch( + "lcfs.services.rabbitmq.report_consumer.OrganizationsService", + return_value=mock_repositories["org_service"], + ) + ) + stack.enter_context( + patch( + "lcfs.services.rabbitmq.report_consumer.ComplianceReportServices", + return_value=mock_repositories["compliance_service"], + ) + ) + yield stack + + +@pytest.mark.anyio +async def test_process_message_created(mock_app, setup_patches, mock_repositories): + consumer = ReportConsumer(mock_app) + + # Prepare a sample message for "Created" action + message = { + "tfrs_id": 123, + "organization_id": 1, + "compliance_period": "2023", + "nickname": "Test Report", + "action": "Created", + "user_id": 42, + } + body = json.dumps(message).encode() + + # Ensure correct mock setup + mock_user = MagicMock() + mock_repositories["user_repo"].get_user_by_id.return_value = mock_user + + await consumer.process_message(body) + + # Assertions for "Created" action + mock_repositories[ + "compliance_service" + ].create_compliance_report.assert_called_once_with( + 1, # org_id + ComplianceReportCreateSchema( + legacy_id=123, + compliance_period="2023", + organization_id=1, + nickname="Test Report", + status="Draft", + ), + mock_user, + ) + + +@pytest.mark.anyio +async def test_process_message_submitted(mock_app, setup_patches, mock_repositories): + consumer = ReportConsumer(mock_app) + + # Prepare a sample message for "Submitted" action + message = { + "tfrs_id": 123, + "organization_id": 1, + "compliance_period": "2023", + "nickname": "Test Report", + "action": "Submitted", + "credits": 50, + "user_id": 42, + } + body = json.dumps(message).encode() + + await consumer.process_message(body) + + # Assertions for "Submitted" action + mock_repositories[ + "compliance_report_repo" + ].get_compliance_report_by_legacy_id.assert_called_once_with(123) + mock_repositories["org_service"].adjust_balance.assert_called_once_with( + TransactionActionEnum.Reserved, 50, 1 + ) + mock_repositories[ + "compliance_report_repo" + ].add_compliance_report_history.assert_called_once() + + +@pytest.mark.anyio +async def test_process_message_approved(mock_app, setup_patches, mock_repositories): + consumer = ReportConsumer(mock_app) + + # Prepare a sample message for "Approved" action + message = { + "tfrs_id": 123, + "organization_id": 1, + "action": "Approved", + "user_id": 42, + } + body = json.dumps(message).encode() + + await consumer.process_message(body) + + # Assertions for "Approved" action + mock_repositories[ + "compliance_report_repo" + ].get_compliance_report_by_legacy_id.assert_called_once_with(123) + mock_repositories["transaction_repo"].confirm_transaction.assert_called_once() diff --git a/backend/lcfs/tests/services/rabbitmq/test_transaction_consumer.py b/backend/lcfs/tests/services/rabbitmq/test_transaction_consumer.py deleted file mode 100644 index 3bd8d539a..000000000 --- a/backend/lcfs/tests/services/rabbitmq/test_transaction_consumer.py +++ /dev/null @@ -1,111 +0,0 @@ -from contextlib import ExitStack - -import pytest -from unittest.mock import AsyncMock, patch, MagicMock -import json - - -from lcfs.db.models.transaction.Transaction import TransactionActionEnum -from lcfs.services.rabbitmq.transaction_consumer import ( - setup_transaction_consumer, - close_transaction_consumer, - TransactionConsumer, - consumer, - consumer_task, -) - - -@pytest.mark.anyio -async def test_setup_transaction_consumer(): - with patch( - "lcfs.services.rabbitmq.transaction_consumer.TransactionConsumer" - ) as MockConsumer: - mock_consumer = MockConsumer.return_value - mock_consumer.connect = AsyncMock() - mock_consumer.start_consuming = AsyncMock() - - await setup_transaction_consumer() - - mock_consumer.connect.assert_called_once() - mock_consumer.start_consuming.assert_called_once() - - -@pytest.mark.anyio -async def test_close_transaction_consumer(): - with patch( - "lcfs.services.rabbitmq.transaction_consumer.TransactionConsumer" - ) as MockConsumer: - mock_consumer = MockConsumer.return_value - mock_consumer.connect = AsyncMock() - mock_consumer.start_consuming = AsyncMock() - mock_consumer.close_connection = AsyncMock() - - await setup_transaction_consumer() - - await close_transaction_consumer() - - mock_consumer.close_connection.assert_called_once() - - -@pytest.mark.anyio -async def test_process_message(): - mock_redis = AsyncMock() - mock_session = AsyncMock() - mock_repo = AsyncMock() - mock_redis_balance_service = AsyncMock() - adjust_balance = AsyncMock() - - with ExitStack() as stack: - stack.enter_context( - patch("redis.asyncio.Redis.from_url", return_value=mock_redis) - ) - stack.enter_context( - patch("sqlalchemy.ext.asyncio.AsyncSession", return_value=mock_session) - ) - stack.enter_context( - patch( - "lcfs.web.api.organizations.repo.OrganizationsRepository", - return_value=mock_repo, - ) - ) - stack.enter_context( - patch( - "lcfs.web.api.transaction.repo.TransactionRepository.calculate_available_balance", - side_effect=[100, 200, 150, 250, 300, 350], - ) - ) - stack.enter_context( - patch( - "lcfs.web.api.transaction.repo.TransactionRepository.calculate_reserved_balance", - side_effect=[100, 200, 150, 250, 300, 350], - ) - ) - stack.enter_context( - patch( - "lcfs.services.tfrs.redis_balance.RedisBalanceService", - return_value=mock_redis_balance_service, - ) - ) - stack.enter_context( - patch( - "lcfs.web.api.organizations.services.OrganizationsService.adjust_balance", - adjust_balance, - ) - ) - - # Create an instance of the consumer - consumer = TransactionConsumer() - - # Prepare a sample message - message = { - "compliance_units_amount": 100, - "organization_id": 1, - } - body = json.dumps(message).encode() - - mock_request = AsyncMock() - - await consumer.process_message(body, mock_request) - - # Assert that the organization service's adjust_balance method was called correctly - adjust_balance.assert_called_once_with(TransactionActionEnum.Adjustment, 100, 1) diff --git a/backend/lcfs/tests/transfer/test_transfer_services.py b/backend/lcfs/tests/transfer/test_transfer_services.py index d9e30abfb..f82ef70c7 100644 --- a/backend/lcfs/tests/transfer/test_transfer_services.py +++ b/backend/lcfs/tests/transfer/test_transfer_services.py @@ -76,13 +76,13 @@ async def test_create_transfer_success(transfer_service, mock_transfer_repo): ) mock_transfer_repo.create_transfer.return_value = transfer_data - # Patch the _perform_notificaiton_call method - with patch.object(transfer_service, "_perform_notificaiton_call", AsyncMock()): + # Patch the _perform_notification_call method + with patch.object(transfer_service, "_perform_notification_call", AsyncMock()): result = await transfer_service.create_transfer(transfer_data) assert result.transfer_id == transfer_id assert isinstance(result, TransferCreateSchema) - transfer_service._perform_notificaiton_call.assert_called_once() + transfer_service._perform_notification_call.assert_called_once() @pytest.mark.anyio @@ -91,8 +91,15 @@ async def test_update_transfer_success( ): transfer_status = TransferStatus(transfer_status_id=1, status="status") transfer_id = 1 + # Create valid nested organization objects + from_org = Organization(organization_id=1, name="org1") + to_org = Organization(organization_id=2, name="org2") + + # Create a Transfer object with the necessary attributes transfer = Transfer( transfer_id=transfer_id, + from_organization=from_org, + to_organization=to_org, from_organization_id=1, to_organization_id=2, from_transaction_id=1, @@ -114,11 +121,22 @@ async def test_update_transfer_success( mock_transfer_repo.get_transfer_by_id.return_value = transfer mock_transfer_repo.update_transfer.return_value = transfer + # Replace _perform_notification_call with an AsyncMock + transfer_service._perform_notification_call = AsyncMock() + result = await transfer_service.update_transfer(transfer) + # Assertions assert result.transfer_id == transfer_id assert isinstance(result, Transfer) + # Verify mocks + mock_transfer_repo.get_transfer_by_id.assert_called_once_with(transfer_id) + mock_transfer_repo.update_transfer.assert_called_once_with(transfer) + transfer_service._perform_notification_call.assert_awaited_once_with( + transfer, status="Return to analyst" + ) + @pytest.mark.anyio async def test_update_category_success(transfer_service, mock_transfer_repo): diff --git a/backend/lcfs/web/api/compliance_report/repo.py b/backend/lcfs/web/api/compliance_report/repo.py index 194afb8d0..1bac62843 100644 --- a/backend/lcfs/web/api/compliance_report/repo.py +++ b/backend/lcfs/web/api/compliance_report/repo.py @@ -2,6 +2,8 @@ from typing import List, Optional, Dict from collections import defaultdict from datetime import datetime + +from lcfs.db.models import UserProfile from lcfs.db.models.organization.Organization import Organization from lcfs.db.models.fuel.FuelType import FuelType from lcfs.db.models.fuel.FuelCategory import FuelCategory @@ -15,7 +17,6 @@ PaginationRequestSchema, apply_filter_conditions, get_field_for_filter, - get_enum_value, ) from lcfs.db.models.compliance import CompliancePeriod from lcfs.db.models.compliance.ComplianceReport import ( @@ -181,7 +182,9 @@ async def check_compliance_report( ) @repo_handler - async def get_compliance_report_status_by_desc(self, status: str) -> int: + async def get_compliance_report_status_by_desc( + self, status: str + ) -> ComplianceReportStatus: """ Retrieve the compliance report status ID from the database based on the description. Replaces spaces with underscores in the status description. @@ -266,7 +269,7 @@ async def get_assessed_compliance_report_by_period( return result @repo_handler - async def add_compliance_report(self, report: ComplianceReport): + async def create_compliance_report(self, report: ComplianceReport): """ Add a new compliance report to the database """ @@ -304,7 +307,9 @@ async def get_compliance_report_history(self, report: ComplianceReport): return history.scalar_one_or_none() @repo_handler - async def add_compliance_report_history(self, report: ComplianceReport, user): + async def add_compliance_report_history( + self, report: ComplianceReport, user: UserProfile + ): """ Add a new compliance report history record to the database """ @@ -823,3 +828,26 @@ async def get_latest_report_by_group_uuid( .limit(1) ) return result.scalars().first() + + async def get_compliance_report_by_legacy_id(self, legacy_id): + """ + Retrieve a compliance report from the database by ID + """ + result = await self.db.execute( + select(ComplianceReport) + .options( + joinedload(ComplianceReport.organization), + joinedload(ComplianceReport.compliance_period), + joinedload(ComplianceReport.current_status), + joinedload(ComplianceReport.summary), + joinedload(ComplianceReport.history).joinedload( + ComplianceReportHistory.status + ), + joinedload(ComplianceReport.history).joinedload( + ComplianceReportHistory.user_profile + ), + joinedload(ComplianceReport.transaction), + ) + .where(ComplianceReport.legacy_id == legacy_id) + ) + return result.scalars().unique().first() diff --git a/backend/lcfs/web/api/compliance_report/schema.py b/backend/lcfs/web/api/compliance_report/schema.py index 9eb215c53..34696dee2 100644 --- a/backend/lcfs/web/api/compliance_report/schema.py +++ b/backend/lcfs/web/api/compliance_report/schema.py @@ -148,7 +148,6 @@ class ComplianceReportBaseSchema(BaseSchema): current_status_id: int current_status: ComplianceReportStatusSchema transaction_id: Optional[int] = None - # transaction: Optional[TransactionBaseSchema] = None nickname: Optional[str] = None supplemental_note: Optional[str] = None reporting_frequency: Optional[ReportingFrequency] = None @@ -166,6 +165,8 @@ class ComplianceReportCreateSchema(BaseSchema): compliance_period: str organization_id: int status: str + legacy_id: Optional[int] = None + nickname: Optional[str] = None class ComplianceReportListSchema(BaseSchema): diff --git a/backend/lcfs/web/api/compliance_report/services.py b/backend/lcfs/web/api/compliance_report/services.py index 31993bc75..dac78edd9 100644 --- a/backend/lcfs/web/api/compliance_report/services.py +++ b/backend/lcfs/web/api/compliance_report/services.py @@ -27,10 +27,7 @@ class ComplianceReportServices: - def __init__( - self, request: Request = None, repo: ComplianceReportRepository = Depends() - ) -> None: - self.request = request + def __init__(self, repo: ComplianceReportRepository = Depends()) -> None: self.repo = repo @service_handler @@ -41,7 +38,10 @@ async def get_all_compliance_periods(self) -> List[CompliancePeriodSchema]: @service_handler async def create_compliance_report( - self, organization_id: int, report_data: ComplianceReportCreateSchema + self, + organization_id: int, + report_data: ComplianceReportCreateSchema, + user: UserProfile, ) -> ComplianceReportBaseSchema: """Creates a new compliance report.""" period = await self.repo.get_compliance_period(report_data.compliance_period) @@ -52,8 +52,7 @@ async def create_compliance_report( report_data.status ) if not draft_status: - raise DataNotFoundException( - f"Status '{report_data.status}' not found.") + raise DataNotFoundException(f"Status '{report_data.status}' not found.") # Generate a new group_uuid for the new report series group_uuid = str(uuid.uuid4()) @@ -65,15 +64,17 @@ async def create_compliance_report( reporting_frequency=ReportingFrequency.ANNUAL, compliance_report_group_uuid=group_uuid, # New group_uuid for the series version=0, # Start with version 0 - nickname="Original Report", + nickname=report_data.nickname or "Original Report", summary=ComplianceReportSummary(), # Create an empty summary object + legacy_id=report_data.legacy_id, + create_user=user.keycloak_username, ) # Add the new compliance report - report = await self.repo.add_compliance_report(report) + report = await self.repo.create_compliance_report(report) # Create the history record - await self.repo.add_compliance_report_history(report, self.request.user) + await self.repo.add_compliance_report_history(report, user) return ComplianceReportBaseSchema.model_validate(report) @@ -137,7 +138,7 @@ async def create_supplemental_report( ) # Add the new supplemental report - new_report = await self.repo.add_compliance_report(new_report) + new_report = await self.repo.create_compliance_report(new_report) # Create the history record for the new supplemental report await self.repo.add_compliance_report_history(new_report, user) @@ -228,8 +229,7 @@ async def get_compliance_report_by_id( if apply_masking: # Apply masking to each report in the chain - masked_chain = self._mask_report_status( - compliance_report_chain) + masked_chain = self._mask_report_status(compliance_report_chain) # Apply history masking to each report in the chain masked_chain = [ self._mask_report_status_for_history(report, apply_masking) diff --git a/backend/lcfs/web/api/compliance_report/summary_service.py b/backend/lcfs/web/api/compliance_report/summary_service.py index 6d781a3d6..f241da143 100644 --- a/backend/lcfs/web/api/compliance_report/summary_service.py +++ b/backend/lcfs/web/api/compliance_report/summary_service.py @@ -118,9 +118,11 @@ def convert_summary_to_dict( "description" ] ), - field=RENEWABLE_FUEL_TARGET_DESCRIPTIONS[str(line)]["field"], + field=RENEWABLE_FUEL_TARGET_DESCRIPTIONS[str( + line)]["field"], ) - summary.renewable_fuel_target_summary.append(existing_element) + summary.renewable_fuel_target_summary.append( + existing_element) value = int(getattr(summary_obj, column.key) or 0) if column.key.endswith("_gasoline"): existing_element.gasoline = value @@ -150,7 +152,8 @@ def convert_summary_to_dict( "description" ] ), - field=LOW_CARBON_FUEL_TARGET_DESCRIPTIONS[str(line)]["field"], + field=LOW_CARBON_FUEL_TARGET_DESCRIPTIONS[str( + line)]["field"], value=int(getattr(summary_obj, column.key) or 0), ) ) @@ -188,7 +191,8 @@ def convert_summary_to_dict( "field" ], ) - summary.non_compliance_penalty_summary.append(existing_element) + summary.non_compliance_penalty_summary.append( + existing_element) value = int(getattr(summary_obj, column.key) or 0) if column.key.endswith("_gasoline"): existing_element.gasoline = value @@ -307,7 +311,8 @@ async def calculate_compliance_report_summary( for transfer in notional_transfers.notional_transfers: # Normalize the fuel category key - normalized_category = transfer.fuel_category.replace(" ", "_").lower() + normalized_category = transfer.fuel_category.replace( + " ", "_").lower() # Update the corresponding category sum if transfer.received_or_transferred.lower() == "received": @@ -324,12 +329,12 @@ async def calculate_compliance_report_summary( fossil_quantities = await self.calculate_fuel_quantities( compliance_report.compliance_report_id, effective_fuel_supplies, - fossil_derived=True, + fossil_derived=True ) renewable_quantities = await self.calculate_fuel_quantities( compliance_report.compliance_report_id, effective_fuel_supplies, - fossil_derived=False, + fossil_derived=False ) renewable_fuel_target_summary = self.calculate_renewable_fuel_target_summary( @@ -450,18 +455,21 @@ def calculate_renewable_fuel_target_summary( deferred_renewables = {"gasoline": 0.0, "diesel": 0.0, "jet_fuel": 0.0} for category in ["gasoline", "diesel", "jet_fuel"]: - required_renewable_quantity = eligible_renewable_fuel_required.get(category) + required_renewable_quantity = eligible_renewable_fuel_required.get( + category) previous_required_renewable_quantity = getattr( - prev_summary, f"line_4_eligible_renewable_fuel_required_{category}" + prev_summary, f"""line_4_eligible_renewable_fuel_required_{ + category}""" ) # only carry over line 6,8 if required quantities have not changed if previous_required_renewable_quantity == required_renewable_quantity: retained_renewables[category] = getattr( - prev_summary, f"line_6_renewable_fuel_retained_{category}" + prev_summary, f"""line_6_renewable_fuel_retained_{ + category}""" ) deferred_renewables[category] = getattr( - prev_summary, f"line_8_obligation_deferred_{category}" + prev_summary, f"""line_8_obligation_deferred_{category}""" ) # line 10 @@ -557,9 +565,12 @@ def calculate_renewable_fuel_target_summary( line=line, description=( RENEWABLE_FUEL_TARGET_DESCRIPTIONS[line]["description"].format( - "{:,}".format(int(summary_lines["4"]["gasoline"] * 0.05)), - "{:,}".format(int(summary_lines["4"]["diesel"] * 0.05)), - "{:,}".format(int(summary_lines["4"]["jet_fuel"] * 0.05)), + "{:,}".format( + int(summary_lines["4"]["gasoline"] * 0.05)), + "{:,}".format( + int(summary_lines["4"]["diesel"] * 0.05)), + "{:,}".format( + int(summary_lines["4"]["jet_fuel"] * 0.05)), ) if (line in ["6", "8"]) else RENEWABLE_FUEL_TARGET_DESCRIPTIONS[line]["description"] @@ -571,7 +582,8 @@ def calculate_renewable_fuel_target_summary( total_value=values.get("gasoline", 0) + values.get("diesel", 0) + values.get("jet_fuel", 0), - format=(FORMATS.CURRENCY if (str(line) == "11") else FORMATS.NUMBER), + format=(FORMATS.CURRENCY if ( + str(line) == "11") else FORMATS.NUMBER), ) for line, values in summary_lines.items() ] @@ -660,7 +672,8 @@ async def calculate_low_carbon_fuel_target_summary( ), field=LOW_CARBON_FUEL_TARGET_DESCRIPTIONS[line]["field"], value=values.get("value", 0), - format=(FORMATS.CURRENCY if (str(line) == "21") else FORMATS.NUMBER), + format=(FORMATS.CURRENCY if ( + str(line) == "21") else FORMATS.NUMBER), ) for line, values in low_carbon_summary_lines.items() ] @@ -675,7 +688,8 @@ def calculate_non_compliance_penalty_summary( non_compliance_penalty_payable = int( (non_compliance_penalty_payable_units * Decimal(-600.0)).max(0) ) - line_11 = next(row for row in renewable_fuel_target_summary if row.line == "11") + line_11 = next( + row for row in renewable_fuel_target_summary if row.line == "11") non_compliance_summary_lines = { "11": {"total_value": line_11.total_value}, @@ -720,11 +734,6 @@ async def calculate_fuel_quantities( await self.repo.aggregate_other_uses(compliance_report_id, fossil_derived) ) - if not fossil_derived: - fuel_quantities.update( - await self.repo.aggregate_allocation_agreements(compliance_report_id) - ) - return dict(fuel_quantities) @service_handler @@ -752,7 +761,8 @@ async def calculate_fuel_supply_compliance_units( ED = fuel_supply.energy_density or 0 # Energy Density # Apply the compliance units formula - compliance_units = calculate_compliance_units(TCI, EER, RCI, UCI, Q, ED) + compliance_units = calculate_compliance_units( + TCI, EER, RCI, UCI, Q, ED) compliance_units_sum += compliance_units return int(compliance_units_sum) @@ -781,9 +791,11 @@ async def calculate_fuel_export_compliance_units( ED = fuel_export.energy_density or 0 # Energy Density # Apply the compliance units formula - compliance_units = calculate_compliance_units(TCI, EER, RCI, UCI, Q, ED) + compliance_units = calculate_compliance_units( + TCI, EER, RCI, UCI, Q, ED) compliance_units = -compliance_units - compliance_units = round(compliance_units) if compliance_units < 0 else 0 + compliance_units = round( + compliance_units) if compliance_units < 0 else 0 compliance_units_sum += compliance_units diff --git a/backend/lcfs/web/api/compliance_report/update_service.py b/backend/lcfs/web/api/compliance_report/update_service.py index 7e76ea76b..1a1d7d9c7 100644 --- a/backend/lcfs/web/api/compliance_report/update_service.py +++ b/backend/lcfs/web/api/compliance_report/update_service.py @@ -1,3 +1,4 @@ +import json from fastapi import Depends, HTTPException, Request from lcfs.web.api.notification.schema import ( COMPLIANCE_REPORT_STATUS_NOTIFICATION_MAPPER, @@ -48,13 +49,7 @@ async def update_compliance_report( raise DataNotFoundException( f"Compliance report with ID {report_id} not found" ) - - notifications = None - notification_data: NotificationMessageSchema = NotificationMessageSchema( - message=f"Compliance report {report.compliance_report_id} has been updated", - related_organization_id=report.organization_id, - origin_user_profile_id=self.request.user.user_profile_id, - ) + current_status = report_data.status # if we're just returning the compliance report back to either compliance manager or analyst, # then neither history nor any updates to summary is required. if report_data.status in RETURN_STATUSES: @@ -64,19 +59,10 @@ async def update_compliance_report( ) if report_data.status == "Return to analyst": report_data.status = ComplianceReportStatusEnum.Submitted.value - notification_data.message = f"Compliance report {report.compliance_report_id} has been returned to analyst" else: report_data.status = ( ComplianceReportStatusEnum.Recommended_by_analyst.value ) - - notification_data.message = f"Compliance report {report.compliance_report_id} has been returned by director" - notification_data.related_user_profile_id = [ - h.user_profile.user_profile_id - for h in report.history - if h.status.status - == ComplianceReportStatusEnum.Recommended_by_analyst - ][0] else: status_has_changed = report.current_status.status != getattr( ComplianceReportStatusEnum, report_data.status.replace(" ", "_") @@ -91,14 +77,37 @@ async def update_compliance_report( updated_report = await self.repo.update_compliance_report(report) if status_has_changed: await self.handle_status_change(report, new_status.status) - notification_data.message = ( - f"Compliance report {report.compliance_report_id} has been updated" - ) - notifications = COMPLIANCE_REPORT_STATUS_NOTIFICATION_MAPPER.get( - new_status.status - ) # Add history record await self.repo.add_compliance_report_history(report, self.request.user) + + await self._perform_notification_call(report, current_status) + return updated_report + + async def _perform_notification_call(self, report, status): + """Send notifications based on the current status of the transfer.""" + status_mapper = status.replace(" ", "_") + notifications = COMPLIANCE_REPORT_STATUS_NOTIFICATION_MAPPER.get( + ( + ComplianceReportStatusEnum[status_mapper] + if status_mapper in ComplianceReportStatusEnum.__members__ + else status + ), + None, + ) + message_data = { + "service": "ComplianceReport", + "id": report.compliance_report_id, + "transactionId": report.transaction_id, + "compliancePeriod": report.compliance_period.description, + "status": status.lower(), + } + notification_data = NotificationMessageSchema( + type=f"Compliance report {status.lower()}", + related_transaction_id=f"CR{report.compliance_report_id}", + message=json.dumps(message_data), + related_organization_id=report.organization_id, + origin_user_profile_id=self.request.user.user_profile_id, + ) if notifications and isinstance(notifications, list): await self.notfn_service.send_notification( NotificationRequestSchema( @@ -106,7 +115,6 @@ async def update_compliance_report( notification_data=notification_data, ) ) - return updated_report async def handle_status_change( self, report: ComplianceReport, new_status: ComplianceReportStatusEnum diff --git a/backend/lcfs/web/api/email/services.py b/backend/lcfs/web/api/email/services.py index 8c7dc4cd8..066a7a664 100644 --- a/backend/lcfs/web/api/email/services.py +++ b/backend/lcfs/web/api/email/services.py @@ -23,19 +23,8 @@ class CHESEmailService: def __init__(self, repo: CHESEmailRepository = Depends()): self.repo = repo - - # CHES configuration - self.config = { - "AUTH_URL": settings.ches_auth_url, - "EMAIL_URL": settings.ches_email_url, - "CLIENT_ID": settings.ches_client_id, - "CLIENT_SECRET": settings.ches_client_secret, - "SENDER_EMAIL": settings.ches_sender_email, - "SENDER_NAME": settings.ches_sender_name, - } self._access_token = None self._token_expiry = None - self._validate_configuration() # Update template directory path to the root templates directory template_dir = os.path.join(os.path.dirname(__file__), "templates") @@ -48,9 +37,24 @@ def _validate_configuration(self): """ Validate the CHES configuration to ensure all necessary environment variables are set. """ - missing = [key for key, value in self.config.items() if not value] - if missing: - raise ValueError(f"Missing configuration: {', '.join(missing)}") + missing_configs = [] + + # Check each required CHES configuration setting + if not settings.ches_auth_url: + missing_configs.append("ches_auth_url") + if not settings.ches_email_url: + missing_configs.append("ches_email_url") + if not settings.ches_client_id: + missing_configs.append("ches_client_id") + if not settings.ches_client_secret: + missing_configs.append("ches_client_secret") + if not settings.ches_sender_email: + missing_configs.append("ches_sender_email") + if not settings.ches_sender_name: + missing_configs.append("ches_sender_name") + + if missing_configs: + raise ValueError(f"Missing CHES configuration: {', '.join(missing_configs)}") @service_handler async def send_notification_email( @@ -62,6 +66,9 @@ async def send_notification_email( """ Send an email notification to users subscribed to the specified notification type. """ + # Validate configuration before performing any operations + self._validate_configuration() + # Retrieve subscribed user emails recipient_emails = await self.repo.get_subscribed_user_emails( notification_type.value, organization_id @@ -109,7 +116,7 @@ def _build_email_payload( return { "bcc": recipients, "to": ["Undisclosed recipients"], - "from": f"{self.config['SENDER_NAME']} <{self.config['SENDER_EMAIL']}>", + "from": f"{settings.ches_sender_name} <{settings.ches_sender_email}>", "delayTS": 0, "encoding": "utf-8", "priority": "normal", @@ -124,9 +131,12 @@ async def send_email(self, payload: Dict[str, Any]) -> bool: """ Send an email using CHES. """ + # Validate configuration before performing any operations + self._validate_configuration() + token = await self.get_ches_token() response = requests.post( - self.config["EMAIL_URL"], + settings.ches_email_url, json=payload, headers={ "Authorization": f"Bearer {token}", @@ -142,12 +152,15 @@ async def get_ches_token(self) -> str: """ Retrieve and cache the CHES access token. """ + # Validate configuration before performing any operations + self._validate_configuration() + if self._access_token and datetime.now().timestamp() < self._token_expiry: return self._access_token response = requests.post( - self.config["AUTH_URL"], + settings.ches_auth_url, data={"grant_type": "client_credentials"}, - auth=(self.config["CLIENT_ID"], self.config["CLIENT_SECRET"]), + auth=(settings.ches_client_id, settings.ches_client_secret), timeout=10, ) response.raise_for_status() @@ -158,4 +171,4 @@ async def get_ches_token(self) -> str: "expires_in", 3600 ) logger.info("Retrieved new CHES token.") - return self._access_token + return self._access_token \ No newline at end of file diff --git a/backend/lcfs/web/api/fuel_export/repo.py b/backend/lcfs/web/api/fuel_export/repo.py index 36aeb4ce1..d09a546dc 100644 --- a/backend/lcfs/web/api/fuel_export/repo.py +++ b/backend/lcfs/web/api/fuel_export/repo.py @@ -260,6 +260,7 @@ async def create_fuel_export(self, fuel_export: FuelExport) -> FuelExport: "fuel_type", "provision_of_the_act", "end_use_type", + "fuel_code", ], ) return fuel_export diff --git a/backend/lcfs/web/api/initiative_agreement/services.py b/backend/lcfs/web/api/initiative_agreement/services.py index b7697f2a4..c9cb3c4de 100644 --- a/backend/lcfs/web/api/initiative_agreement/services.py +++ b/backend/lcfs/web/api/initiative_agreement/services.py @@ -1,3 +1,4 @@ +import json from lcfs.web.api.notification.schema import ( INITIATIVE_AGREEMENT_STATUS_NOTIFICATION_MAPPER, NotificationMessageSchema, @@ -129,7 +130,7 @@ async def update_initiative_agreement( # Return the updated initiative agreement schema with the returned status flag ia_schema = InitiativeAgreementSchema.from_orm(updated_initiative_agreement) ia_schema.returned = returned - await self._perform_notificaiton_call(ia_schema, re_recommended) + await self._perform_notification_call(updated_initiative_agreement, returned) return ia_schema @service_handler @@ -174,7 +175,7 @@ async def create_initiative_agreement( await self.internal_comment_service.create_internal_comment( internal_comment_data ) - await self._perform_notificaiton_call(initiative_agreement) + await self._perform_notification_call(initiative_agreement) return initiative_agreement async def director_approve_initiative_agreement( @@ -208,16 +209,28 @@ async def director_approve_initiative_agreement( initiative_agreement.transaction_effective_date = datetime.now().date() await self.repo.refresh_initiative_agreement(initiative_agreement) - await self._perform_notificaiton_call(initiative_agreement) - async def _perform_notificaiton_call(self, ia, re_recommended=False): + async def _perform_notification_call(self, ia, returned=False): """Send notifications based on the current status of the transfer.""" + status = ia.current_status.status if not returned else "Return to analyst" + status_val = ( + status.value + if isinstance(status, InitiativeAgreementStatusEnum) + else status + ).lower() notifications = INITIATIVE_AGREEMENT_STATUS_NOTIFICATION_MAPPER.get( - ia.current_status.status if not re_recommended else "Return to analyst", - None, + status, None ) + message_data = { + "service": "InitiativeAgreement", + "id": ia.initiative_agreement_id, + "transactionId": ia.transaction_id, + "status": status_val, + } notification_data = NotificationMessageSchema( - message=f"Initiative Agreement {ia.initiative_agreement_id} has been {ia.current_status.status}", + type=f"Initiative agreement {status_val}", + related_transaction_id=f"IA{ia.initiative_agreement_id}", + message=json.dumps(message_data), related_organization_id=ia.to_organization_id, origin_user_profile_id=self.request.user.user_profile_id, ) diff --git a/backend/lcfs/web/api/notification/repo.py b/backend/lcfs/web/api/notification/repo.py index ec32f9716..bd9d874fa 100644 --- a/backend/lcfs/web/api/notification/repo.py +++ b/backend/lcfs/web/api/notification/repo.py @@ -5,20 +5,29 @@ NotificationType, ChannelEnum, ) +from lcfs.db.models.organization import Organization from lcfs.db.models.user import UserProfile -from lcfs.web.api.base import NotificationTypeEnum +from lcfs.db.models.user.UserRole import UserRole +from lcfs.web.api.base import ( + NotificationTypeEnum, + PaginationRequestSchema, + apply_filter_conditions, + get_field_for_filter, + validate_pagination, +) import structlog -from typing import List, Optional +from typing import List, Optional, Sequence from fastapi import Depends from lcfs.db.dependencies import get_async_db_session from lcfs.web.exception.exceptions import DataNotFoundException -from sqlalchemy import delete, or_, select, func +from sqlalchemy import asc, delete, desc, or_, select, func, update from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload +from sqlalchemy.orm import selectinload, joinedload from lcfs.web.core.decorators import repo_handler +from sqlalchemy import and_ logger = structlog.get_logger(__name__) @@ -66,8 +75,15 @@ async def get_notification_messages_by_user( Retrieve all notification messages for a user """ # Start building the query - query = select(NotificationMessage).where( - NotificationMessage.related_user_profile_id == user_profile_id + query = ( + select(NotificationMessage) + .options( + joinedload(NotificationMessage.related_organization), + joinedload(NotificationMessage.origin_user_profile) + .joinedload(UserProfile.user_roles) + .joinedload(UserRole.role), + ) + .where(NotificationMessage.related_user_profile_id == user_profile_id) ) # Apply additional filter for `is_read` if provided @@ -76,7 +92,116 @@ async def get_notification_messages_by_user( # Execute the query and retrieve the results result = await self.db.execute(query) - return result.scalars().all() + return result.unique().scalars().all() + + def _apply_notification_filters( + self, pagination: PaginationRequestSchema, conditions: List + ): + for filter in pagination.filters: + filter_value = filter.filter + filter_option = filter.type + filter_type = filter.filter_type + + # Handle date filters + if filter.field == "date": + filter_value = filter.date_from + field = get_field_for_filter(NotificationMessage, "create_date") + conditions.append( + apply_filter_conditions( + field, filter_value, filter_option, filter_type + ) + ) + elif filter.field == "user": + conditions.append( + NotificationMessage.origin_user_profile.has( + UserProfile.first_name.like(f"%{filter_value}%") + ) + ) + elif filter.field == "organization": + conditions.append( + NotificationMessage.related_organization.has( + Organization.name.like(f"%{filter_value}%") + ) + ) + elif filter.field == "transaction_id": + field = get_field_for_filter(NotificationMessage, 'related_transaction_id') + conditions.append( + apply_filter_conditions( + field, filter_value, filter_option, filter_type + ) + ) + else: + field = get_field_for_filter(NotificationMessage, filter.field) + conditions.append( + apply_filter_conditions( + field, filter_value, filter_option, filter_type + ) + ) + + return conditions + + @repo_handler + async def get_paginated_notification_messages( + self, user_id, pagination: PaginationRequestSchema + ) -> tuple[Sequence[NotificationMessage], int]: + """ + Queries notification messages from the database with optional filters. Supports pagination and sorting. + + Args: + pagination (dict): Pagination and sorting parameters. + + Returns: + List[NotificationSchema]: A list of notification messages matching the query. + """ + conditions = [NotificationMessage.related_user_profile_id == user_id] + pagination = validate_pagination(pagination) + + if pagination.filters: + self._apply_notification_filters(pagination, conditions) + + offset = 0 if (pagination.page < 1) else (pagination.page - 1) * pagination.size + limit = pagination.size + # Start building the query + query = ( + select(NotificationMessage) + .options( + joinedload(NotificationMessage.related_organization), + joinedload(NotificationMessage.origin_user_profile) + .joinedload(UserProfile.user_roles) + .joinedload(UserRole.role), + ) + .where(and_(*conditions)) + ) + + # Apply sorting + order_clauses = [] + if not pagination.sort_orders: + order_clauses.append(desc(NotificationMessage.create_date)) + else: + for order in pagination.sort_orders: + direction = asc if order.direction == "asc" else desc + if order.field == "date": + field = NotificationMessage.create_date + elif order.field == "user": + field = UserProfile.first_name + elif order.field == "organization": + field = Organization.name + elif order.field == "transaction_id": + field = NotificationMessage.related_transaction_id + else: + field = getattr(NotificationMessage, order.field) + if field is not None: + order_clauses.append(direction(field)) + query = query.order_by(*order_clauses) + + # Execute the count query to get the total count + count_query = query.with_only_columns(func.count()).order_by(None) + total_count = (await self.db.execute(count_query)).scalar() + + # Execute the main query to retrieve all notification_messages + result = await self.db.execute(query.offset(offset).limit(limit)) + notification_messages = result.unique().scalars().all() + return notification_messages, total_count @repo_handler async def get_notification_message_by_id( @@ -136,6 +261,20 @@ async def delete_notification_message(self, notification_id: int): await self.db.execute(query) await self.db.flush() + @repo_handler + async def delete_notification_messages(self, user_id, notification_ids: List[int]): + """ + Delete a notification_message by id + """ + query = delete(NotificationMessage).where( + and_( + NotificationMessage.notification_message_id.in_(notification_ids), + NotificationMessage.related_user_profile_id == user_id, + ) + ) + await self.db.execute(query) + await self.db.flush() + @repo_handler async def mark_notification_as_read( self, notification_id @@ -156,6 +295,31 @@ async def mark_notification_as_read( return notification + @repo_handler + async def mark_notifications_as_read( + self, user_id: int, notification_ids: List[int] + ): + """ + Mark notification messages as read for a user + """ + if not notification_ids: + return [] + + stmt = ( + update(NotificationMessage) + .where( + and_( + NotificationMessage.notification_message_id.in_(notification_ids), + NotificationMessage.related_user_profile_id == user_id, + ) + ) + .values(is_read=True) + ) + await self.db.execute(stmt) + await self.db.flush() + + return notification_ids + @repo_handler async def create_notification_channel_subscription( self, notification_channel_subscription: NotificationChannelSubscription @@ -291,7 +455,7 @@ async def get_subscribed_users_by_channel( NotificationChannel.channel_name == channel.value, or_( UserProfile.organization_id == organization_id, - UserProfile.organization_id.is_(None), + UserProfile.organization_id.is_(None), ), ) ) diff --git a/backend/lcfs/web/api/notification/schema.py b/backend/lcfs/web/api/notification/schema.py index 0176b9bdd..30ff2d5f2 100644 --- a/backend/lcfs/web/api/notification/schema.py +++ b/backend/lcfs/web/api/notification/schema.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Any, Dict, List, Optional from lcfs.db.models.compliance.ComplianceReportStatus import ComplianceReportStatusEnum @@ -5,7 +6,34 @@ InitiativeAgreementStatusEnum, ) from lcfs.db.models.transfer.TransferStatus import TransferStatusEnum -from lcfs.web.api.base import BaseSchema, NotificationTypeEnum +from lcfs.web.api.base import BaseSchema, NotificationTypeEnum, PaginationResponseSchema +from pydantic import computed_field, model_validator + + +class NotificationOrganizationSchema(BaseSchema): + organization_id: int + name: str + + +class NotificationUserProfileSchema(BaseSchema): + first_name: str + last_name: str + organization_id: Optional[int] = None + is_government: bool = False + + @model_validator(mode="before") + def update_government_profile(cls, data): + if data.is_government: + data.first_name = "Government of B.C." + data.last_name = "" + return data + + @computed_field + @property + def full_name(self) -> str: + if self.is_government: + return "Government of B.C." + return f"{self.first_name} {self.last_name}" class NotificationMessageSchema(BaseSchema): @@ -14,9 +42,14 @@ class NotificationMessageSchema(BaseSchema): is_archived: Optional[bool] = False is_warning: Optional[bool] = False is_error: Optional[bool] = False + type: Optional[str] = None message: Optional[str] = None related_organization_id: Optional[int] = None + related_organization: Optional[NotificationOrganizationSchema] = None + related_transaction_id: Optional[str] = None + create_date: Optional[datetime] = None origin_user_profile_id: Optional[int] = None + origin_user_profile: Optional[NotificationUserProfileSchema] = None related_user_profile_id: Optional[int] = None notification_type_id: Optional[int] = None deleted: Optional[bool] = None @@ -53,6 +86,11 @@ class DeleteNotificationChannelSubscriptionResponseSchema(BaseSchema): message: str +class NotificationsSchema(BaseSchema): + notifications: List[NotificationMessageSchema] = [] + pagination: PaginationResponseSchema = None + + class NotificationRequestSchema(BaseSchema): notification_types: List[NotificationTypeEnum] notification_context: Optional[Dict[str, Any]] = {} @@ -93,11 +131,15 @@ class NotificationRequestSchema(BaseSchema): TransferStatusEnum.Sent: [ NotificationTypeEnum.BCEID__TRANSFER__PARTNER_ACTIONS, ], + TransferStatusEnum.Rescinded: [ + NotificationTypeEnum.BCEID__TRANSFER__PARTNER_ACTIONS, + ], TransferStatusEnum.Declined: [ NotificationTypeEnum.BCEID__TRANSFER__PARTNER_ACTIONS, ], TransferStatusEnum.Submitted: [ - NotificationTypeEnum.IDIR_ANALYST__TRANSFER__SUBMITTED_FOR_REVIEW + NotificationTypeEnum.BCEID__TRANSFER__PARTNER_ACTIONS, + NotificationTypeEnum.IDIR_ANALYST__TRANSFER__SUBMITTED_FOR_REVIEW, ], TransferStatusEnum.Recommended: [ NotificationTypeEnum.IDIR_DIRECTOR__TRANSFER__ANALYST_RECOMMENDATION @@ -110,6 +152,9 @@ class NotificationRequestSchema(BaseSchema): NotificationTypeEnum.BCEID__TRANSFER__DIRECTOR_DECISION, NotificationTypeEnum.IDIR_ANALYST__TRANSFER__DIRECTOR_RECORDED, ], + "Return to analyst": [ + NotificationTypeEnum.IDIR_ANALYST__TRANSFER__SUBMITTED_FOR_REVIEW + ], } INITIATIVE_AGREEMENT_STATUS_NOTIFICATION_MAPPER = { @@ -118,6 +163,7 @@ class NotificationRequestSchema(BaseSchema): ], InitiativeAgreementStatusEnum.Approved: [ NotificationTypeEnum.BCEID__INITIATIVE_AGREEMENT__DIRECTOR_APPROVAL, + NotificationTypeEnum.IDIR_ANALYST__INITIATIVE_AGREEMENT__RETURNED_TO_ANALYST ], "Return to analyst": [ NotificationTypeEnum.IDIR_ANALYST__INITIATIVE_AGREEMENT__RETURNED_TO_ANALYST diff --git a/backend/lcfs/web/api/notification/services.py b/backend/lcfs/web/api/notification/services.py index e64848823..a36796af7 100644 --- a/backend/lcfs/web/api/notification/services.py +++ b/backend/lcfs/web/api/notification/services.py @@ -1,13 +1,18 @@ -from typing import List, Optional, Union +import math +from typing import List, Optional from lcfs.db.models.notification import ( NotificationChannelSubscription, NotificationMessage, ChannelEnum, ) -from lcfs.web.api.base import NotificationTypeEnum +from lcfs.web.api.base import ( + PaginationRequestSchema, + PaginationResponseSchema, +) from lcfs.web.api.email.services import CHESEmailService from lcfs.web.api.notification.schema import ( NotificationRequestSchema, + NotificationsSchema, SubscriptionSchema, NotificationMessageSchema, ) @@ -47,6 +52,51 @@ async def get_notification_messages_by_user_id( for message in notification_models ] + @service_handler + async def get_paginated_notification_messages( + self, user_id: int, pagination: PaginationRequestSchema + ) -> NotificationsSchema: + """ + Retrieve all notifications for a given user with pagination, filtering and sorting. + """ + notifications, total_count = ( + await self.repo.get_paginated_notification_messages(user_id, pagination) + ) + return NotificationsSchema( + pagination=PaginationResponseSchema( + total=total_count, + page=pagination.page, + size=pagination.size, + total_pages=math.ceil(total_count / pagination.size), + ), + notifications=[ + NotificationMessageSchema.model_validate(notification) + for notification in notifications + ], + ) + + @service_handler + async def update_notification_messages( + self, user_id: int, notification_ids: List[int] + ): + """ + Update multiple notifications (mark as read). + """ + await self.repo.mark_notifications_as_read(user_id, notification_ids) + + return notification_ids + + @service_handler + async def delete_notification_messages( + self, user_id: int, notification_ids: List[int] + ): + """ + Delete multiple notifications. + """ + await self.repo.delete_notification_messages(user_id, notification_ids) + + return notification_ids + @service_handler async def get_notification_message_by_id(self, notification_id: int): """ diff --git a/backend/lcfs/web/api/notification/views.py b/backend/lcfs/web/api/notification/views.py index f4f98d0e9..f5caa8695 100644 --- a/backend/lcfs/web/api/notification/views.py +++ b/backend/lcfs/web/api/notification/views.py @@ -3,15 +3,17 @@ """ from typing import Union, List +from lcfs.web.api.base import PaginationRequestSchema from lcfs.web.exception.exceptions import DataNotFoundException import structlog -from fastapi import APIRouter, Body, Depends, HTTPException, Request +from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response from lcfs.db.models.user.Role import RoleEnum from lcfs.web.api.notification.schema import ( DeleteNotificationChannelSubscriptionResponseSchema, DeleteNotificationMessageResponseSchema, DeleteSubscriptionSchema, DeleteNotificationMessageSchema, + NotificationsSchema, SubscriptionSchema, NotificationMessageSchema, NotificationCountSchema, @@ -43,6 +45,56 @@ async def get_notification_messages_by_user_id( ) +@router.post( + "/list", response_model=NotificationsSchema, status_code=status.HTTP_200_OK +) +@view_handler(["*"]) +async def get_notification_messages_by_user_id( + request: Request, + pagination: PaginationRequestSchema = Body(..., embed=False), + response: Response = None, + service: NotificationService = Depends(), +): + """ + Retrieve all notifications of a user with pagination + """ + return await service.get_paginated_notification_messages( + user_id=request.user.user_profile_id, pagination=pagination + ) + + +@router.put("/", response_model=List[int], status_code=status.HTTP_200_OK) +@view_handler(["*"]) +async def update_notification_messages_to_read( + request: Request, + notification_ids: List[int] = Body(..., embed=False), + response: Response = None, + service: NotificationService = Depends(), +): + """ + Update notifications (mark the messages as read) + """ + return await service.update_notification_messages( + request.user.user_profile_id, notification_ids + ) + + +@router.delete("/", response_model=List[int], status_code=status.HTTP_200_OK) +@view_handler(["*"]) +async def delete_notification_messages( + request: Request, + notification_ids: List[int] = Body(..., embed=False), + response: Response = None, + service: NotificationService = Depends(), +): + """ + Delete notification messages + """ + return await service.delete_notification_messages( + request.user.user_profile_id, notification_ids + ) + + @router.get( "/count", response_model=NotificationCountSchema, diff --git a/backend/lcfs/web/api/organization/views.py b/backend/lcfs/web/api/organization/views.py index e175bf756..a33cdd984 100644 --- a/backend/lcfs/web/api/organization/views.py +++ b/backend/lcfs/web/api/organization/views.py @@ -33,7 +33,7 @@ ComplianceReportCreateSchema, ComplianceReportListSchema, CompliancePeriodSchema, - ChainedComplianceReportSchema + ChainedComplianceReportSchema, ) from lcfs.web.api.compliance_report.services import ComplianceReportServices from .services import OrganizationService @@ -56,8 +56,7 @@ async def get_org_users( request: Request, organization_id: int, - status: str = Query( - default="Active", description="Active or Inactive users list"), + status: str = Query(default="Active", description="Active or Inactive users list"), pagination: PaginationRequestSchema = Body(..., embed=False), response: Response = None, org_service: OrganizationService = Depends(), @@ -249,7 +248,9 @@ async def create_compliance_report( validate: OrganizationValidation = Depends(), ): await validate.create_compliance_report(organization_id, report_data) - return await report_service.create_compliance_report(organization_id, report_data) + return await report_service.create_compliance_report( + organization_id, report_data, request.user + ) @router.post( @@ -307,4 +308,6 @@ async def get_compliance_report_by_id( This endpoint returns the information of a user by ID, including their roles and organization. """ await report_validate.validate_organization_access(report_id) - return await report_service.get_compliance_report_by_id(report_id, apply_masking=True, get_chain=True) + return await report_service.get_compliance_report_by_id( + report_id, apply_masking=True, get_chain=True + ) diff --git a/backend/lcfs/web/api/organizations/services.py b/backend/lcfs/web/api/organizations/services.py index e8ef43620..35c2155a3 100644 --- a/backend/lcfs/web/api/organizations/services.py +++ b/backend/lcfs/web/api/organizations/services.py @@ -16,6 +16,7 @@ OrganizationStatus, OrgStatusEnum, ) +from lcfs.db.models.transaction import Transaction from lcfs.db.models.transaction.Transaction import TransactionActionEnum from lcfs.services.tfrs.redis_balance import ( RedisBalanceService, @@ -44,6 +45,7 @@ logger = structlog.get_logger(__name__) + class OrganizationsService: def __init__( self, @@ -198,7 +200,6 @@ async def update_organization( updated_organization = await self.repo.update_organization(organization) return updated_organization - @service_handler async def get_organization(self, organization_id: int): """handles fetching an organization""" @@ -400,7 +401,7 @@ async def adjust_balance( transaction_action: TransactionActionEnum, compliance_units: int, organization_id: int, - ): + ) -> Transaction: """ Adjusts an organization's balance based on the transaction action. diff --git a/backend/lcfs/web/api/other_uses/schema.py b/backend/lcfs/web/api/other_uses/schema.py index 51327f772..db3e591be 100644 --- a/backend/lcfs/web/api/other_uses/schema.py +++ b/backend/lcfs/web/api/other_uses/schema.py @@ -40,12 +40,19 @@ class ExpectedUseTypeSchema(BaseSchema): description: Optional[str] = None +class FuelCategorySchema(BaseSchema): + fuel_category_id: int + category: str + description: Optional[str] = None + + class FuelTypeSchema(BaseSchema): fuel_type_id: int fuel_type: str fossil_derived: Optional[bool] = None provision_1_id: Optional[int] = None provision_2_id: Optional[int] = None + fuel_categories: List[FuelCategorySchema] default_carbon_intensity: Optional[float] = None fuel_codes: Optional[List[FuelCodeSchema]] = [] provision_of_the_act: Optional[List[ProvisionOfTheActSchema]] = [] diff --git a/backend/lcfs/web/api/transaction/repo.py b/backend/lcfs/web/api/transaction/repo.py index 7134e7332..861b1e32b 100644 --- a/backend/lcfs/web/api/transaction/repo.py +++ b/backend/lcfs/web/api/transaction/repo.py @@ -318,7 +318,7 @@ async def create_transaction( transaction_action: TransactionActionEnum, compliance_units: int, organization_id: int, - ): + ) -> Transaction: """ Creates and saves a new transaction to the database. diff --git a/backend/lcfs/web/api/transaction/schema.py b/backend/lcfs/web/api/transaction/schema.py index 34a44b441..ad0d8411e 100644 --- a/backend/lcfs/web/api/transaction/schema.py +++ b/backend/lcfs/web/api/transaction/schema.py @@ -1,6 +1,6 @@ from typing import Optional, List -from pydantic import ConfigDict +from pydantic import ConfigDict, Field from lcfs.web.api.base import BaseSchema from datetime import datetime from enum import Enum diff --git a/backend/lcfs/web/api/transfer/schema.py b/backend/lcfs/web/api/transfer/schema.py index 858accf73..889437c8a 100644 --- a/backend/lcfs/web/api/transfer/schema.py +++ b/backend/lcfs/web/api/transfer/schema.py @@ -3,7 +3,7 @@ from typing import Optional, List from datetime import date, datetime from enum import Enum -from pydantic import ConfigDict +from pydantic import ConfigDict, Field class TransferRecommendationEnumSchema(str, Enum): diff --git a/backend/lcfs/web/api/transfer/services.py b/backend/lcfs/web/api/transfer/services.py index f498d927e..4b2aba0c9 100644 --- a/backend/lcfs/web/api/transfer/services.py +++ b/backend/lcfs/web/api/transfer/services.py @@ -1,3 +1,4 @@ +import json from lcfs.web.api.notification.schema import ( TRANSFER_STATUS_NOTIFICATION_MAPPER, NotificationMessageSchema, @@ -155,7 +156,6 @@ async def create_transfer( # transfer.transfer_category_id = 1 transfer.current_status = current_status - notifications = TRANSFER_STATUS_NOTIFICATION_MAPPER.get(current_status.status) if current_status.status == TransferStatusEnum.Sent: await self.sign_and_send_from_supplier(transfer) @@ -166,7 +166,7 @@ async def create_transfer( current_status.transfer_status_id, self.request.user.user_profile_id, ) - await self._perform_notificaiton_call(notifications, transfer) + await self._perform_notification_call(transfer, current_status.status) return transfer @service_handler @@ -264,37 +264,68 @@ async def update_transfer(self, transfer_data: TransferCreateSchema) -> Transfer # Finally, update the transfer's status and save the changes transfer.current_status = new_status transfer_result = await self.repo.update_transfer(transfer) - await self._perform_notificaiton_call(transfer_result) + await self._perform_notification_call( + transfer, + status=( + new_status.status + if status_has_changed or re_recommended + else "Return to analyst" + ), + ) return transfer_result - async def _perform_notificaiton_call(self, transfer): + async def _perform_notification_call( + self, transfer: TransferSchema, status: TransferStatusEnum + ): """Send notifications based on the current status of the transfer.""" - notifications = TRANSFER_STATUS_NOTIFICATION_MAPPER.get( - transfer.current_status.status - ) - notification_data = NotificationMessageSchema( - message=f"Transfer {transfer.transfer_id} has been updated", - origin_user_profile_id=self.request.user.user_profile_id, - ) - if notifications and isinstance(notifications, list): - notification_data.related_organization_id = ( - transfer.from_organization_id - if transfer.current_status.status == TransferStatusEnum.Declined - else transfer.to_organization_id - ) - await self.notfn_service.send_notification( - NotificationRequestSchema( - notification_types=notifications, - notification_data=notification_data, - ) + notifications = TRANSFER_STATUS_NOTIFICATION_MAPPER.get(status) + status_val = ( + status.value if isinstance(status, TransferStatusEnum) else status + ).lower() + organization_ids = [] + if status in [ + TransferStatusEnum.Submitted, + TransferStatusEnum.Recommended, + TransferStatusEnum.Declined, + ]: + organization_ids = [transfer.from_organization.organization_id] + elif status in [ + TransferStatusEnum.Sent, + TransferStatusEnum.Rescinded, + ]: + organization_ids = [transfer.to_organization.organization_id] + elif status in [ + TransferStatusEnum.Recorded, + TransferStatusEnum.Refused, + ]: + organization_ids = [ + transfer.to_organization.organization_id, + transfer.from_organization.organization_id, + ] + message_data = { + "service": "Transfer", + "id": transfer.transfer_id, + "transactionId": transfer.from_transaction.transaction_id if getattr(transfer, 'from_transaction', None) else None, + "status": status_val, + "fromOrganizationId": transfer.from_organization.organization_id, + "fromOrganization": transfer.from_organization.name, + "toOrganizationId": transfer.to_organization.organization_id, + "toOrganization": transfer.to_organization.name, + } + type = f"Transfer {status_val}" + if status_val == "sent": + type = "Transfer received" + elif status_val == "return to analyst": + type = "Transfer returned" + for org_id in organization_ids: + notification_data = NotificationMessageSchema( + type=type, + related_transaction_id=f"CT{transfer.transfer_id}", + message=json.dumps(message_data), + related_organization_id=org_id, + origin_user_profile_id=self.request.user.user_profile_id, ) - if transfer.current_status.status in [ - TransferStatusEnum.Refused, - TransferStatusEnum.Recorded, - ]: - notification_data.related_organization_id = ( - transfer.from_organization_id - ) + if notifications and isinstance(notifications, list): await self.notfn_service.send_notification( NotificationRequestSchema( notification_types=notifications, diff --git a/backend/lcfs/web/application.py b/backend/lcfs/web/application.py index e7117a105..8aef6126c 100644 --- a/backend/lcfs/web/application.py +++ b/backend/lcfs/web/application.py @@ -1,4 +1,6 @@ import logging +import os +import debugpy import uuid import structlog diff --git a/backend/lcfs/web/lifetime.py b/backend/lcfs/web/lifetime.py index 5de67c16c..fbe7b0b6e 100644 --- a/backend/lcfs/web/lifetime.py +++ b/backend/lcfs/web/lifetime.py @@ -62,7 +62,7 @@ async def _startup() -> None: # noqa: WPS430 await init_org_balance_cache(app) # Setup RabbitMQ Listeners - await start_consumers() + await start_consumers(app) return _startup diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 129680549..fdce412f9 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -8,11 +8,11 @@ "name": "frontend", "version": "0.0.0", "dependencies": { - "@ag-grid-community/client-side-row-model": "^32.0.2", - "@ag-grid-community/core": "^32.0.2", - "@ag-grid-community/csv-export": "^32.0.2", - "@ag-grid-community/react": "^32.0.2", - "@ag-grid-community/styles": "^32.0.2", + "@ag-grid-community/client-side-row-model": "^32.3.3", + "@ag-grid-community/core": "^32.3.3", + "@ag-grid-community/csv-export": "^32.3.3", + "@ag-grid-community/react": "^32.3.3", + "@ag-grid-community/styles": "^32.3.3", "@bcgov/bc-sans": "^2.1.0", "@emotion/react": "^11.11.3", "@emotion/styled": "^11.11.0", @@ -119,49 +119,49 @@ "license": "MIT" }, "node_modules/@ag-grid-community/client-side-row-model": { - "version": "32.1.0", - "resolved": "https://registry.npmjs.org/@ag-grid-community/client-side-row-model/-/client-side-row-model-32.1.0.tgz", - "integrity": "sha512-R/IA3chA/w9fy6/EeZhi42PTwVnb6bNjGMah1GWGvuNDTvfbPO4X9r4nhOMj6YH483bO+C7pPb4EoLECx0dfRQ==", + "version": "32.3.3", + "resolved": "https://registry.npmjs.org/@ag-grid-community/client-side-row-model/-/client-side-row-model-32.3.3.tgz", + "integrity": "sha512-/6OFltj9qax/xfOcYMOKGFQRFTrPX8hrELfS2jChWwpo/+rpnnFqN2iUlIiAB1tDJZsi2ryl8S4UoFSTcEv/VA==", "dependencies": { - "@ag-grid-community/core": "32.1.0", + "@ag-grid-community/core": "32.3.3", "tslib": "^2.3.0" } }, "node_modules/@ag-grid-community/core": { - "version": "32.1.0", - "resolved": "https://registry.npmjs.org/@ag-grid-community/core/-/core-32.1.0.tgz", - "integrity": "sha512-fHpgSZa/aBjg2DdOzooDxILFZqxmxP8vsjRfeZVtqby19mTKwNAclE7Z6rWzOA0GYjgN9s8JwLFcNA5pvfswMg==", + "version": "32.3.3", + "resolved": "https://registry.npmjs.org/@ag-grid-community/core/-/core-32.3.3.tgz", + "integrity": "sha512-JMr5ahDjjl+pvQbBM1/VrfVFlioCVnMl1PKWc6MC1ENhpXT1+CPQdfhUEUw2VytOulQeQ4eeP0pFKPuBZ5Jn2g==", "dependencies": { - "ag-charts-types": "10.1.0", + "ag-charts-types": "10.3.3", "tslib": "^2.3.0" } }, "node_modules/@ag-grid-community/csv-export": { - "version": "32.1.0", - "resolved": "https://registry.npmjs.org/@ag-grid-community/csv-export/-/csv-export-32.1.0.tgz", - "integrity": "sha512-rtHY+MvfmzlRq3dH8prvoNPOmNrvSxZNDmxSYEGC/y12d6ucoAH+Q1cTksMx5d/LKrUXGCrd/jKoPEi9FSdkNA==", + "version": "32.3.3", + "resolved": "https://registry.npmjs.org/@ag-grid-community/csv-export/-/csv-export-32.3.3.tgz", + "integrity": "sha512-uu5BdegnQCpoySFbhd7n0/yK9mMoepZMN6o36DblPydLXCOLEqOuroIPqQv008slDOK676Pe/O6bMszY3/MUlQ==", "dependencies": { - "@ag-grid-community/core": "32.1.0", + "@ag-grid-community/core": "32.3.3", "tslib": "^2.3.0" } }, "node_modules/@ag-grid-community/react": { - "version": "32.1.0", - "resolved": "https://registry.npmjs.org/@ag-grid-community/react/-/react-32.1.0.tgz", - "integrity": "sha512-ObaMk+g5IpfuiHSNar56IhJ0dLKkHaeMQYI9H1JlJyf5+3IafY1DiuGZ5mZTU7GyfNBgmMuRWrUxwOyt0tp7Lw==", + "version": "32.3.3", + "resolved": "https://registry.npmjs.org/@ag-grid-community/react/-/react-32.3.3.tgz", + "integrity": "sha512-YU8nOMZjvJsrbbW41PT1jFZQw67p1RGvTk3W7w1dFmtzXFOoXzpB2pWf2jMxREyLYGvz2P9TwmfeHEM50osSPQ==", "dependencies": { "prop-types": "^15.8.1" }, "peerDependencies": { - "@ag-grid-community/core": "32.1.0", - "react": "^16.3.0 || ^17.0.0 || ^18.0.0", - "react-dom": "^16.3.0 || ^17.0.0 || ^18.0.0" + "@ag-grid-community/core": "32.3.3", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0", + "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "node_modules/@ag-grid-community/styles": { - "version": "32.1.0", - "resolved": "https://registry.npmjs.org/@ag-grid-community/styles/-/styles-32.1.0.tgz", - "integrity": "sha512-OjakLetS/zr0g5mJWpnjldk/RjGnl7Rv3I/5cGuvtgdmSgS+4FNZMr8ZmyR8Bl34s0RM63OSIphpVaFGlnJM4w==" + "version": "32.3.3", + "resolved": "https://registry.npmjs.org/@ag-grid-community/styles/-/styles-32.3.3.tgz", + "integrity": "sha512-QAJc1CPbmFsAAq5M/8r0IOm8HL4Fb3eVK6tZXKzV9zibIereBjUwvvJRaSJa8iwtTlgxCtaULAQyE2gJcctphA==" }, "node_modules/@ampproject/remapping": { "version": "2.3.0", @@ -8912,9 +8912,9 @@ } }, "node_modules/ag-charts-types": { - "version": "10.1.0", - "resolved": "https://registry.npmjs.org/ag-charts-types/-/ag-charts-types-10.1.0.tgz", - "integrity": "sha512-pk9ft8hbgTXJ/thI/SEUR1BoauNplYExpcHh7tMOqVikoDsta1O15TB1ZL4XWnl4TPIzROBmONKsz7d8a2HBuQ==" + "version": "10.3.3", + "resolved": "https://registry.npmjs.org/ag-charts-types/-/ag-charts-types-10.3.3.tgz", + "integrity": "sha512-8rmyquaTkwfP4Lzei/W/cbkq9wwEl8+grIo3z97mtxrMIXh9sHJK1oJipd/u08MmBZrca5Jjtn5F1+UNPu/4fQ==" }, "node_modules/agent-base": { "version": "7.1.1", diff --git a/frontend/package.json b/frontend/package.json index 136507a04..6e8783097 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -30,11 +30,11 @@ ] }, "dependencies": { - "@ag-grid-community/client-side-row-model": "^32.0.2", - "@ag-grid-community/core": "^32.0.2", - "@ag-grid-community/csv-export": "^32.0.2", - "@ag-grid-community/react": "^32.0.2", - "@ag-grid-community/styles": "^32.0.2", + "@ag-grid-community/client-side-row-model": "^32.3.3", + "@ag-grid-community/core": "^32.3.3", + "@ag-grid-community/csv-export": "^32.3.3", + "@ag-grid-community/react": "^32.3.3", + "@ag-grid-community/styles": "^32.3.3", "@bcgov/bc-sans": "^2.1.0", "@emotion/react": "^11.11.3", "@emotion/styled": "^11.11.0", diff --git a/frontend/src/assets/locales/en/allocationAgreement.json b/frontend/src/assets/locales/en/allocationAgreement.json index 1da1af078..f120b7ebe 100644 --- a/frontend/src/assets/locales/en/allocationAgreement.json +++ b/frontend/src/assets/locales/en/allocationAgreement.json @@ -3,6 +3,7 @@ "noAllocationAgreementsFound": "No allocation agreements found", "addAllocationAgreementRowsTitle": "Allocation agreements (e.g., allocating responsibility for fuel)", "allocationAgreementSubtitle": "Enter allocation agreement details below", + "fuelCodeFieldRequiredError": "Error updating row: Fuel code field required", "allocationAgreementColLabels": { "transaction": "Responsibility", "transactionPartner": "Legal name of transaction partner", diff --git a/frontend/src/assets/locales/en/fuelExport.json b/frontend/src/assets/locales/en/fuelExport.json index 002fba7c1..d5cf55dc6 100644 --- a/frontend/src/assets/locales/en/fuelExport.json +++ b/frontend/src/assets/locales/en/fuelExport.json @@ -32,5 +32,6 @@ }, "validateMsg": { "isRequired": "{{field}} is required" - } + }, + "fuelCodeFieldRequiredError": "Error updating row: Fuel code field required" } diff --git a/frontend/src/assets/locales/en/fuelSupply.json b/frontend/src/assets/locales/en/fuelSupply.json index 3e6036080..93c75760b 100644 --- a/frontend/src/assets/locales/en/fuelSupply.json +++ b/frontend/src/assets/locales/en/fuelSupply.json @@ -9,6 +9,7 @@ "LoadFailMsg": "Failed to load supply of fuel rows", "addRow": "Add row", "rows": "rows", + "fuelCodeFieldRequiredError": "Error updating row: Fuel code field required", "fuelSupplyColLabels": { "complianceReportId": "Compliance Report ID", "fuelSupplyId": "Fuel supply ID", diff --git a/frontend/src/assets/locales/en/notifications.json b/frontend/src/assets/locales/en/notifications.json index 5f985416c..5d3ac2313 100644 --- a/frontend/src/assets/locales/en/notifications.json +++ b/frontend/src/assets/locales/en/notifications.json @@ -90,5 +90,24 @@ "managerRecommendation": "Compliance manager recommendation" } } - } + }, + "buttonStack": { + "selectAll": "Select all", + "unselectAll": "Unselect all", + "markAsRead": "Mark as read", + "deleteSelected": "Delete selected" + }, + "notificationColLabels": { + "type": "Type", + "date": "Date", + "user": "User", + "transactionId": "Transaction ID", + "organization": "Organization" + }, + "noNotificationsFound": "No notification messages found.", + "noNotificationsSelectedText": "No messages selected.", + "deleteSuccessText": "Successfully deleted selected message(s).", + "deleteErrorText": "An error occurred while deleting the selected message(s).", + "markAsReadSuccessText": "Successfully updated message(s) as read.", + "markAsReadErrorText": "An error occurred while updating the message(s) as read." } diff --git a/frontend/src/assets/locales/en/otherUses.json b/frontend/src/assets/locales/en/otherUses.json index c67e328ab..70b32650d 100644 --- a/frontend/src/assets/locales/en/otherUses.json +++ b/frontend/src/assets/locales/en/otherUses.json @@ -21,6 +21,7 @@ "approveConfirmText": "Are you sure you want to approve this other use entry?", "addRow": "Add row", "rows": "rows", + "fuelCodeFieldRequiredError": "Error updating row: Fuel code field required", "otherUsesColLabels": { "complianceReportId": "Compliance report ID", "fuelType": "Fuel type", diff --git a/frontend/src/components/BCDataGrid/BCGridBase.jsx b/frontend/src/components/BCDataGrid/BCGridBase.jsx index d26ef13b7..4a30c198e 100644 --- a/frontend/src/components/BCDataGrid/BCGridBase.jsx +++ b/frontend/src/components/BCDataGrid/BCGridBase.jsx @@ -34,7 +34,6 @@ export const BCGridBase = forwardRef(({ autoSizeStrategy, ...props }, ref) => { suppressMovableColumns suppressColumnMoveAnimation={false} reactiveCustomComponents - rowSelection="multiple" suppressCsvExport={false} suppressPaginationPanel suppressScrollOnNewData diff --git a/frontend/src/components/BCDataGrid/BCGridViewer.jsx b/frontend/src/components/BCDataGrid/BCGridViewer.jsx index 429c3c153..8a5638431 100644 --- a/frontend/src/components/BCDataGrid/BCGridViewer.jsx +++ b/frontend/src/components/BCDataGrid/BCGridViewer.jsx @@ -1,4 +1,4 @@ -import BCAlert from '@/components/BCAlert' +import BCAlert, { FloatingAlert } from '@/components/BCAlert' import BCBox from '@/components/BCBox' import { BCGridBase } from '@/components/BCDataGrid/BCGridBase' import { BCPagination } from '@/components/BCDataGrid/components' @@ -9,6 +9,7 @@ import BCButton from '../BCButton' export const BCGridViewer = ({ gridRef, + alertRef, loading, defaultColDef, columnDefs, @@ -202,6 +203,7 @@ export const BCGridViewer = ({ className="bc-grid-container" data-test="bc-grid-container" > + ({ cellRendererParams: props, pinned: 'left', maxWidth: 110, + minWidth: 90, editable: false, suppressKeyboardEvent, filter: false, diff --git a/frontend/src/components/BCDataGrid/components/Editors/DateEditor.jsx b/frontend/src/components/BCDataGrid/components/Editors/DateEditor.jsx index 01321d990..d2c5f51e4 100644 --- a/frontend/src/components/BCDataGrid/components/Editors/DateEditor.jsx +++ b/frontend/src/components/BCDataGrid/components/Editors/DateEditor.jsx @@ -2,11 +2,23 @@ import { DatePicker } from '@mui/x-date-pickers' import { format, parseISO } from 'date-fns' import { useEffect, useRef, useState } from 'react' -export const DateEditor = ({ value, onValueChange, minDate, maxDate }) => { +export const DateEditor = ({ + value, + onValueChange, + minDate, + maxDate, + rowIndex, + api, + autoOpenLastRow +}) => { const [selectedDate, setSelectedDate] = useState( value ? parseISO(value) : null ) - const [isOpen, setIsOpen] = useState(false) + const [isOpen, setIsOpen] = useState(() => { + if (!autoOpenLastRow) return false + const lastRowIndex = api.getLastDisplayedRowIndex() + return rowIndex === lastRowIndex + }) const containerRef = useRef(null) useEffect(() => { diff --git a/frontend/src/components/BCDataGrid/components/Renderers/ActionsRenderer2.jsx b/frontend/src/components/BCDataGrid/components/Renderers/ActionsRenderer2.jsx index 452d75531..4e6d622a9 100644 --- a/frontend/src/components/BCDataGrid/components/Renderers/ActionsRenderer2.jsx +++ b/frontend/src/components/BCDataGrid/components/Renderers/ActionsRenderer2.jsx @@ -7,7 +7,7 @@ export const ActionsRenderer2 = (props) => { .some((cell) => cell.rowIndex === props.node.rowIndex) return ( - + {props.enableDuplicate && ( diff --git a/frontend/src/constants/routes/apiRoutes.js b/frontend/src/constants/routes/apiRoutes.js index c96a65db8..d79d25464 100644 --- a/frontend/src/constants/routes/apiRoutes.js +++ b/frontend/src/constants/routes/apiRoutes.js @@ -73,6 +73,8 @@ export const apiRoutes = { getUserLoginHistories: '/users/login-history', getAuditLogs: '/audit-log/list', getAuditLog: '/audit-log/:auditLogId', + notifications: '/notifications/', + getNotifications: '/notifications/list', getNotificationsCount: '/notifications/count', getNotificationSubscriptions: '/notifications/subscriptions', saveNotificationSubscriptions: '/notifications/subscriptions/save', diff --git a/frontend/src/hooks/useNotifications.js b/frontend/src/hooks/useNotifications.js index 87b286a87..df52f3fa1 100644 --- a/frontend/src/hooks/useNotifications.js +++ b/frontend/src/hooks/useNotifications.js @@ -17,6 +17,54 @@ export const useNotificationsCount = (options) => { }) } +export const useGetNotificationMessages = ( + { page = 1, size = 10, sortOrders = [], filters = [] } = {}, + options +) => { + const client = useApiService() + return useQuery({ + queryKey: ['notification-messages', page, size, sortOrders, filters], + queryFn: async () => { + const response = await client.post(apiRoutes.getNotifications, { + page, + size, + sortOrders, + filters + }) + return response.data + }, + ...options + }) +} + +export const useMarkNotificationAsRead = (options) => { + const client = useApiService() + const queryClient = useQueryClient() + return useMutation({ + mutationFn: (_ids) => + client.put(apiRoutes.notifications, _ids), + onSettled: () => { + queryClient.invalidateQueries(['notifications-count']) + queryClient.invalidateQueries(['notifications-messages']) + }, + ...options + }) +} + +export const useDeleteNotificationMessages = (options) => { + const client = useApiService() + const queryClient = useQueryClient() + return useMutation({ + mutationFn: (_ids) => + client.delete(apiRoutes.notifications, { data: _ids }), + onSettled: () => { + queryClient.invalidateQueries(['notifications-count']) + queryClient.invalidateQueries(['notifications-messages']) + }, + ...options + }) +} + export const useNotificationSubscriptions = (options) => { const client = useApiService() return useQuery({ diff --git a/frontend/src/themes/base/globals.js b/frontend/src/themes/base/globals.js index 4d39236fb..f8a092509 100644 --- a/frontend/src/themes/base/globals.js +++ b/frontend/src/themes/base/globals.js @@ -88,6 +88,7 @@ const globals = { '--ag-header-column-resize-handle-height': '30%', '--ag-header-column-resize-handle-width': '2px', '--ag-header-column-resize-handle-color': '#dde2eb', + '--ag-material-accent-color': grey[700], '--ag-borders': `1px solid ${grey[700]}`, '--ag-border-color': grey[700], '--ag-odd-row-background-color': rgba(light.main, 0.6), @@ -109,6 +110,10 @@ const globals = { border: 'none', borderBottom: '2px solid #495057' }, + '.unread-row': { + fontWeight: 700, + color: grey[700] + }, // editor theme for ag-grid quertz theme '.ag-theme-quartz': { '--ag-borders': `0.5px solid ${grey[400]} !important`, @@ -197,10 +202,10 @@ const globals = { color: grey[600], textTransform: 'none', fontSize: pxToRem(14), - padding: '6px 12px', + padding: '6px 12px' }, '.ag-filter-apply-panel-button[data-ref="clearFilterButton"]:hover': { - color: grey[900], + color: grey[900] }, '.MuiPaper-elevation': { diff --git a/frontend/src/views/Admin/AdminMenu/components/UserLoginHistory.jsx b/frontend/src/views/Admin/AdminMenu/components/UserLoginHistory.jsx index 1dbdd618c..3f3fbc452 100644 --- a/frontend/src/views/Admin/AdminMenu/components/UserLoginHistory.jsx +++ b/frontend/src/views/Admin/AdminMenu/components/UserLoginHistory.jsx @@ -31,7 +31,6 @@ export const UserLoginHistory = () => { defaultMinWidth: 50, defaultMaxWidth: 600 }} - rowSelection={{ isRowSelectable: false }} /> diff --git a/frontend/src/views/AllocationAgreements/AddEditAllocationAgreements.jsx b/frontend/src/views/AllocationAgreements/AddEditAllocationAgreements.jsx index f09d3ee0e..6a79fd45e 100644 --- a/frontend/src/views/AllocationAgreements/AddEditAllocationAgreements.jsx +++ b/frontend/src/views/AllocationAgreements/AddEditAllocationAgreements.jsx @@ -3,7 +3,6 @@ import BCTypography from '@/components/BCTypography' import Grid2 from '@mui/material/Unstable_Grid2/Grid2' import { useTranslation } from 'react-i18next' import { useLocation, useNavigate, useParams } from 'react-router-dom' -import { BCAlert2 } from '@/components/BCAlert' import BCBox from '@/components/BCBox' import { BCGridEditor } from '@/components/BCDataGrid/BCGridEditor' import { @@ -66,7 +65,24 @@ export const AddEditAllocationAgreements = () => { severity: location.state.severity || 'info' }) } - }, [location.state]) + }, [location.state?.message, location.state?.severity]) + + const validate = (params, validationFn, errorMessage, alertRef, field = null) => { + const value = field ? params.node?.data[field] : params; + + if (field && params.colDef.field !== field) { + return true; + } + + if (!validationFn(value)) { + alertRef.current?.triggerAlert({ + message: errorMessage, + severity: 'error', + }); + return false; + } + return true; // Proceed with the update + }; const onGridReady = useCallback( async (params) => { @@ -80,13 +96,22 @@ export const AddEditAllocationAgreements = () => { ...item, id: item.id || uuid() // Ensure every item has a unique ID })) - setRowData(updatedRowData) + setRowData([...updatedRowData, { id: uuid() }]) } else { // If allocationAgreements is not available or empty, initialize with a single row setRowData([{ id: uuid() }]) } params.api.sizeColumnsToFit() + + setTimeout(() => { + const lastRowIndex = params.api.getLastDisplayedRowIndex() + params.api.setFocusedCell(lastRowIndex, 'allocationTransactionType') + params.api.startEditingCell({ + rowIndex: lastRowIndex, + colKey: 'allocationTransactionType' + }) + }, 100) }, [data] ) @@ -148,6 +173,22 @@ export const AddEditAllocationAgreements = () => { async (params) => { if (params.oldValue === params.newValue) return + const isValid = validate( + params, + (value) => { + return value !== null && !isNaN(value) && value > 0; + }, + 'Quantity supplied must be greater than 0.', + alertRef, + 'quantity', + ); + + if (!isValid) { + return + } + + if (!isValid) return + params.node.updateData({ ...params.node.data, validationStatus: 'pending' @@ -169,6 +210,29 @@ export const AddEditAllocationAgreements = () => { updatedData.ciOfFuel = DEFAULT_CI_FUEL[updatedData.fuelCategory] } + const isFuelCodeScenario = + params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE + if (isFuelCodeScenario && !updatedData.fuelCode) { + // Fuel code is required but not provided + setErrors((prevErrors) => ({ + ...prevErrors, + [params.node.data.id]: ['fuelCode'] + })) + + alertRef.current?.triggerAlert({ + message: t('allocationAgreement:fuelCodeFieldRequiredError'), + severity: 'error' + }) + + updatedData = { + ...updatedData, + validationStatus: 'error' + } + + params.node.updateData(updatedData) + return // Stop execution, do not proceed to save + } + try { setErrors({}) await saveRow(updatedData) diff --git a/frontend/src/views/AllocationAgreements/_schema.jsx b/frontend/src/views/AllocationAgreements/_schema.jsx index 6e503e3dd..e9d6e29d6 100644 --- a/frontend/src/views/AllocationAgreements/_schema.jsx +++ b/frontend/src/views/AllocationAgreements/_schema.jsx @@ -57,9 +57,13 @@ export const allocationAgreementColDefs = (optionsData, errors) => [ headerName: i18n.t( 'allocationAgreement:allocationAgreementColLabels.transaction' ), - cellEditor: 'agSelectCellEditor', + cellEditor: AutocompleteCellEditor, cellEditorParams: { - values: ['Allocated from', 'Allocated to'] + options: ['Allocated from', 'Allocated to'], + multiple: false, + disableCloseOnSelect: false, + freeSolo: false, + openOnFocus: true }, cellRenderer: (params) => params.value || @@ -196,6 +200,7 @@ export const allocationAgreementColDefs = (optionsData, errors) => [ params.data.units = fuelType?.units params.data.unrecognized = fuelType?.unrecognized params.data.provisionOfTheAct = null + params.data.fuelCode = undefined } return true }, @@ -302,16 +307,85 @@ export const allocationAgreementColDefs = (optionsData, errors) => [ }), cellStyle: (params) => { const style = StandardCellErrors(params, errors) - const conditionalStyle = + const isFuelCodeScenario = params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE - ? { backgroundColor: '#fff', borderColor: 'unset' } - : { backgroundColor: '#f2f2f2' } + const fuelType = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ) + const fuelCodes = fuelType?.fuelCodes || [] + const fuelCodeRequiredAndMissing = + isFuelCodeScenario && !params.data.fuelCode + + let conditionalStyle = {} + + // If required and missing, show red border and white background + if (fuelCodeRequiredAndMissing) { + style.borderColor = 'red' + style.backgroundColor = '#fff' + } else { + // Apply conditional styling if not missing + conditionalStyle = + isFuelCodeScenario && fuelCodes.length > 0 + ? { + backgroundColor: '#fff', + borderColor: style.borderColor || 'unset' + } + : { backgroundColor: '#f2f2f2' } + } + return { ...style, ...conditionalStyle } }, suppressKeyboardEvent, minWidth: 150, editable: (params) => - params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE, + params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE && + optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + )?.fuelCodes?.length > 0, + valueGetter: (params) => { + const fuelTypeObj = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ) + if (!fuelTypeObj) return params.data.fuelCode + + const isFuelCodeScenario = + params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE + const fuelCodes = + fuelTypeObj.fuelCodes?.map((item) => item.fuelCode) || [] + + if (isFuelCodeScenario && !params.data.fuelCode) { + // Autopopulate if only one fuel code is available + if (fuelCodes.length === 1) { + const singleFuelCode = fuelTypeObj.fuelCodes[0] + params.data.fuelCode = singleFuelCode.fuelCode + params.data.fuelCodeId = singleFuelCode.fuelCodeId + } + } + + return params.data.fuelCode + }, + valueSetter: (params) => { + if (params.newValue) { + params.data.fuelCode = params.newValue + + const fuelType = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ) + if (params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE) { + const matchingFuelCode = fuelType?.fuelCodes?.find( + (fuelCode) => params.data.fuelCode === fuelCode.fuelCode + ) + if (matchingFuelCode) { + params.data.fuelCodeId = matchingFuelCode.fuelCodeId + } + } + } else { + // If user clears the value + params.data.fuelCode = undefined + params.data.fuelCodeId = undefined + } + return true + }, tooltipValueGetter: (p) => 'Select the approved fuel code' }, { @@ -362,7 +436,7 @@ export const allocationAgreementColDefs = (optionsData, errors) => [ 'allocationAgreement:allocationAgreementColLabels.quantity' ), editor: NumberEditor, - valueFormatter, + valueFormatter: (params) => valueFormatter({ value: params.value }), cellEditor: NumberEditor, cellEditorParams: { precision: 0, diff --git a/frontend/src/views/FinalSupplyEquipments/AddEditFinalSupplyEquipments.jsx b/frontend/src/views/FinalSupplyEquipments/AddEditFinalSupplyEquipments.jsx index fb9fcafaa..2244efea4 100644 --- a/frontend/src/views/FinalSupplyEquipments/AddEditFinalSupplyEquipments.jsx +++ b/frontend/src/views/FinalSupplyEquipments/AddEditFinalSupplyEquipments.jsx @@ -75,18 +75,32 @@ export const AddEditFinalSupplyEquipments = () => { } ]) } else { - setRowData( - data.finalSupplyEquipments.map((item) => ({ + setRowData([ + ...data.finalSupplyEquipments.map((item) => ({ ...item, levelOfEquipment: item.levelOfEquipment.name, fuelMeasurementType: item.fuelMeasurementType.type, intendedUses: item.intendedUseTypes.map((i) => i.type), intendedUsers: item.intendedUserTypes.map((i) => i.typeName), id: uuid() - })) - ) + })), + { + id: uuid(), + complianceReportId, + supplyFromDate: `${compliancePeriod}-01-01`, + supplyToDate: `${compliancePeriod}-12-31` + } + ]) } params.api.sizeColumnsToFit() + + setTimeout(() => { + const lastRowIndex = params.api.getLastDisplayedRowIndex() + params.api.startEditingCell({ + rowIndex: lastRowIndex, + colKey: 'organizationName' + }) + }, 100) }, [compliancePeriod, complianceReportId, data] ) diff --git a/frontend/src/views/FuelExports/AddEditFuelExports.jsx b/frontend/src/views/FuelExports/AddEditFuelExports.jsx index ebaa03498..577cde049 100644 --- a/frontend/src/views/FuelExports/AddEditFuelExports.jsx +++ b/frontend/src/views/FuelExports/AddEditFuelExports.jsx @@ -13,14 +13,19 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useLocation, useNavigate, useParams } from 'react-router-dom' import { v4 as uuid } from 'uuid' -import { defaultColDef, fuelExportColDefs } from './_schema' +import { + defaultColDef, + fuelExportColDefs, + PROVISION_APPROVED_FUEL_CODE +} from './_schema' export const AddEditFuelExports = () => { const [rowData, setRowData] = useState([]) const gridRef = useRef(null) - const [gridApi, setGridApi] = useState() + const [, setGridApi] = useState() const [errors, setErrors] = useState({}) const [columnDefs, setColumnDefs] = useState([]) + const [gridReady, setGridReady] = useState(false) const alertRef = useRef() const location = useLocation() const { t } = useTranslation(['common', 'fuelExport']) @@ -74,21 +79,34 @@ export const AddEditFuelExports = () => { endUse: item.endUse?.type || 'Any', id: uuid() })) - setRowData(updatedRowData) + setRowData([...updatedRowData, { id: uuid(), compliancePeriod }]) } else { setRowData([{ id: uuid(), compliancePeriod }]) } params.api.sizeColumnsToFit() + + setTimeout(() => { + const lastRowIndex = params.api.getLastDisplayedRowIndex() + params.api.startEditingCell({ + rowIndex: lastRowIndex, + colKey: 'exportDate' + }) + setGridReady(true) + }, 500) }, [compliancePeriod, data] ) useEffect(() => { if (optionsData?.fuelTypes?.length > 0) { - const updatedColumnDefs = fuelExportColDefs(optionsData, errors) + const updatedColumnDefs = fuelExportColDefs( + optionsData, + errors, + gridReady + ) setColumnDefs(updatedColumnDefs) } - }, [errors, optionsData]) + }, [errors, gridReady, optionsData]) useEffect(() => { if (!fuelExportsLoading && !isArrayEmpty(data)) { @@ -143,7 +161,33 @@ export const AddEditFuelExports = () => { acc[key] = value return acc }, {}) + updatedData.compliancePeriod = compliancePeriod + + // Local validation before saving + const isFuelCodeScenario = + params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE + + if (isFuelCodeScenario && !updatedData.fuelCode) { + // Fuel code is required but not provided + setErrors((prevErrors) => ({ + ...prevErrors, + [params.node.data.id]: ['fuelCode'] + })) + + alertRef.current?.triggerAlert({ + message: t('fuelExport:fuelCodeFieldRequiredError'), + severity: 'error' + }) + + updatedData = { + ...updatedData, + validationStatus: 'error' + } + params.node.updateData(updatedData) + return // Don't proceed with save + } + try { setErrors({}) await saveRow(updatedData) @@ -189,7 +233,7 @@ export const AddEditFuelExports = () => { params.node.updateData(updatedData) }, - [saveRow, t] + [saveRow, t, compliancePeriod] ) const onAction = async (action, params) => { diff --git a/frontend/src/views/FuelExports/_schema.jsx b/frontend/src/views/FuelExports/_schema.jsx index f113671d3..d826fa567 100644 --- a/frontend/src/views/FuelExports/_schema.jsx +++ b/frontend/src/views/FuelExports/_schema.jsx @@ -17,6 +17,8 @@ import { fuelTypeOtherConditionalStyle } from '@/utils/fuelTypeOther' +export const PROVISION_APPROVED_FUEL_CODE = 'Fuel code - section 19 (b) (i)' + const cellErrorStyle = (params, errors) => { let style = {} if ( @@ -44,7 +46,7 @@ const cellErrorStyle = (params, errors) => { return style } -export const fuelExportColDefs = (optionsData, errors) => [ +export const fuelExportColDefs = (optionsData, errors, gridReady) => [ validation, actions({ enableDuplicate: false, @@ -103,7 +105,10 @@ export const fuelExportColDefs = (optionsData, errors) => [ ), suppressKeyboardEvent, cellEditor: DateEditor, - cellEditorPopup: true + cellEditorPopup: true, + cellEditorParams: { + autoOpenLastRow: !gridReady + } }, { field: 'fuelType', @@ -318,29 +323,94 @@ export const fuelExportColDefs = (optionsData, errors) => [ field: 'fuelCode', headerName: i18n.t('fuelExport:fuelExportColLabels.fuelCode'), cellEditor: 'agSelectCellEditor', - cellEditorParams: (params) => ({ - values: optionsData?.fuelTypes - ?.find((obj) => params.data.fuelType === obj.fuelType) - ?.fuelCodes.map((item) => item.fuelCode) - }), + suppressKeyboardEvent, + minWidth: 135, + cellEditorParams: (params) => { + const fuelTypeObj = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ) + return { + values: fuelTypeObj?.fuelCodes?.map((item) => item.fuelCode) || [] + } + }, cellStyle: (params) => { const style = cellErrorStyle(params, errors) + const fuelTypeObj = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ) + const fuelCodes = + fuelTypeObj?.fuelCodes.map((item) => item.fuelCode) || [] + const isFuelCodeScenario = + params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE + + // Check if fuel code is required (scenario) but missing + const fuelCodeRequiredAndMissing = + isFuelCodeScenario && !params.data.fuelCode + + if (fuelCodeRequiredAndMissing) { + // If required and missing, force a red border + style.borderColor = 'red' + } + const conditionalStyle = - optionsData?.fuelTypes - ?.find((obj) => params.data.fuelType === obj.fuelType) - ?.fuelCodes.map((item) => item.fuelCode).length > 0 && - /Fuel code/i.test(params.data.provisionOfTheAct) + fuelCodes.length > 0 && + isFuelCodeScenario && + !fuelCodeRequiredAndMissing ? { backgroundColor: '#fff', borderColor: 'unset' } : { backgroundColor: '#f2f2f2' } return { ...style, ...conditionalStyle } }, - suppressKeyboardEvent, - minWidth: 135, - editable: (params) => - optionsData?.fuelTypes - ?.find((obj) => params.data.fuelType === obj.fuelType) - ?.fuelCodes.map((item) => item.fuelCode).length > 0 && - /Fuel code/i.test(params.data.provisionOfTheAct) + editable: (params) => { + const fuelTypeObj = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ) + const fuelCodes = fuelTypeObj?.fuelCodes || [] + return ( + fuelCodes.length > 0 && + params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE + ) + }, + valueGetter: (params) => { + const fuelTypeObj = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ) + if (!fuelTypeObj) return params.data.fuelCode + + const isFuelCodeScenario = + params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE + const fuelCodes = + fuelTypeObj.fuelCodes?.map((item) => item.fuelCode) || [] + + if (isFuelCodeScenario && !params.data.fuelCode) { + // Autopopulate if only one fuel code is available + if (fuelCodes.length === 1) { + const singleFuelCode = fuelTypeObj.fuelCodes[0] + params.data.fuelCode = singleFuelCode.fuelCode + params.data.fuelCodeId = singleFuelCode.fuelCodeId + } + } + + return params.data.fuelCode + }, + valueSetter: (params) => { + const newCode = params.newValue + const fuelTypeObj = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ) + const selectedFuelCodeObj = fuelTypeObj?.fuelCodes.find( + (item) => item.fuelCode === newCode + ) + + if (selectedFuelCodeObj) { + params.data.fuelCode = selectedFuelCodeObj.fuelCode + params.data.fuelCodeId = selectedFuelCodeObj.fuelCodeId + } else { + params.data.fuelCode = undefined + params.data.fuelCodeId = undefined + } + + return true + } }, { field: 'quantity', diff --git a/frontend/src/views/FuelSupplies/AddEditFuelSupplies.jsx b/frontend/src/views/FuelSupplies/AddEditFuelSupplies.jsx index 7d7c4bd29..4badf3167 100644 --- a/frontend/src/views/FuelSupplies/AddEditFuelSupplies.jsx +++ b/frontend/src/views/FuelSupplies/AddEditFuelSupplies.jsx @@ -14,7 +14,11 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useLocation, useNavigate, useParams } from 'react-router-dom' import { v4 as uuid } from 'uuid' -import { defaultColDef, fuelSupplyColDefs } from './_schema' +import { + defaultColDef, + fuelSupplyColDefs, + PROVISION_APPROVED_FUEL_CODE +} from './_schema' export const AddEditFuelSupplies = () => { const [rowData, setRowData] = useState([]) @@ -54,13 +58,30 @@ export const AddEditFuelSupplies = () => { ) useEffect(() => { - if (location.state?.message) { + if (location?.state?.message) { alertRef.current?.triggerAlert({ message: location.state.message, severity: location.state.severity || 'info' }) } - }, [location.state]) + }, [location?.state?.message, location?.state?.severity]); + + const validate = (params, validationFn, errorMessage, alertRef, field = null) => { + const value = field ? params.node?.data[field] : params; + + if (field && params.colDef.field !== field) { + return true; + } + + if (!validationFn(value)) { + alertRef.current?.triggerAlert({ + message: errorMessage, + severity: 'error', + }); + return false; + } + return true; // Proceed with the update + }; const onGridReady = useCallback( async (params) => { @@ -79,10 +100,17 @@ export const AddEditFuelSupplies = () => { endUse: item.endUse?.type || 'Any', id: uuid() })) - setRowData(updatedRowData) + setRowData([...updatedRowData, { id: uuid() }]) } else { setRowData([{ id: uuid() }]) } + setTimeout(() => { + const lastRowIndex = params.api.getLastDisplayedRowIndex() + params.api.startEditingCell({ + rowIndex: lastRowIndex, + colKey: 'fuelType' + }) + }, 100) }, [data, complianceReportId, compliancePeriod] ) @@ -131,15 +159,6 @@ export const AddEditFuelSupplies = () => { 'fuelCategory', fuelCategoryOptions[0] ?? null ) - - const fuelCodeOptions = selectedFuelType.fuelCodes.map( - (code) => code.fuelCode - ) - params.node.setDataValue('fuelCode', fuelCodeOptions[0] ?? null) - params.node.setDataValue( - 'fuelCodeId', - selectedFuelType.fuelCodes[0]?.fuelCodeId ?? null - ) } } }, @@ -150,6 +169,20 @@ export const AddEditFuelSupplies = () => { async (params) => { if (params.oldValue === params.newValue) return + const isValid = validate( + params, + (value) => { + return value !== null && !isNaN(value) && value > 0; + }, + 'Quantity supplied must be greater than 0.', + alertRef, + 'quantity', + ); + + if (!isValid) { + return + } + params.node.updateData({ ...params.node.data, validationStatus: 'pending' @@ -164,6 +197,28 @@ export const AddEditFuelSupplies = () => { if (updatedData.fuelType === 'Other') { updatedData.ciOfFuel = DEFAULT_CI_FUEL[updatedData.fuelCategory] } + + const isFuelCodeScenario = + params.node.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE + if (isFuelCodeScenario && !params.node.data.fuelCode) { + // Set error on the row + setErrors({ + [params.node.data.id]: ['fuelCode'] + }) + + alertRef.current?.triggerAlert({ + message: t('fuelSupply:fuelCodeFieldRequiredError'), + severity: 'error' + }) + + // Update node data to reflect error state + params.node.updateData({ + ...params.node.data, + validationStatus: 'error' + }) + return // Stop saving further + } + try { setErrors({}) await saveRow(updatedData) diff --git a/frontend/src/views/FuelSupplies/_schema.jsx b/frontend/src/views/FuelSupplies/_schema.jsx index 912dc994d..5939d074b 100644 --- a/frontend/src/views/FuelSupplies/_schema.jsx +++ b/frontend/src/views/FuelSupplies/_schema.jsx @@ -19,6 +19,8 @@ import { } from '@/utils/grid/errorRenderers' import { apiRoutes } from '@/constants/routes' +export const PROVISION_APPROVED_FUEL_CODE = 'Fuel code - section 19 (b) (i)' + export const fuelSupplyColDefs = (optionsData, errors, warnings) => [ validation, actions({ @@ -102,7 +104,6 @@ export const fuelSupplyColDefs = (optionsData, errors, warnings) => [ params.data.provisionOfTheAct = null params.data.fuelCode = null params.data.fuelCodeId = null - params.data.quantity = 0 params.data.units = fuelType?.unit params.data.unrecognized = fuelType?.unrecognized } @@ -176,7 +177,6 @@ export const fuelSupplyColDefs = (optionsData, errors, warnings) => [ params.data.eer = null params.data.provisionOfTheAct = null params.data.fuelCode = null - params.data.quantity = 0 } return true }, @@ -292,21 +292,35 @@ export const fuelSupplyColDefs = (optionsData, errors, warnings) => [ field: 'fuelCode', headerName: i18n.t('fuelSupply:fuelSupplyColLabels.fuelCode'), cellEditor: 'agSelectCellEditor', - cellEditorParams: (params) => ({ - values: optionsData?.fuelTypes - ?.find((obj) => params.data.fuelType === obj.fuelType) - ?.fuelCodes.map((item) => item.fuelCode) - }), + cellEditorParams: (params) => { + const fuelType = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ) + return { + values: fuelType?.fuelCodes.map((item) => item.fuelCode) || [] + } + }, cellStyle: (params) => { const style = StandardCellWarningAndErrors(params, errors, warnings) - const conditionalStyle = - optionsData?.fuelTypes - ?.find((obj) => params.data.fuelType === obj.fuelType) - ?.fuelCodes.map((item) => item.fuelCode).length > 0 && - /Fuel code/i.test(params.data.provisionOfTheAct) - ? { backgroundColor: '#fff' } - : { backgroundColor: '#f2f2f2', borderColor: 'unset' } - return { ...style, ...conditionalStyle } + const isFuelCodeScenario = + params.data?.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE + const fuelType = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ) + const fuelCodes = fuelType?.fuelCodes || [] + const fuelCodeRequiredAndMissing = + isFuelCodeScenario && !params.data.fuelCode + + if (fuelCodeRequiredAndMissing) { + // Highlight the cell if fuel code is required but not selected + return { ...style, backgroundColor: '#fff', borderColor: 'red' } + } else if (isFuelCodeScenario && fuelCodes.length > 0) { + // Allow selection when scenario matches and codes are present + return { ...style, backgroundColor: '#fff', borderColor: 'unset' } + } else { + // Otherwise disabled styling + return { ...style, backgroundColor: '#f2f2f2', borderColor: 'unset' } + } }, suppressKeyboardEvent, minWidth: 135, @@ -314,29 +328,50 @@ export const fuelSupplyColDefs = (optionsData, errors, warnings) => [ const fuelType = optionsData?.fuelTypes?.find( (obj) => params.data.fuelType === obj.fuelType ) - if (fuelType) { - return ( - fuelType.fuelCodes.map((item) => item.fuelCode).length > 0 && - /Fuel code/i.test(params.data.provisionOfTheAct) - ) + return ( + params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE && + fuelType?.fuelCodes?.length > 0 + ) + }, + valueGetter: (params) => { + const fuelType = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ) + if (!fuelType) return params.data.fuelCode + + const isFuelCodeScenario = + params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE + const fuelCodes = fuelType?.fuelCodes?.map((item) => item.fuelCode) || [] + + if (isFuelCodeScenario && !params.data.fuelCode) { + // If only one code is available, auto-populate + if (fuelCodes.length === 1) { + const singleFuelCode = fuelType.fuelCodes[0] + params.data.fuelCode = singleFuelCode.fuelCode + params.data.fuelCodeId = singleFuelCode.fuelCodeId + } } - return false + + return params.data.fuelCode }, valueSetter: (params) => { if (params.newValue) { params.data.fuelCode = params.newValue - const fuelType = optionsData?.fuelTypes?.find( (obj) => params.data.fuelType === obj.fuelType ) - if (/Fuel code/i.test(params.data.provisionOfTheAct)) { - const matchingFuelCode = fuelType.fuelCodes?.find( + if (params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE) { + const matchingFuelCode = fuelType?.fuelCodes?.find( (fuelCode) => params.data.fuelCode === fuelCode.fuelCode ) if (matchingFuelCode) { params.data.fuelCodeId = matchingFuelCode.fuelCodeId } } + } else { + // If user clears the value + params.data.fuelCode = undefined + params.data.fuelCodeId = undefined } return true } @@ -345,7 +380,7 @@ export const fuelSupplyColDefs = (optionsData, errors, warnings) => [ field: 'quantity', headerComponent: RequiredHeader, headerName: i18n.t('fuelSupply:fuelSupplyColLabels.quantity'), - valueFormatter, + valueFormatter: (params) => valueFormatter({ value: params.value }), cellEditor: NumberEditor, cellEditorParams: { precision: 0, diff --git a/frontend/src/views/Notifications/NotificationMenu/components/IDIRAnalystNotificationSettings.jsx b/frontend/src/views/Notifications/NotificationMenu/components/IDIRAnalystNotificationSettings.jsx index 731dedd68..8294115b1 100644 --- a/frontend/src/views/Notifications/NotificationMenu/components/IDIRAnalystNotificationSettings.jsx +++ b/frontend/src/views/Notifications/NotificationMenu/components/IDIRAnalystNotificationSettings.jsx @@ -9,8 +9,10 @@ const IDIRAnalystNotificationSettings = () => { 'idirAnalyst.categories.transfers.submittedForReview', IDIR_ANALYST__TRANSFER__RESCINDED_ACTION: 'idirAnalyst.categories.transfers.rescindedAction', - IDIR_ANALYST__TRANSFER__DIRECTOR_RECORDEDIDIR_A__TR__DIRECTOR_RECORDED: - 'idirAnalyst.categories.transfers.directorRecorded' + IDIR_ANALYST__TRANSFER__DIRECTOR_RECORDED: + 'idirAnalyst.categories.transfers.directorRecorded', + IDIR_ANALYST__TRANSFER__RETURNED_TO_ANALYST: + 'idirAnalyst.categories.initiativeAgreements.returnedToAnalyst' }, 'idirAnalyst.categories.initiativeAgreements': { title: 'idirAnalyst.categories.initiativeAgreements.title', diff --git a/frontend/src/views/Notifications/NotificationMenu/components/NotificationSettingsForm.jsx b/frontend/src/views/Notifications/NotificationMenu/components/NotificationSettingsForm.jsx index 5a7e0bd93..bb2058482 100644 --- a/frontend/src/views/Notifications/NotificationMenu/components/NotificationSettingsForm.jsx +++ b/frontend/src/views/Notifications/NotificationMenu/components/NotificationSettingsForm.jsx @@ -34,8 +34,8 @@ import BCTypography from '@/components/BCTypography' const NotificationSettingsForm = ({ categories, - showEmailField, - initialEmail + showEmailField = false, + initialEmail = '' }) => { const { t } = useTranslation(['notifications']) const [isFormLoading, setIsFormLoading] = useState(false) @@ -468,9 +468,4 @@ NotificationSettingsForm.propTypes = { initialEmail: PropTypes.string } -NotificationSettingsForm.defaultProps = { - showEmailField: false, - initialEmail: '' -} - export default NotificationSettingsForm diff --git a/frontend/src/views/Notifications/NotificationMenu/components/Notifications.jsx b/frontend/src/views/Notifications/NotificationMenu/components/Notifications.jsx index 13d413031..afb8bd493 100644 --- a/frontend/src/views/Notifications/NotificationMenu/components/Notifications.jsx +++ b/frontend/src/views/Notifications/NotificationMenu/components/Notifications.jsx @@ -1,14 +1,238 @@ +import { useCallback, useMemo, useRef, useState } from 'react' +import { useNavigate } from 'react-router-dom' import { useTranslation } from 'react-i18next' -import BCTypography from '@/components/BCTypography' +import { Stack, Grid } from '@mui/material' +import { FontAwesomeIcon } from '@fortawesome/react-fontawesome' +import { faSquareCheck } from '@fortawesome/free-solid-svg-icons' + +import BCButton from '@/components/BCButton' +import { BCGridViewer } from '@/components/BCDataGrid/BCGridViewer' +import { columnDefs, routesMapping } from './_schema' +import { + useGetNotificationMessages, + useMarkNotificationAsRead, + useDeleteNotificationMessages +} from '@/hooks/useNotifications' +import { useCurrentUser } from '@/hooks/useCurrentUser' export const Notifications = () => { + const gridRef = useRef(null) + const alertRef = useRef(null) + const [isAllSelected, setIsAllSelected] = useState(false) + const [selectedRowCount, setSelectedRowCount] = useState(0) + const { t } = useTranslation(['notifications']) + const navigate = useNavigate() + const { data: currentUser } = useCurrentUser() + // react query hooks + const { refetch } = useGetNotificationMessages() + const markAsReadMutation = useMarkNotificationAsRead() + const deleteMutation = useDeleteNotificationMessages() + + // row class rules for unread messages + const rowClassRules = useMemo( + () => ({ + 'unread-row': (params) => !params.data.isRead + }), + [] + ) + const selectionColumnDef = useMemo(() => { + return { + sortable: false, + resizable: false, + suppressHeaderMenuButton: true, + headerTooltip: 'Checkboxes indicate selection' + } + }, []) + const rowSelection = useMemo(() => { + return { + mode: 'multiRow' + } + }, []) + + // Consolidated mutation handler + const handleMutation = useCallback( + (mutation, selectedNotifications, successMessage, errorMessage) => { + if (selectedNotifications.length === 0) { + alertRef.current?.triggerAlert({ + message: t('notifications:noNotificationsSelectedText'), + severity: 'warning' + }) + return + } + + mutation.mutate(selectedNotifications, { + onSuccess: () => { + // eslint-disable-next-line chai-friendly/no-unused-expressions + successMessage && + alertRef.current?.triggerAlert({ + message: t(successMessage), + severity: 'success' + }) + refetch() + }, + onError: (error) => { + alertRef.current?.triggerAlert({ + message: t(errorMessage, { error: error.message }), + severity: 'error' + }) + } + }) + }, + [t, refetch] + ) + + // Toggle selection for visible rows + const toggleSelectVisibleRows = useCallback(() => { + const gridApi = gridRef.current?.api + if (gridApi) { + gridApi.forEachNodeAfterFilterAndSort((node) => { + node.setSelected(!isAllSelected) + }) + setIsAllSelected(!isAllSelected) + } + }, [isAllSelected]) + + // event handlers for delete, markAsRead, and row-level deletes + const handleMarkAsRead = useCallback(() => { + const gridApi = gridRef.current?.api + if (gridApi) { + const selectedNotifications = gridApi + .getSelectedNodes() + .map((node) => node.data.notificationMessageId) + handleMutation( + markAsReadMutation, + selectedNotifications, + 'notifications:markAsReadSuccessText', + 'notifications:markAsReadErrorText' + ) + } + }, [handleMutation, markAsReadMutation]) + + const handleDelete = useCallback(() => { + const gridApi = gridRef.current?.api + if (gridApi) { + const selectedNotifications = gridApi + .getSelectedNodes() + .map((node) => node.data.notificationMessageId) + handleMutation( + deleteMutation, + selectedNotifications, + 'notifications:deleteSuccessText', + 'notifications:deleteErrorText' + ) + } + }, [handleMutation, deleteMutation]) + + const handleRowClicked = useCallback( + (params) => { + const { id, service, compliancePeriod } = JSON.parse(params.data.message) + // Select the appropriate route based on the notification type + const routeTemplate = routesMapping(currentUser)[service] + + if (routeTemplate && params.event.target.dataset.action !== 'delete') { + navigate( + // replace any matching query params by chaining these replace methods + routeTemplate + .replace(':transactionId', id) + .replace(':transferId', id) + .replace(':compliancePeriod', compliancePeriod) + .replace(':complianceReportId', id) + ) + handleMutation(markAsReadMutation, [params.data.notificationMessageId]) + } + }, + [currentUser, navigate] + ) + + const onCellClicked = useCallback( + (params) => { + if ( + params.column.colId === 'action' && + params.event.target.dataset.action + ) { + handleMutation( + deleteMutation, + [params.data.notificationMessageId], + 'notifications:deleteSuccessText', + 'notifications:deleteErrorText' + ) + } + }, + [handleMutation, deleteMutation] + ) + + const onSelectionChanged = useCallback(() => { + const gridApi = gridRef.current?.api + const visibleRows = [] + gridApi.forEachNodeAfterFilterAndSort((node) => { + visibleRows.push(node) + }) + const selectedRows = visibleRows.filter((node) => node.isSelected()) + setSelectedRowCount(selectedRows.length) + setIsAllSelected( + visibleRows.length > 0 && visibleRows.length === selectedRows.length + ) + }, []) return ( - <> - - {t('title.Notifications')} - - + + + + } + onClick={toggleSelectVisibleRows} + > + {isAllSelected + ? t('notifications:buttonStack.unselectAll') + : t('notifications:buttonStack.selectAll')} + + + {t('notifications:buttonStack.markAsRead')} + + + {t('notifications:buttonStack.deleteSelected')} + + + + ) } diff --git a/frontend/src/views/Notifications/NotificationMenu/components/_schema.jsx b/frontend/src/views/Notifications/NotificationMenu/components/_schema.jsx new file mode 100644 index 000000000..601fb29c6 --- /dev/null +++ b/frontend/src/views/Notifications/NotificationMenu/components/_schema.jsx @@ -0,0 +1,70 @@ +import { dateFormatter } from '@/utils/formatters' +import { actions } from '@/components/BCDataGrid/columns' +import { ROUTES } from '@/constants/routes' + +export const columnDefs = (t, currentUser) => [ + { + ...actions({ enableDelete: true }), + headerName: 'Delete', + pinned: '' + }, + { + colId: 'type', + field: 'type', + headerName: t('notifications:notificationColLabels.type') + }, + { + colId: 'date', + field: 'date', + cellDataType: 'date', + headerName: t('notifications:notificationColLabels.date'), + valueGetter: (params) => params.data.createDate, + valueFormatter: dateFormatter + }, + { + colId: 'user', + field: 'user', + headerName: t('notifications:notificationColLabels.user'), + valueGetter: (params) => params.data.originUserProfile.fullName.trim() + }, + { + colId: 'transactionId', + field: 'transactionId', + headerName: t('notifications:notificationColLabels.transactionId'), + valueGetter: (params) => params.data.relatedTransactionId + }, + { + colId: 'organization', + field: 'organization', + headerName: t('notifications:notificationColLabels.organization'), + valueGetter: (params) => { + const { service, toOrganizationId, fromOrganization } = JSON.parse( + params.data.message + ) + if ( + service === 'Transfer' && + toOrganizationId === currentUser?.organization?.organizationId + ) { + return fromOrganization + } + return params.data.relatedOrganization.name + } + } +] + +export const defaultColDef = { + editable: false, + resizable: true, + sortable: true +} + +export const routesMapping = (currentUser) => ({ + Transfer: ROUTES.TRANSFERS_VIEW, + AdminAdjustment: currentUser.isGovernmentUser + ? ROUTES.ADMIN_ADJUSTMENT_VIEW + : ROUTES.ORG_ADMIN_ADJUSTMENT_VIEW, + InitiativeAgreement: currentUser.isGovernmentUser + ? ROUTES.INITIATIVE_AGREEMENT_VIEW + : ROUTES.ORG_INITIATIVE_AGREEMENT_VIEW, + ComplianceReport: ROUTES.REPORTS_VIEW +}) diff --git a/frontend/src/views/Notifications/Notifications.jsx b/frontend/src/views/Notifications/Notifications.jsx deleted file mode 100644 index da555141f..000000000 --- a/frontend/src/views/Notifications/Notifications.jsx +++ /dev/null @@ -1,14 +0,0 @@ -import * as ROUTES from '@/constants/routes/routes.js' -import withFeatureFlag from '@/utils/withFeatureFlag.jsx' -import { FEATURE_FLAGS } from '@/constants/config.js' - -export const NotificationsBase = () => { - return
Notifications
-} - -export const Notifications = withFeatureFlag( - NotificationsBase, - FEATURE_FLAGS.NOTIFICATIONS, - ROUTES.DASHBOARD -) -Notifications.displayName = 'Notifications' diff --git a/frontend/src/views/NotionalTransfers/AddEditNotionalTransfers.jsx b/frontend/src/views/NotionalTransfers/AddEditNotionalTransfers.jsx index 2360dfb19..5d8073495 100644 --- a/frontend/src/views/NotionalTransfers/AddEditNotionalTransfers.jsx +++ b/frontend/src/views/NotionalTransfers/AddEditNotionalTransfers.jsx @@ -39,13 +39,30 @@ export const AddEditNotionalTransfers = () => { const navigate = useNavigate() useEffect(() => { - if (location.state?.message) { + if (location?.state?.message) { alertRef.triggerAlert({ message: location.state.message, severity: location.state.severity || 'info' }) } - }, [location.state]) + }, [location?.state?.message, location?.state?.severity]); + + const validate = (params, validationFn, errorMessage, alertRef, field = null) => { + const value = field ? params.node?.data[field] : params; + + if (field && params.colDef.field !== field) { + return true; + } + + if (!validationFn(value)) { + alertRef.current?.triggerAlert({ + message: errorMessage, + severity: 'error', + }); + return false; + } + return true; // Proceed with the update + }; const onGridReady = (params) => { const ensureRowIds = (rows) => { @@ -64,7 +81,13 @@ export const AddEditNotionalTransfers = () => { if (notionalTransfers && notionalTransfers.length > 0) { try { - setRowData(ensureRowIds(notionalTransfers)) + setRowData([ + ...ensureRowIds(notionalTransfers), + { + id: uuid(), + complianceReportId + } + ]) } catch (error) { alertRef.triggerAlert({ message: t('notionalTransfer:LoadFailMsg'), @@ -78,12 +101,34 @@ export const AddEditNotionalTransfers = () => { } params.api.sizeColumnsToFit() + + setTimeout(() => { + const lastRowIndex = params.api.getLastDisplayedRowIndex() + params.api.startEditingCell({ + rowIndex: lastRowIndex, + colKey: 'legalName' + }) + }, 100) } const onCellEditingStopped = useCallback( async (params) => { if (params.oldValue === params.newValue) return + const isValid = validate( + params, + (value) => { + return value !== null && !isNaN(value) && value > 0; + }, + 'Quantity supplied must be greater than 0.', + alertRef, + 'quantity', + ); + + if (!isValid) { + return + } + // Initialize updated data with 'pending' status params.node.updateData({ ...params.node.data, diff --git a/frontend/src/views/NotionalTransfers/_schema.jsx b/frontend/src/views/NotionalTransfers/_schema.jsx index 38673f676..6064debe4 100644 --- a/frontend/src/views/NotionalTransfers/_schema.jsx +++ b/frontend/src/views/NotionalTransfers/_schema.jsx @@ -140,7 +140,7 @@ export const notionalTransferColDefs = (optionsData, errors) => [ min: 0, showStepperButtons: false }, - valueFormatter, + valueFormatter: (params) => valueFormatter({ value: params.value }), cellStyle: (params) => StandardCellErrors(params, errors) } ] diff --git a/frontend/src/views/OtherUses/AddEditOtherUses.jsx b/frontend/src/views/OtherUses/AddEditOtherUses.jsx index bbd553ca3..d3fc18ec8 100644 --- a/frontend/src/views/OtherUses/AddEditOtherUses.jsx +++ b/frontend/src/views/OtherUses/AddEditOtherUses.jsx @@ -1,5 +1,4 @@ -import { BCAlert2 } from '@/components/BCAlert' -import BCButton from '@/components/BCButton' + import { BCGridEditor } from '@/components/BCDataGrid/BCGridEditor' import Loading from '@/components/Loading' import { @@ -8,16 +7,17 @@ import { useSaveOtherUses } from '@/hooks/useOtherUses' import { cleanEmptyStringValues } from '@/utils/formatters' -import { faFloppyDisk } from '@fortawesome/free-solid-svg-icons' -import { FontAwesomeIcon } from '@fortawesome/react-fontawesome' -import { Stack } from '@mui/material' import BCTypography from '@/components/BCTypography' import Grid2 from '@mui/material/Unstable_Grid2/Grid2' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useLocation, useNavigate, useParams } from 'react-router-dom' import { v4 as uuid } from 'uuid' -import { defaultColDef, otherUsesColDefs, PROVISION_APPROVED_FUEL_CODE} from './_schema' +import { + defaultColDef, + otherUsesColDefs, + PROVISION_APPROVED_FUEL_CODE +} from './_schema' import * as ROUTES from '@/constants/routes/routes.js' export const AddEditOtherUses = () => { @@ -55,31 +55,48 @@ export const AddEditOtherUses = () => { rows.map((row) => ({ ...row, id: row.id || uuid(), - isValid: true, - })); + isValid: true + })) - setRowData(ensureRowIds(otherUses)); + setRowData(ensureRowIds(otherUses)) } - }, [otherUses]); + }, [otherUses]) const findCiOfFuel = useCallback((data, optionsData) => { - let ciOfFuel = 0; + let ciOfFuel = 0 if (data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE) { const fuelType = optionsData?.fuelTypes?.find( (obj) => data.fuelType === obj.fuelType - ); + ) const fuelCode = fuelType?.fuelCodes?.find( (item) => item.fuelCode === data.fuelCode - ); - ciOfFuel = fuelCode?.carbonIntensity || 0; + ) + ciOfFuel = fuelCode?.carbonIntensity || 0 } else { const fuelType = optionsData?.fuelTypes?.find( (obj) => data.fuelType === obj.fuelType - ); - ciOfFuel = fuelType?.defaultCarbonIntensity || 0; + ) + ciOfFuel = fuelType?.defaultCarbonIntensity || 0 + } + return ciOfFuel + }, []) + + const validate = (params, validationFn, errorMessage, alertRef, field = null) => { + const value = field ? params.node?.data[field] : params; + + if (field && params.colDef.field !== field) { + return true; + } + + if (!validationFn(value)) { + alertRef.current?.triggerAlert({ + message: errorMessage, + severity: 'error', + }); + return false; } - return ciOfFuel; - }, []); + return true; // Proceed with the update + }; const onGridReady = (params) => { const ensureRowIds = (rows) => { @@ -98,7 +115,10 @@ export const AddEditOtherUses = () => { if (otherUses && otherUses.length > 0) { try { - setRowData(ensureRowIds(otherUses)) + setRowData([ + ...ensureRowIds(otherUses), + { id: uuid(), complianceReportId } + ]) } catch (error) { alertRef.triggerAlert({ message: t('otherUses:otherUsesLoadFailMsg'), @@ -112,6 +132,15 @@ export const AddEditOtherUses = () => { } params.api.sizeColumnsToFit() + + setTimeout(() => { + const lastRowIndex = params.api.getLastDisplayedRowIndex() + + params.api.startEditingCell({ + rowIndex: lastRowIndex, + colKey: 'fuelType' + }) + }, 100) } const onAction = async (action, params) => { @@ -151,25 +180,59 @@ export const AddEditOtherUses = () => { const ciOfFuel = findCiOfFuel(params.data, optionsData) params.node.setDataValue('ciOfFuel', ciOfFuel) - // Auto-populate the "Unit" field based on the selected fuel type - if (params.colDef.field === 'fuelType') { - const fuelType = optionsData?.fuelTypes?.find( - (obj) => params.data.fuelType === obj.fuelType - ); - if (fuelType && fuelType.units) { - params.node.setDataValue('units', fuelType.units); - } else { - params.node.setDataValue('units', ''); + // Auto-populate fields based on the selected fuel type + if (params.colDef.field === 'fuelType') { + const fuelType = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ); + if (fuelType) { + // Auto-populate the "units" field + if (fuelType.units) { + params.node.setDataValue('units', fuelType.units); + } else { + params.node.setDataValue('units', ''); + } + + // Auto-populate the "fuelCategory" field + const fuelCategoryOptions = fuelType.fuelCategories.map( + (item) => item.category + ); + params.node.setDataValue('fuelCategory', fuelCategoryOptions[0] ?? null); + + // Auto-populate the "fuelCode" field + const fuelCodeOptions = fuelType.fuelCodes.map( + (code) => code.fuelCode + ); + params.node.setDataValue('fuelCode', fuelCodeOptions[0] ?? null); + params.node.setDataValue( + 'fuelCodeId', + fuelType.fuelCodes[0]?.fuelCodeId ?? null + ); + } } } - } }, - [optionsData] + [optionsData, findCiOfFuel] ) const onCellEditingStopped = useCallback( async (params) => { if (params.oldValue === params.newValue) return + + const isValid = validate( + params, + (value) => { + return value !== null && !isNaN(value) && value > 0; + }, + 'Quantity supplied must be greater than 0.', + alertRef, + 'quantitySupplied', + ); + + if (!isValid) { + return + } + params.data.complianceReportId = complianceReportId params.data.validationStatus = 'pending' @@ -181,6 +244,29 @@ export const AddEditOtherUses = () => { // clean up any null or empty string values let updatedData = cleanEmptyStringValues(params.data) + const isFuelCodeScenario = + params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE + if (isFuelCodeScenario && !updatedData.fuelCode) { + // Fuel code is required but not provided + setErrors((prevErrors) => ({ + ...prevErrors, + [params.node.data.id]: ['fuelCode'] + })) + + alertRef.current?.triggerAlert({ + message: t('otherUses:fuelCodeFieldRequiredError'), + severity: 'error' + }) + + updatedData = { + ...updatedData, + validationStatus: 'error' + } + + params.node.updateData(updatedData) + return // Stop execution, do not proceed to save + } + try { setErrors({}) await saveRow(updatedData) diff --git a/frontend/src/views/OtherUses/_schema.jsx b/frontend/src/views/OtherUses/_schema.jsx index 9af82de0b..401b6db61 100644 --- a/frontend/src/views/OtherUses/_schema.jsx +++ b/frontend/src/views/OtherUses/_schema.jsx @@ -23,7 +23,7 @@ export const otherUsesColDefs = (optionsData, errors) => [ hide: true }, { - field:'otherUsesId', + field: 'otherUsesId', hide: true }, { @@ -42,19 +42,32 @@ export const otherUsesColDefs = (optionsData, errors) => [ suppressKeyboardEvent, cellRenderer: (params) => params.value || Select, - cellStyle: (params) => StandardCellErrors(params, errors) + cellStyle: (params) => StandardCellErrors(params, errors), + valueSetter: (params) => { + if (params.newValue) { + // TODO: Evaluate if additional fields need to be reset when fuel type changes + params.data.fuelType = params.newValue + params.data.fuelCode = undefined + } + return true + } }, { field: 'fuelCategory', headerName: i18n.t('otherUses:otherUsesColLabels.fuelCategory'), headerComponent: RequiredHeader, cellEditor: AutocompleteCellEditor, - cellEditorParams: { - options: optionsData.fuelCategories.map((obj) => obj.category), - multiple: false, - disableCloseOnSelect: false, - freeSolo: false, - openOnFocus: true + cellEditorParams: (params) => { + const fuelType = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ); + return { + options: fuelType ? fuelType.fuelCategories.map((item) => item.category) : [], + multiple: false, + disableCloseOnSelect: false, + freeSolo: false, + openOnFocus: true + }; }, suppressKeyboardEvent, cellRenderer: (params) => @@ -65,9 +78,7 @@ export const otherUsesColDefs = (optionsData, errors) => [ { field: 'provisionOfTheAct', headerComponent: RequiredHeader, - headerName: i18n.t( - 'otherUses:otherUsesColLabels.provisionOfTheAct' - ), + headerName: i18n.t('otherUses:otherUsesColLabels.provisionOfTheAct'), cellEditor: 'agSelectCellEditor', cellEditorParams: (params) => { const fuelType = optionsData?.fuelTypes?.find( @@ -89,11 +100,11 @@ export const otherUsesColDefs = (optionsData, errors) => [ suppressKeyboardEvent, valueSetter: (params) => { if (params.newValue !== params.oldValue) { - params.data.provisionOfTheAct = params.newValue; - params.data.fuelCode = ''; // Reset fuelCode when provisionOfTheAct changes - return true; + params.data.provisionOfTheAct = params.newValue + params.data.fuelCode = '' // Reset fuelCode when provisionOfTheAct changes + return true } - return false; + return false }, minWidth: 300, editable: true, @@ -105,61 +116,91 @@ export const otherUsesColDefs = (optionsData, errors) => [ headerName: i18n.t('otherUses:otherUsesColLabels.fuelCode'), cellEditor: AutocompleteCellEditor, cellEditorParams: (params) => { - const fuelType = optionsData?.fuelTypes?.find((obj) => params.data.fuelType === obj.fuelType); + const fuelType = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ) - return { - options: fuelType?.fuelCodes?.map((item) => item.fuelCode) || [], // Safely access fuelCodes - multiple: false, - disableCloseOnSelect: false, - freeSolo: false, - openOnFocus: true - }; + return { + options: fuelType?.fuelCodes?.map((item) => item.fuelCode) || [], // Safely access fuelCodes + multiple: false, + disableCloseOnSelect: false, + freeSolo: false, + openOnFocus: true + } }, cellRenderer: (params) => { - const fuelType = optionsData?.fuelTypes?.find((obj) => params.data.fuelType === obj.fuelType); + const fuelType = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ) if ( params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE && fuelType?.fuelCodes?.length > 0 ) { - return params.value || Select; + return ( + params.value || Select + ) } - return null; + return null }, cellStyle: (params) => { - const style = StandardCellErrors(params, errors); - const conditionalStyle = - params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE && - optionsData?.fuelTypes - ?.find((obj) => params.data.fuelType === obj.fuelType) - ?.fuelCodes?.length > 0 - ? { backgroundColor: '#fff', borderColor: 'unset' } - : { backgroundColor: '#f2f2f2' }; - return { ...style, ...conditionalStyle }; + const style = StandardCellErrors(params, errors) + const isFuelCodeScenario = + params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE + const fuelType = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ) + const fuelCodes = fuelType?.fuelCodes || [] + const fuelCodeRequiredAndMissing = + isFuelCodeScenario && !params.data.fuelCode + + // If required and missing, show red border + if (fuelCodeRequiredAndMissing) { + style.borderColor = 'red' + } + + const conditionalStyle = + isFuelCodeScenario && + fuelCodes.length > 0 && + !fuelCodeRequiredAndMissing + ? { + backgroundColor: '#fff', + borderColor: style.borderColor || 'unset' + } + : { backgroundColor: '#f2f2f2' } + + return { ...style, ...conditionalStyle } }, suppressKeyboardEvent, minWidth: 150, editable: (params) => { - const fuelType = optionsData?.fuelTypes?.find((obj) => params.data.fuelType === obj.fuelType); + const fuelType = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ) return ( params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE && fuelType?.fuelCodes?.length > 0 - ); + ) }, - valueSetter: (params) => { - if (params.newValue) { - params.data.fuelCode = params.newValue; + valueGetter: (params) => { + const isFuelCodeScenario = + params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE + const fuelType = optionsData?.fuelTypes?.find( + (obj) => params.data.fuelType === obj.fuelType + ) + const fuelCodes = fuelType?.fuelCodes || [] - const fuelType = optionsData?.fuelTypes?.find((obj) => params.data.fuelType === obj.fuelType); - if (params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE) { - const matchingFuelCode = fuelType?.fuelCodes?.find( - (fuelCode) => params.data.fuelCode === fuelCode.fuelCode - ); - if (matchingFuelCode) { - params.data.fuelCodeId = matchingFuelCode.fuelCodeId; - } - } - } - return true; + if ( + isFuelCodeScenario && + !params.data.fuelCode && + fuelCodes.length === 1 + ) { + // Autopopulate if only one fuel code is available + const singleFuelCode = fuelCodes[0] + params.data.fuelCode = singleFuelCode.fuelCode + params.data.fuelCodeId = singleFuelCode.fuelCodeId + } + + return params.data.fuelCode }, tooltipValueGetter: (p) => 'Select the approved fuel code' }, @@ -168,11 +209,10 @@ export const otherUsesColDefs = (optionsData, errors) => [ headerName: i18n.t('otherUses:otherUsesColLabels.quantitySupplied'), headerComponent: RequiredHeader, cellEditor: NumberEditor, - valueFormatter, + valueFormatter: (params) => valueFormatter({ value: params.value }), type: 'numericColumn', cellEditorParams: { precision: 0, - min: 0, showStepperButtons: false }, cellStyle: (params) => StandardCellErrors(params, errors), @@ -207,31 +247,31 @@ export const otherUsesColDefs = (optionsData, errors) => [ valueGetter: (params) => { const fuelType = optionsData?.fuelTypes?.find( (obj) => params.data.fuelType === obj.fuelType - ); + ) if (params.data.provisionOfTheAct === PROVISION_APPROVED_FUEL_CODE) { return ( fuelType?.fuelCodes?.find( (item) => item.fuelCode === params.data.fuelCode )?.carbonIntensity || 0 - ); + ) } if (fuelType) { if (params.data.fuelType === 'Other' && params.data.fuelCategory) { - const categories = fuelType.fuelCategories; + const categories = fuelType.fuelCategories const defaultCI = categories?.find( (cat) => cat.category === params.data.fuelCategory - )?.defaultAndPrescribedCi; + )?.defaultAndPrescribedCi - return defaultCI || 0; + return defaultCI || 0 } - return fuelType.defaultCarbonIntensity || 0; + return fuelType.defaultCarbonIntensity || 0 } - return 0; + return 0 }, - minWidth: 150, + minWidth: 150 }, { field: 'expectedUse',