diff --git a/backend/lcfs/db/migrations/versions/2024-12-06-09-59_9206124a098b.py b/backend/lcfs/db/migrations/versions/2024-12-06-09-59_9206124a098b.py new file mode 100644 index 000000000..fc805ff14 --- /dev/null +++ b/backend/lcfs/db/migrations/versions/2024-12-06-09-59_9206124a098b.py @@ -0,0 +1,25 @@ +"""Add Organization name to FSE + +Revision ID: 9206124a098b +Revises: aeaa26f5cdd5 +Create Date: 2024-12-04 09:59:22.876386 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = '9206124a098b' +down_revision = '26ab15f8ab18' +branch_labels = None +depends_on = None + + +def upgrade(): + # Add the column 'organization_name' to 'final_supply_equipment' table + op.add_column("final_supply_equipment", sa.Column("organization_name", sa.String(), nullable=True)) + + +def downgrade(): + # Remove the column 'organization_name' from 'final_supply_equipment' table + op.drop_column("final_supply_equipment", "organization_name") \ No newline at end of file diff --git a/backend/lcfs/db/migrations/versions/2024-12-09-22-33_cd8698fe40e6.py b/backend/lcfs/db/migrations/versions/2024-12-09-22-33_cd8698fe40e6.py new file mode 100644 index 000000000..8ec2b8223 --- /dev/null +++ b/backend/lcfs/db/migrations/versions/2024-12-09-22-33_cd8698fe40e6.py @@ -0,0 +1,34 @@ +"""Remove notifications email from user_profile + +Revision ID: cd8698fe40e6 +Revises: 26ab15f8ab18 +Create Date: 2024-12-09 22:33:29.554360 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "cd8698fe40e6" +down_revision = "9206124a098b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Remove notifications_email column from user_profile table + op.drop_column("user_profile", "notifications_email") + + +def downgrade() -> None: + # Add notifications_email column to user_profile table + op.add_column( + "user_profile", + sa.Column( + "notifications_email", + sa.String(length=255), + nullable=True, + comment="Email address used for notifications", + ), + ) diff --git a/backend/lcfs/db/migrations/versions/2024-12-09-23-31_7ae38a8413ab.py b/backend/lcfs/db/migrations/versions/2024-12-09-23-31_7ae38a8413ab.py new file mode 100644 index 000000000..51856f8c4 --- /dev/null +++ b/backend/lcfs/db/migrations/versions/2024-12-09-23-31_7ae38a8413ab.py @@ -0,0 +1,95 @@ +"""Update Fuel Types measured in volume to be other-uses + +Revision ID: 7ae38a8413ab +Revises: 26ab15f8ab18 +Create Date: 2024-12-09 19:31:18.199089 + +""" + +import sqlalchemy as sa +from alembic import op +from datetime import datetime + +# revision identifiers, used by Alembic. +revision = "7ae38a8413ab" +down_revision = "cd8698fe40e6" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + current_time = datetime.now() + + # Update the `other_uses_fossil_derived` field for all specified fuel types + op.execute( + f""" + UPDATE fuel_type + SET other_uses_fossil_derived = true, + update_date = '{current_time}', + update_user = 'no_user' + WHERE fuel_type IN ( + 'Alternative jet fuel', + 'Biodiesel', + 'Ethanol', + 'HDRD', + 'Renewable gasoline', + 'Renewable naphtha' + ) + """ + ) + + # Update the `other_uses_fossil_derived` field for all specified fuel types + op.execute( + f""" + UPDATE fuel_type + SET other_uses_fossil_derived = false, + update_date = '{current_time}', + update_user = 'no_user' + WHERE fuel_type IN ( + 'CNG', + 'Electricity', + 'Hydrogen', + 'LNG', + 'Propane' + ) + """ + ) + + +def downgrade() -> None: + current_time = datetime.now() + + # Revert the `other_uses_fossil_derived` field to false for the first set of fuel types + op.execute( + f""" + UPDATE fuel_type + SET other_uses_fossil_derived = false, + update_date = '{current_time}', + update_user = 'no_user' + WHERE fuel_type IN ( + 'Alternative jet fuel', + 'Biodiesel', + 'Ethanol', + 'HDRD', + 'Renewable gasoline', + 'Renewable naphtha' + ) + """ + ) + + # Revert the `other_uses_fossil_derived` field to true for the second set of fuel types + op.execute( + f""" + UPDATE fuel_type + SET other_uses_fossil_derived = true, + update_date = '{current_time}', + update_user = 'no_user' + WHERE fuel_type IN ( + 'CNG', + 'Electricity', + 'Hydrogen', + 'LNG', + 'Propane' + ) + """ + ) diff --git a/backend/lcfs/db/models/compliance/FinalSupplyEquipment.py b/backend/lcfs/db/models/compliance/FinalSupplyEquipment.py index e3b8d5685..90bd37b27 100644 --- a/backend/lcfs/db/models/compliance/FinalSupplyEquipment.py +++ b/backend/lcfs/db/models/compliance/FinalSupplyEquipment.py @@ -123,6 +123,7 @@ class FinalSupplyEquipment(BaseModel, Auditable): Double, nullable=False, comment="The longitude of the equipment location." ) notes = Column(Text, comment="Any additional notes related to the equipment.") + organization_name = Column(Text, comment="External organization name.") # relationships compliance_report = relationship( diff --git a/backend/lcfs/db/models/user/UserProfile.py b/backend/lcfs/db/models/user/UserProfile.py index 681d19aa0..7fecca2f8 100644 --- a/backend/lcfs/db/models/user/UserProfile.py +++ b/backend/lcfs/db/models/user/UserProfile.py @@ -26,9 +26,6 @@ class UserProfile(BaseModel, Auditable): String(150), unique=True, nullable=False, comment="keycloak Username" ) email = Column(String(255), nullable=True, comment="Primary email address") - notifications_email = Column( - String(255), nullable=True, comment="Email address used for notifications" - ) title = Column(String(100), nullable=True, comment="Professional Title") phone = Column(String(50), nullable=True, comment="Primary phone number") mobile_phone = Column(String(50), nullable=True, comment="Mobile phone number") diff --git a/backend/lcfs/tests/fuel_code/test_fuel_code_repo.py b/backend/lcfs/tests/fuel_code/test_fuel_code_repo.py index a8c4945d1..91b3d6b09 100644 --- a/backend/lcfs/tests/fuel_code/test_fuel_code_repo.py +++ b/backend/lcfs/tests/fuel_code/test_fuel_code_repo.py @@ -1,13 +1,40 @@ +from datetime import date +from unittest import mock + import pytest from unittest.mock import AsyncMock, MagicMock +from sqlalchemy.exc import NoResultFound +from sqlalchemy.orm import joinedload + from lcfs.web.api.fuel_code.repo import FuelCodeRepository from lcfs.db.models.fuel.TransportMode import TransportMode +from lcfs.db.models.fuel.FuelType import FuelType +from lcfs.db.models.fuel.FuelCategory import FuelCategory +from lcfs.db.models.fuel.UnitOfMeasure import UnitOfMeasure +from lcfs.db.models.fuel.ExpectedUseType import ExpectedUseType +from lcfs.db.models.fuel.FuelCode import FuelCode +from lcfs.db.models.fuel.FuelCodePrefix import FuelCodePrefix +from lcfs.db.models.fuel.FuelCodeStatus import FuelCodeStatus, FuelCodeStatusEnum +from lcfs.db.models.fuel.EnergyDensity import EnergyDensity +from lcfs.db.models.fuel.EnergyEffectivenessRatio import EnergyEffectivenessRatio +from lcfs.db.models.fuel.ProvisionOfTheAct import ProvisionOfTheAct +from lcfs.db.models.fuel.AdditionalCarbonIntensity import AdditionalCarbonIntensity +from lcfs.db.models.fuel.TargetCarbonIntensity import TargetCarbonIntensity +from lcfs.db.models.compliance.CompliancePeriod import CompliancePeriod +from lcfs.web.exception.exceptions import DatabaseException @pytest.fixture def mock_db(): """Fixture for mocking the database session.""" - return AsyncMock() + mock_session = AsyncMock() + mock_session.execute = AsyncMock() + mock_session.get_one = AsyncMock() + mock_session.add = MagicMock() + mock_session.flush = AsyncMock() + mock_session.refresh = AsyncMock() + mock_session.scalar = AsyncMock() + return mock_session @pytest.fixture @@ -18,21 +45,613 @@ def fuel_code_repo(mock_db): return repo +@pytest.fixture +def valid_fuel_code(): + """Fixture for creating a repository with a mocked database.""" + fc = FuelCode( + fuel_code_id=5, + fuel_suffix="105.0", + prefix_id=1, # Assuming prefix_id=1 exists + fuel_type_id=1, # Assuming fuel_type_id=1 exists + company="Test Company", + contact_name="Test Contact", + contact_email="test@example.com", + carbon_intensity=50.00, + edrms="EDRMS-001", + application_date=date.today(), + feedstock="Corn", + feedstock_location="USA", + feedstock_misc="", + fuel_production_facility_city="CityA", + fuel_production_facility_province_state="ProvinceA", + fuel_production_facility_country="CountryA", + last_updated=date.today(), + ) + return fc + + @pytest.mark.anyio -async def test_get_transport_mode_by_name(fuel_code_repo, mock_db): - # Define the test transport mode - transport_mode_name = "Truck" - mock_transport_mode = TransportMode(transport_mode_id=1, transport_mode="Truck") +async def test_get_fuel_types(fuel_code_repo, mock_db): + mock_fuel_type = FuelType(fuel_type_id=1, fuel_type="Diesel") + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_fuel_type] - # Mock the database query result - mock_db.execute.return_value.scalar_one = MagicMock() - mock_db.execute.return_value.scalar_one.return_value = mock_transport_mode + mock_db.execute.return_value = mock_result + result = await fuel_code_repo.get_fuel_types() + assert len(result) == 1 + assert result[0] == mock_fuel_type + mock_db.execute.assert_called_once() - # Call the repository method - result = await fuel_code_repo.get_transport_mode_by_name(transport_mode_name) - # Assert the result matches the mock data - assert result == mock_transport_mode +@pytest.mark.anyio +async def test_get_formatted_fuel_types(fuel_code_repo, mock_db): + # Setup mock data + mock_fuel_type = FuelType( + fuel_type_id=1, + fuel_type="Diesel", + default_carbon_intensity=80.0, + units="gCO2e/MJ", + unrecognized=False, + ) + mock_result = MagicMock() + mock_result.unique.return_value.scalars.return_value.all.return_value = [ + mock_fuel_type + ] + mock_db.execute.return_value = mock_result - # Ensure the database query was called + result = await fuel_code_repo.get_formatted_fuel_types() + assert len(result) == 1 + assert result[0]["fuel_type"] == "Diesel" mock_db.execute.assert_called_once() + + +@pytest.mark.anyio +async def test_get_fuel_type_by_name_found(fuel_code_repo, mock_db): + mock_fuel_type = FuelType(fuel_type_id=2, fuel_type="Gasoline") + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_fuel_type + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_fuel_type_by_name("Gasoline") + assert result == mock_fuel_type + + +@pytest.mark.anyio +async def test_get_fuel_type_by_name_not_found(fuel_code_repo, mock_db): + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = None + mock_db.execute.return_value = mock_result + + with pytest.raises(DatabaseException): + await fuel_code_repo.get_fuel_type_by_name("Nonexistent") + + +@pytest.mark.anyio +async def test_get_fuel_type_by_id_found(fuel_code_repo, mock_db): + mock_fuel_type = FuelType(fuel_type_id=3, fuel_type="Biofuel") + mock_db.get_one.return_value = mock_fuel_type + + result = await fuel_code_repo.get_fuel_type_by_id(3) + assert result == mock_fuel_type + mock_db.get_one.assert_called_once() + + +@pytest.mark.anyio +async def test_get_fuel_type_by_id_not_found(fuel_code_repo, mock_db): + mock_db.get_one.return_value = None + with pytest.raises(DatabaseException): + await fuel_code_repo.get_fuel_type_by_id(999) + + +@pytest.mark.anyio +async def test_get_fuel_categories(fuel_code_repo, mock_db): + mock_fc = FuelCategory(fuel_category_id=1, category="Renewable") + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_fc] + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_fuel_categories() + assert len(result) == 1 + assert result[0] == mock_fc + + +@pytest.mark.anyio +async def test_get_fuel_category_by_name(fuel_code_repo, mock_db): + mock_fc = FuelCategory(fuel_category_id=2, category="Fossil") + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = mock_fc + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_fuel_category_by_name("Fossil") + assert result == mock_fc + + +@pytest.mark.anyio +async def test_get_transport_modes(fuel_code_repo, mock_db): + mock_tm = TransportMode(transport_mode_id=1, transport_mode="Truck") + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_tm] + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_transport_modes() + assert len(result) == 1 + assert result[0] == mock_tm + + +@pytest.mark.anyio +async def test_get_transport_mode(fuel_code_repo, mock_db): + mock_tm = TransportMode(transport_mode_id=10, transport_mode="Ship") + mock_db.scalar.return_value = mock_tm + + result = await fuel_code_repo.get_transport_mode(10) + assert result == mock_tm + mock_db.scalar.assert_called_once() + + +@pytest.mark.anyio +async def test_get_transport_mode_by_name_found(fuel_code_repo, mock_db): + mock_tm = TransportMode(transport_mode_id=1, transport_mode="Truck") + mock_result = MagicMock() + mock_result.scalar_one.return_value = mock_tm + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_transport_mode_by_name("Truck") + assert result == mock_tm + + +@pytest.mark.anyio +async def test_get_transport_mode_by_name_not_found(fuel_code_repo, mock_db): + mock_result = MagicMock() + mock_result.scalar_one.side_effect = NoResultFound + mock_db.execute.return_value = mock_result + + with pytest.raises(DatabaseException): + await fuel_code_repo.get_transport_mode_by_name("NonexistentMode") + + +@pytest.mark.anyio +async def test_get_fuel_code_prefixes(fuel_code_repo, mock_db): + mock_prefix = FuelCodePrefix(fuel_code_prefix_id=1, prefix="BC") + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_prefix] + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_fuel_code_prefixes() + assert len(result) == 1 + assert result[0] == mock_prefix + + +@pytest.mark.anyio +async def test_get_fuel_code_prefix(fuel_code_repo, mock_db): + mock_prefix = FuelCodePrefix(fuel_code_prefix_id=2, prefix="AB") + mock_db.get_one.return_value = mock_prefix + + result = await fuel_code_repo.get_fuel_code_prefix(2) + assert result == mock_prefix + + +@pytest.mark.anyio +async def test_get_fuel_status_by_status(fuel_code_repo, mock_db): + mock_status = FuelCodeStatus( + fuel_code_status_id=1, status=FuelCodeStatusEnum.Approved + ) + mock_result = MagicMock() + mock_result.scalar.return_value = mock_status + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_fuel_status_by_status(FuelCodeStatusEnum.Approved) + assert result == mock_status + + +@pytest.mark.anyio +async def test_get_energy_densities(fuel_code_repo, mock_db): + ed = EnergyDensity(energy_density_id=1, density=35.0) + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [ed] + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_energy_densities() + assert len(result) == 1 + assert result[0] == ed + + +@pytest.mark.anyio +async def test_get_energy_density(fuel_code_repo, mock_db): + ed = EnergyDensity(energy_density_id=2, density=40.0) + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = ed + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_energy_density(10) + assert result == ed + + +@pytest.mark.anyio +async def test_get_energy_effectiveness_ratios(fuel_code_repo, mock_db): + eer = EnergyEffectivenessRatio(eer_id=1, ratio=2.0) + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [eer] + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_energy_effectiveness_ratios() + assert len(result) == 1 + assert result[0] == eer + + +@pytest.mark.anyio +async def test_get_units_of_measure(fuel_code_repo, mock_db): + uom = UnitOfMeasure(uom_id=1, name="gCO2e/MJ") + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [uom] + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_units_of_measure() + assert len(result) == 1 + assert result[0] == uom + + +@pytest.mark.anyio +async def test_get_expected_use_types(fuel_code_repo, mock_db): + eut = ExpectedUseType(expected_use_type_id=1, name="Vehicle") + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [eut] + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_expected_use_types() + assert len(result) == 1 + assert result[0] == eut + + +@pytest.mark.anyio +async def test_get_expected_use_type_by_name(fuel_code_repo, mock_db): + eut = ExpectedUseType(expected_use_type_id=2, name="Heating") + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = eut + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_expected_use_type_by_name("Heating") + assert result == eut + + +@pytest.mark.anyio +async def test_get_fuel_codes_paginated(fuel_code_repo, mock_db): + fc = FuelCode(fuel_code_id=1, fuel_suffix="101.0") + mock_db.execute.side_effect = [ + MagicMock(scalar=MagicMock(return_value=FuelCodeStatus())), + MagicMock(scalar=MagicMock(return_value=1)), + MagicMock( + unique=MagicMock( + return_value=MagicMock( + scalars=MagicMock( + return_value=MagicMock(all=MagicMock(return_value=[fc])) + ) + ) + ) + ), + ] + pagination = MagicMock(page=1, size=10, filters=[], sort_orders=[]) + result, count = await fuel_code_repo.get_fuel_codes_paginated(pagination) + assert len(result) == 1 + assert result[0] == fc + assert count == 1 + + +@pytest.mark.anyio +async def test_get_fuel_code_statuses(fuel_code_repo, mock_db): + fcs = FuelCodeStatus(fuel_code_status_id=1, status=FuelCodeStatusEnum.Approved) + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [fcs] + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_fuel_code_statuses() + assert len(result) == 1 + assert result[0] == fcs + + +@pytest.mark.anyio +async def test_create_fuel_code(fuel_code_repo, mock_db, valid_fuel_code): + mock_db.flush = AsyncMock() + mock_db.scalar.return_value = valid_fuel_code + + result = await fuel_code_repo.create_fuel_code(valid_fuel_code) + assert result == valid_fuel_code + mock_db.add.assert_called_once_with(valid_fuel_code) + + +@pytest.mark.anyio +async def test_get_fuel_code(fuel_code_repo, mock_db, valid_fuel_code): + mock_db.scalar.return_value = valid_fuel_code + result = await fuel_code_repo.get_fuel_code(1) + assert result == valid_fuel_code + + +@pytest.mark.anyio +async def test_get_fuel_code_status_enum(fuel_code_repo, mock_db): + fcs = FuelCodeStatus(fuel_code_status_id=2, status=FuelCodeStatusEnum.Deleted) + mock_db.scalar.return_value = fcs + result = await fuel_code_repo.get_fuel_code_status(FuelCodeStatusEnum.Deleted) + assert result == fcs + + +@pytest.mark.anyio +async def test_update_fuel_code(fuel_code_repo, mock_db, valid_fuel_code): + mock_db.flush = AsyncMock() + mock_db.refresh = AsyncMock() + updated = await fuel_code_repo.update_fuel_code(valid_fuel_code) + assert updated.fuel_code_id == 5 + + +@pytest.mark.anyio +async def test_delete_fuel_code(fuel_code_repo, mock_db): + mock_delete_status = FuelCodeStatus( + fuel_code_status_id=3, status=FuelCodeStatusEnum.Deleted + ) + mock_execute_result = MagicMock() + mock_execute_result.scalar.return_value = mock_delete_status + mock_db.execute.return_value = mock_execute_result + + mock_db.flush = AsyncMock() + + await fuel_code_repo.delete_fuel_code(10) + mock_db.execute.assert_awaited() # Check that execute was awaited + + +@pytest.mark.anyio +async def test_get_distinct_company_names(fuel_code_repo, mock_db): + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = ["CompanyA", "CompanyB"] + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_distinct_company_names("Com") + assert len(result) == 2 + + +@pytest.mark.anyio +async def test_get_contact_names_by_company(fuel_code_repo, mock_db): + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = ["John Doe", "Jane Doe"] + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_contact_names_by_company("CompanyA", "J") + assert len(result) == 2 + + +@pytest.mark.anyio +async def test_get_contact_email_by_company_and_name(fuel_code_repo, mock_db): + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = ["john@example.com"] + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_contact_email_by_company_and_name( + "CompanyA", "John Doe", "john@" + ) + assert len(result) == 1 + + +@pytest.mark.anyio +async def test_get_distinct_fuel_codes_by_code(fuel_code_repo, mock_db): + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = ["101.0", "101.1"] + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_distinct_fuel_codes_by_code("101", "BC") + assert len(result) == 2 + + +@pytest.mark.anyio +async def test_get_fuel_code_by_code_prefix(fuel_code_repo, mock_db): + fc = FuelCode(fuel_code_id=10, fuel_suffix="200.0") + mock_result = MagicMock() + mock_result.unique.return_value.scalars.return_value.all.return_value = [fc] + mock_db.execute.return_value = mock_result + + # Mock the next available suffix + fuel_code_repo.get_next_available_sub_version_fuel_code_by_prefix = AsyncMock( + return_value="200.1" + ) + + result = await fuel_code_repo.get_fuel_code_by_code_prefix("200.0", "BC") + assert len(result) == 1 + assert result[0].fuel_suffix == "200.1" + + +@pytest.mark.anyio +async def test_validate_fuel_code(fuel_code_repo, mock_db): + # Mock no existing code + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.validate_fuel_code("300.0", 1) + assert result == "300.0" + + # Mock existing code + mock_result.scalar_one_or_none.return_value = FuelCode( + fuel_code_id=5, fuel_suffix="300.0" + ) + mock_db.execute.return_value = mock_result + fuel_code_repo.get_next_available_sub_version_fuel_code_by_prefix = AsyncMock( + return_value="300.1" + ) + result = await fuel_code_repo.validate_fuel_code("300.0", 1) + assert result == "300.1" + + +@pytest.mark.anyio +async def test_get_next_available_fuel_code_by_prefix(fuel_code_repo, mock_db): + mock_execute_result = MagicMock() + mock_execute_result.scalar_one_or_none.return_value = "102.0" + mock_db.execute.return_value = mock_execute_result + + result = await fuel_code_repo.get_next_available_fuel_code_by_prefix("BC") + assert result == "102.0" + + +@pytest.mark.anyio +async def test_get_next_available_sub_version_fuel_code_by_prefix( + fuel_code_repo, mock_db +): + mock_execute_result = MagicMock() + mock_execute_result.scalar_one_or_none.return_value = "200.1" + mock_db.execute.return_value = mock_execute_result + + result = await fuel_code_repo.get_next_available_sub_version_fuel_code_by_prefix( + "200", 1 + ) + assert result == "200.1" + + +@pytest.mark.anyio +async def test_get_latest_fuel_codes(fuel_code_repo, mock_db, valid_fuel_code): + prefix = FuelCodePrefix(fuel_code_prefix_id=1, prefix="BC") + valid_fuel_code.fuel_code_prefix = prefix + + mock_result = MagicMock() + mock_result.unique.return_value.scalars.return_value.all.return_value = [ + valid_fuel_code + ] + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_latest_fuel_codes() + assert len(result) == 1 + # The code increments the version, e.g. "BC101.0" -> "BC101.1" + # Assuming suffix "105.0": + assert result[0]["fuel_code"].endswith(".1") + + +@pytest.mark.anyio +async def test_get_fuel_code_field_options(fuel_code_repo, mock_db): + mock_execute_result = MagicMock() + mock_execute_result.all.return_value = [ + ("CompanyA", "Corn", "USA", None, None, "John Doe", "john@example.com") + ] + mock_db.execute.return_value = mock_execute_result + + result = await fuel_code_repo.get_fuel_code_field_options() + assert len(result) == 1 + + +@pytest.mark.anyio +async def test_get_fp_locations(fuel_code_repo, mock_db): + mock_execute_result = MagicMock() + mock_execute_result.all.return_value = [("CityA", "ProvinceA", "CountryA")] + mock_db.execute.return_value = mock_execute_result + + result = await fuel_code_repo.get_fp_locations() + assert len(result) == 1 + + +@pytest.mark.anyio +async def test_get_fuel_code_by_name(fuel_code_repo, mock_db): + fc = FuelCode(fuel_code_id=50, fuel_suffix="150.0") + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = fc + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_fuel_code_by_name("BC150.0") + assert result == fc + + +@pytest.mark.anyio +async def test_get_provision_of_the_act_by_name(fuel_code_repo, mock_db): + poa = ProvisionOfTheAct(provision_of_the_act_id=1, name="Act Name") + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = poa + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_provision_of_the_act_by_name("Act Name") + assert result == poa + + +@pytest.mark.anyio +async def test_get_energy_effectiveness_ratio(fuel_code_repo, mock_db): + eer = EnergyEffectivenessRatio(eer_id=1, ratio=1.5) + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = eer + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_energy_effectiveness_ratio(1, 2, 3) + assert result == eer + + +@pytest.mark.anyio +async def test_get_target_carbon_intensities(fuel_code_repo, mock_db): + tci = TargetCarbonIntensity( + target_carbon_intensity_id=1, target_carbon_intensity=50.0 + ) + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [tci] + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_target_carbon_intensities(1, "2024") + assert len(result) == 1 + assert result[0] == tci + + +@pytest.mark.anyio +async def test_get_standardized_fuel_data(fuel_code_repo, mock_db): + # Mock dependencies + mock_fuel_type = FuelType( + fuel_type_id=1, fuel_type="Diesel", default_carbon_intensity=80.0 + ) + mock_db.get_one.return_value = mock_fuel_type + mock_db.execute.side_effect = [ + # energy density + MagicMock( + scalars=MagicMock( + return_value=MagicMock( + first=MagicMock(return_value=EnergyDensity(density=35.0)) + ) + ) + ), + # eer + MagicMock( + scalars=MagicMock( + return_value=MagicMock( + first=MagicMock(return_value=EnergyEffectivenessRatio(ratio=2.0)) + ) + ) + ), + # target carbon intensities + MagicMock( + scalars=MagicMock( + return_value=MagicMock( + all=MagicMock( + return_value=[ + TargetCarbonIntensity(target_carbon_intensity=50.0) + ] + ) + ) + ) + ), + # additional carbon intensity + MagicMock( + scalars=MagicMock( + return_value=MagicMock( + one_or_none=MagicMock( + return_value=AdditionalCarbonIntensity(intensity=5.0) + ) + ) + ) + ), + ] + + result = await fuel_code_repo.get_standardized_fuel_data( + fuel_type_id=1, fuel_category_id=2, end_use_id=3, compliance_period="2024" + ) + assert result.effective_carbon_intensity == 80.0 + assert result.target_ci == 50.0 + assert result.eer == 2.0 + assert result.energy_density == 35.0 + assert result.uci == 5.0 + + +@pytest.mark.anyio +async def test_get_additional_carbon_intensity(fuel_code_repo, mock_db): + aci = AdditionalCarbonIntensity(additional_uci_id=1, intensity=10.0) + mock_result = MagicMock() + mock_result.scalars.return_value.one_or_none.return_value = aci + mock_db.execute.return_value = mock_result + + result = await fuel_code_repo.get_additional_carbon_intensity(1, 2) + assert result == aci diff --git a/backend/lcfs/tests/fuel_code/test_fuel_code_service.py b/backend/lcfs/tests/fuel_code/test_fuel_code_service.py index c71608031..5dce49e52 100644 --- a/backend/lcfs/tests/fuel_code/test_fuel_code_service.py +++ b/backend/lcfs/tests/fuel_code/test_fuel_code_service.py @@ -210,7 +210,7 @@ async def test_approve_fuel_code_not_found(): repo_mock.get_fuel_code.return_value = None # Act & Assert - with pytest.raises(ServiceException): + with pytest.raises(ValueError, match="Fuel code not found"): await service.approve_fuel_code(fuel_code_id) repo_mock.get_fuel_code.assert_called_once_with(fuel_code_id) @@ -229,7 +229,7 @@ async def test_approve_fuel_code_invalid_status(): repo_mock.get_fuel_code.return_value = mock_fuel_code # Act & Assert - with pytest.raises(ServiceException): + with pytest.raises(ValueError, match="Fuel code is not in Draft"): await service.approve_fuel_code(fuel_code_id) repo_mock.get_fuel_code.assert_called_once_with(fuel_code_id) diff --git a/backend/lcfs/tests/other_uses/test_other_uses_repo.py b/backend/lcfs/tests/other_uses/test_other_uses_repo.py index 8bd39f562..67ea7d1d5 100644 --- a/backend/lcfs/tests/other_uses/test_other_uses_repo.py +++ b/backend/lcfs/tests/other_uses/test_other_uses_repo.py @@ -11,14 +11,19 @@ @pytest.fixture -def mock_db_session(): +def mock_query_result(): + # Setup mock for database query result chain + mock_result = AsyncMock() + mock_result.unique = MagicMock(return_value=mock_result) + mock_result.scalars = MagicMock(return_value=mock_result) + mock_result.all = MagicMock(return_value=[MagicMock(spec=OtherUses)]) + return mock_result + + +@pytest.fixture +def mock_db_session(mock_query_result): session = MagicMock(spec=AsyncSession) - execute_result = AsyncMock() - execute_result.unique = MagicMock(return_value=execute_result) - execute_result.scalars = MagicMock(return_value=execute_result) - execute_result.all = MagicMock(return_value=[MagicMock(spec=OtherUses)]) - execute_result.first = MagicMock(return_value=MagicMock(spec=OtherUses)) - session.execute.return_value = execute_result + session.execute = AsyncMock(return_value=mock_query_result) return session @@ -29,6 +34,14 @@ def other_uses_repo(mock_db_session): repo.fuel_code_repo.get_fuel_categories = AsyncMock(return_value=[]) repo.fuel_code_repo.get_fuel_types = AsyncMock(return_value=[]) repo.fuel_code_repo.get_expected_use_types = AsyncMock(return_value=[]) + + # Mock for local get_formatted_fuel_types method + async def mock_get_formatted_fuel_types(): + mock_result = await mock_db_session.execute(AsyncMock()) + return mock_result.unique().scalars().all() + + repo.get_formatted_fuel_types = AsyncMock(side_effect=mock_get_formatted_fuel_types) + return repo @@ -194,22 +207,21 @@ async def test_get_latest_other_uses_by_group_uuid(other_uses_repo, mock_db_sess mock_other_use_gov.user_type = UserTypeEnum.GOVERNMENT mock_other_use_gov.version = 2 - mock_other_use_supplier = MagicMock(spec=OtherUses) - mock_other_use_supplier.user_type = UserTypeEnum.SUPPLIER - mock_other_use_supplier.version = 3 + # Setup mock result chain + mock_result = AsyncMock() + mock_result.unique = MagicMock(return_value=mock_result) + mock_result.scalars = MagicMock(return_value=mock_result) + mock_result.first = MagicMock(return_value=mock_other_use_gov) - # Mock response with both government and supplier versions - mock_db_session.execute.return_value.scalars.return_value.first.side_effect = [ - mock_other_use_gov, - mock_other_use_supplier, - ] + # Configure mock db session + mock_db_session.execute = AsyncMock(return_value=mock_result) + other_uses_repo.db = mock_db_session result = await other_uses_repo.get_latest_other_uses_by_group_uuid(group_uuid) assert result.user_type == UserTypeEnum.GOVERNMENT assert result.version == 2 - @pytest.mark.anyio async def test_get_other_use_version_by_user(other_uses_repo, mock_db_session): group_uuid = "test-group-uuid" @@ -221,9 +233,14 @@ async def test_get_other_use_version_by_user(other_uses_repo, mock_db_session): mock_other_use.version = version mock_other_use.user_type = user_type - mock_db_session.execute.return_value.scalars.return_value.first.return_value = ( - mock_other_use - ) + # Set up mock result chain + mock_result = AsyncMock() + mock_result.scalars = MagicMock(return_value=mock_result) + mock_result.first = MagicMock(return_value=mock_other_use) + + # Configure mock db session + mock_db_session.execute = AsyncMock(return_value=mock_result) + other_uses_repo.db = mock_db_session result = await other_uses_repo.get_other_use_version_by_user( group_uuid, version, user_type diff --git a/backend/lcfs/tests/other_uses/test_other_uses_services.py b/backend/lcfs/tests/other_uses/test_other_uses_services.py index 3f4705501..14b3bb518 100644 --- a/backend/lcfs/tests/other_uses/test_other_uses_services.py +++ b/backend/lcfs/tests/other_uses/test_other_uses_services.py @@ -189,7 +189,7 @@ async def test_update_other_use_not_found(other_uses_service): mock_repo.get_other_use_version_by_user = AsyncMock(return_value=None) - with pytest.raises(ServiceException): + with pytest.raises(ValueError, match="Other use not found"): await service.update_other_use(other_use_data, UserTypeEnum.SUPPLIER) diff --git a/backend/lcfs/tests/user/test_user_repo.py b/backend/lcfs/tests/user/test_user_repo.py index 56cb75403..a96788933 100644 --- a/backend/lcfs/tests/user/test_user_repo.py +++ b/backend/lcfs/tests/user/test_user_repo.py @@ -1,11 +1,10 @@ -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock import pytest from lcfs.db.models import UserProfile, UserLoginHistory from lcfs.web.api.user.repo import UserRepository from lcfs.tests.user.user_payloads import user_orm_model -from lcfs.web.exception.exceptions import DataNotFoundException @pytest.fixture @@ -50,14 +49,13 @@ async def test_create_login_history(dbsession, user_repo): @pytest.mark.anyio -async def test_update_notifications_email_success(dbsession, user_repo): +async def test_update_email_success(dbsession, user_repo): # Arrange: Create a user in the database user = UserProfile( keycloak_user_id="user_id_1", keycloak_email="user1@domain.com", keycloak_username="username1", email="user1@domain.com", - notifications_email=None, title="Developer", phone="1234567890", mobile_phone="0987654321", @@ -70,10 +68,10 @@ async def test_update_notifications_email_success(dbsession, user_repo): await dbsession.commit() await dbsession.refresh(user) - # Act: Update the notifications email - updated_user = await user_repo.update_notifications_email( + # Act: Update the email + updated_user = await user_repo.update_email( user_profile_id=1, email="new_email@domain.com" ) - # Assert: Check if the notifications email was updated - assert updated_user.notifications_email == "new_email@domain.com" + # Assert: Check if the email was updated + assert updated_user.email == "new_email@domain.com" diff --git a/backend/lcfs/tests/user/test_user_views.py b/backend/lcfs/tests/user/test_user_views.py index 75a228c0c..1c68b120d 100644 --- a/backend/lcfs/tests/user/test_user_views.py +++ b/backend/lcfs/tests/user/test_user_views.py @@ -468,19 +468,18 @@ async def test_track_logged_in_success(client: AsyncClient, fastapi_app, set_moc @pytest.mark.anyio -async def test_update_notifications_email_success( +async def test_update_email_success( client: AsyncClient, fastapi_app, set_mock_user, - add_models, ): set_mock_user(fastapi_app, [RoleEnum.GOVERNMENT]) # Prepare request data - request_data = {"notifications_email": "new_email@domain.com"} + request_data = {"email": "new_email@domain.com"} # Act: Send POST request to the endpoint - url = fastapi_app.url_path_for("update_notifications_email") + url = fastapi_app.url_path_for("update_email") response = await client.post(url, json=request_data) # Assert: Check response status and content diff --git a/backend/lcfs/tests/user/user_payloads.py b/backend/lcfs/tests/user/user_payloads.py index 3491cfd5e..38b9c5acc 100644 --- a/backend/lcfs/tests/user/user_payloads.py +++ b/backend/lcfs/tests/user/user_payloads.py @@ -6,7 +6,6 @@ keycloak_email="email@domain.com", keycloak_username="username", email="email@domain.com", - notifications_email=None, title="Developer", phone="1234567890", mobile_phone="0987654321", diff --git a/backend/lcfs/web/api/compliance_report/schema.py b/backend/lcfs/web/api/compliance_report/schema.py index 0f157be8b..9eb215c53 100644 --- a/backend/lcfs/web/api/compliance_report/schema.py +++ b/backend/lcfs/web/api/compliance_report/schema.py @@ -40,10 +40,6 @@ class CompliancePeriodSchema(BaseSchema): display_order: Optional[int] = None -class ComplianceReportOrganizationSchema(BaseSchema): - organization_id: int - name: str - class SummarySchema(BaseSchema): summary_id: int @@ -118,6 +114,7 @@ class FSEOptionsSchema(BaseSchema): class FinalSupplyEquipmentSchema(BaseSchema): final_supply_equipment_id: int compliance_report_id: int + organization_name: str supply_from_date: date supply_to_date: date registration_nbr: str diff --git a/backend/lcfs/web/api/final_supply_equipment/repo.py b/backend/lcfs/web/api/final_supply_equipment/repo.py index 398f01585..b3680584f 100644 --- a/backend/lcfs/web/api/final_supply_equipment/repo.py +++ b/backend/lcfs/web/api/final_supply_equipment/repo.py @@ -1,6 +1,6 @@ import structlog from typing import List, Tuple -from lcfs.db.models.compliance import EndUserType, FinalSupplyEquipment +from lcfs.db.models.compliance import EndUserType, FinalSupplyEquipment, ComplianceReport from lcfs.db.models.compliance.FinalSupplyEquipmentRegNumber import ( FinalSupplyEquipmentRegNumber, ) @@ -27,8 +27,14 @@ def __init__(self, db: AsyncSession = Depends(get_async_db_session)): @repo_handler async def get_fse_options( - self, - ) -> Tuple[List[EndUseType], List[LevelOfEquipment], List[FuelMeasurementType], List[PortsEnum]]: + self, organization + ) -> Tuple[ + List[EndUseType], + List[LevelOfEquipment], + List[FuelMeasurementType], + List[PortsEnum], + List[str], + ]: """ Retrieve all FSE options in a single database transaction """ @@ -37,13 +43,15 @@ async def get_fse_options( levels_of_equipment = await self.get_levels_of_equipment() fuel_measurement_types = await self.get_fuel_measurement_types() intended_user_types = await self.get_intended_user_types() + organization_names = await self.get_organization_names(organization) ports = list(PortsEnum) return ( intended_use_types, levels_of_equipment, fuel_measurement_types, intended_user_types, - ports + ports, + organization_names, ) async def get_intended_use_types(self) -> List[EndUseType]: @@ -94,6 +102,29 @@ async def get_intended_user_types(self) -> List[EndUserType]: .all() ) + async def get_organization_names(self, organization) -> List[str]: + """ + Retrieve unique organization names for Final Supply Equipment records + associated with the given organization_id via ComplianceReport. + + Args: + organization_id (int): The ID of the organization. + + Returns: + List[str]: A list of unique organization names. + """ + organization_names = ( + await self.db.execute( + select(distinct(FinalSupplyEquipment.organization_name)) + .join(ComplianceReport, FinalSupplyEquipment.compliance_report_id == ComplianceReport.compliance_report_id) + .filter(ComplianceReport.organization_id == organization.organization_id) + .filter(FinalSupplyEquipment.organization_name.isnot(None)) + ) + ).all() + + # Extract strings from the list of tuples + return [name[0] for name in organization_names] + @repo_handler async def get_intended_user_by_name(self, intended_user: str) -> EndUseType: """ diff --git a/backend/lcfs/web/api/final_supply_equipment/schema.py b/backend/lcfs/web/api/final_supply_equipment/schema.py index 38f80b2fa..2dc81e8f3 100644 --- a/backend/lcfs/web/api/final_supply_equipment/schema.py +++ b/backend/lcfs/web/api/final_supply_equipment/schema.py @@ -33,11 +33,13 @@ class FSEOptionsSchema(BaseSchema): levels_of_equipment: List[LevelOfEquipmentSchema] intended_user_types: List[EndUserTypeSchema] ports: List[PortsEnum] + organization_names: List[str] class FinalSupplyEquipmentCreateSchema(BaseSchema): final_supply_equipment_id: Optional[int] = None compliance_report_id: Optional[int] = None + organization_name: str supply_from_date: date supply_to_date: date kwh_usage: float diff --git a/backend/lcfs/web/api/final_supply_equipment/services.py b/backend/lcfs/web/api/final_supply_equipment/services.py index 9cc225935..a70b1ce4b 100644 --- a/backend/lcfs/web/api/final_supply_equipment/services.py +++ b/backend/lcfs/web/api/final_supply_equipment/services.py @@ -29,13 +29,15 @@ def __init__( @service_handler async def get_fse_options(self): """Fetches all FSE options concurrently.""" + organization = self.request.user.organization ( intended_use_types, levels_of_equipment, fuel_measurement_types, intended_user_types, - ports - ) = await self.repo.get_fse_options() + ports, + organization_names, + ) = await self.repo.get_fse_options(organization) return { "intended_use_types": [ @@ -52,6 +54,7 @@ async def get_fse_options(self): EndUserTypeSchema.model_validate(u) for u in intended_user_types ], "ports": [port.value for port in ports], + "organization_names": organization_names, } async def convert_to_fse_model(self, fse: FinalSupplyEquipmentCreateSchema): @@ -141,6 +144,7 @@ async def update_final_supply_equipment( if not existing_fse: raise ValueError("final supply equipment not found") + existing_fse.organization_name = fse_data.organization_name existing_fse.kwh_usage = fse_data.kwh_usage existing_fse.serial_nbr = fse_data.serial_nbr existing_fse.manufacturer = fse_data.manufacturer diff --git a/backend/lcfs/web/api/fuel_code/repo.py b/backend/lcfs/web/api/fuel_code/repo.py index aa0577466..594bd156f 100644 --- a/backend/lcfs/web/api/fuel_code/repo.py +++ b/backend/lcfs/web/api/fuel_code/repo.py @@ -396,19 +396,8 @@ async def create_fuel_code(self, fuel_code: FuelCode) -> FuelCode: """ self.db.add(fuel_code) await self.db.flush() - await self.db.refresh( - fuel_code, - [ - "fuel_code_status", - "fuel_code_prefix", - "fuel_type", - "feedstock_fuel_transport_modes", - "finished_fuel_transport_modes", - ], - ) - # Manually load nested relationships - await self.db.refresh(fuel_code.fuel_type, ["provision_1", "provision_2"]) - return fuel_code + result = await self.get_fuel_code(fuel_code.fuel_code_id) + return result @repo_handler async def get_fuel_code(self, fuel_code_id: int) -> FuelCode: @@ -526,7 +515,7 @@ async def get_distinct_fuel_codes_by_code( @repo_handler async def get_fuel_code_by_code_prefix( self, fuel_suffix: str, prefix: str - ) -> List[str]: + ) -> list[FuelCodeCloneSchema]: query = ( select(FuelCode) .options( @@ -593,9 +582,14 @@ async def validate_fuel_code(self, suffix: str, prefix_id: int) -> str: result = (await self.db.execute(query)).scalar_one_or_none() if result: fuel_code_main_version = suffix.split(".")[0] - return await self.get_next_available_sub_version_fuel_code_by_prefix( + suffix = await self.get_next_available_sub_version_fuel_code_by_prefix( fuel_code_main_version, prefix_id ) + if int(suffix.split(".")[1]) > 9: + return await self.get_next_available_fuel_code_by_prefix( + result.fuel_code_prefix.prefix + ) + return suffix else: return suffix @@ -757,7 +751,7 @@ async def get_fuel_code_by_name(self, fuel_code: str) -> FuelCode: .options( contains_eager(FuelCode.fuel_code_prefix), joinedload(FuelCode.fuel_code_status), - joinedload(FuelCode.fuel_code_type), + joinedload(FuelCode.fuel_type), ) .where( and_( @@ -782,18 +776,35 @@ async def get_provision_of_the_act_by_name( @repo_handler async def get_energy_effectiveness_ratio( - self, fuel_type_id: int, fuel_category_id: int, end_use_type_id: int + self, fuel_type_id: int, fuel_category_id: int, end_use_type_id: Optional[int] ) -> EnergyEffectivenessRatio: + """ + Retrieves the Energy Effectiveness Ratio based on fuel type, fuel category, + and optionally the end use type. - stmt = select(EnergyEffectivenessRatio).where( + Args: + fuel_type_id (int): The ID of the fuel type. + fuel_category_id (int): The ID of the fuel category. + end_use_type_id (Optional[int]): The ID of the end use type (optional). + + Returns: + Optional[EnergyEffectivenessRatio]: The matching EnergyEffectivenessRatio record or None. + """ + conditions = [ EnergyEffectivenessRatio.fuel_type_id == fuel_type_id, EnergyEffectivenessRatio.fuel_category_id == fuel_category_id, - EnergyEffectivenessRatio.end_use_type_id == end_use_type_id, - ) + ] + + if end_use_type_id is not None: + conditions.append( + EnergyEffectivenessRatio.end_use_type_id == end_use_type_id + ) + + stmt = select(EnergyEffectivenessRatio).where(*conditions) result = await self.db.execute(stmt) - energy_density = result.scalars().first() + energy_effectiveness_ratio = result.scalars().first() - return energy_density + return energy_effectiveness_ratio @repo_handler async def get_target_carbon_intensities( @@ -854,12 +865,10 @@ async def get_standardized_fuel_data( effective_carbon_intensity = fuel_type.default_carbon_intensity # Get energy effectiveness ratio (EER) - eer = None - if fuel_type_id and fuel_category_id and end_use_id: - energy_effectiveness = await self.get_energy_effectiveness_ratio( - fuel_type_id, fuel_category_id, end_use_id - ) - eer = energy_effectiveness.ratio if energy_effectiveness else 1.0 + energy_effectiveness = await self.get_energy_effectiveness_ratio( + fuel_type_id, fuel_category_id, end_use_id + ) + eer = energy_effectiveness.ratio if energy_effectiveness else 1.0 # Fetch target carbon intensity (TCI) target_ci = None diff --git a/backend/lcfs/web/api/fuel_code/services.py b/backend/lcfs/web/api/fuel_code/services.py index 039634e6a..d40fde544 100644 --- a/backend/lcfs/web/api/fuel_code/services.py +++ b/backend/lcfs/web/api/fuel_code/services.py @@ -208,6 +208,8 @@ async def convert_to_model( transport_mode_id=matching_transport_mode.transport_mode_id, ) ) + else: + raise ValueError(f"Invalid transport mode: {transport_mode}") for transport_mode in fuel_code_schema.finished_fuel_transport_mode or []: matching_transport_mode = next( @@ -221,6 +223,8 @@ async def convert_to_model( transport_mode_id=matching_transport_mode.transport_mode_id, ) ) + else: + raise ValueError(f"Invalid transport mode: {transport_mode}") return fuel_code diff --git a/backend/lcfs/web/api/monitoring/views.py b/backend/lcfs/web/api/monitoring/views.py index 1b36caf02..c599ef0b1 100644 --- a/backend/lcfs/web/api/monitoring/views.py +++ b/backend/lcfs/web/api/monitoring/views.py @@ -4,9 +4,10 @@ @router.get("/health") -def health_check() -> None: +def health_check() -> str: """ Checks the health of a project. It returns 200 if the project is healthy. """ + return "healthy" diff --git a/backend/lcfs/web/api/other_uses/repo.py b/backend/lcfs/web/api/other_uses/repo.py index 68a8a434a..8e515b484 100644 --- a/backend/lcfs/web/api/other_uses/repo.py +++ b/backend/lcfs/web/api/other_uses/repo.py @@ -1,20 +1,21 @@ import structlog -from typing import List, Optional, Tuple, Any +from datetime import date +from typing import List, Optional, Tuple, Dict, Any from fastapi import Depends - from lcfs.db.base import ActionTypeEnum, UserTypeEnum from lcfs.db.dependencies import get_async_db_session -from sqlalchemy import select, delete, func, case, and_ -from sqlalchemy.orm import joinedload +from sqlalchemy import select, delete, func, case, and_, or_ +from sqlalchemy.orm import joinedload, contains_eager from sqlalchemy.ext.asyncio import AsyncSession from lcfs.db.models.compliance import ComplianceReport from lcfs.db.models.compliance.OtherUses import OtherUses from lcfs.db.models.fuel.ProvisionOfTheAct import ProvisionOfTheAct from lcfs.db.models.fuel.FuelCode import FuelCode -from lcfs.db.models.fuel.FuelType import QuantityUnitsEnum +from lcfs.db.models.fuel.FuelType import FuelType, QuantityUnitsEnum +from lcfs.db.models.fuel.FuelInstance import FuelInstance from lcfs.web.api.fuel_code.repo import FuelCodeRepository from lcfs.web.api.other_uses.schema import OtherUsesSchema from lcfs.web.api.base import PaginationRequestSchema @@ -37,7 +38,7 @@ def __init__( async def get_table_options(self) -> dict: """Get all table options""" fuel_categories = await self.fuel_code_repo.get_fuel_categories() - fuel_types = await self.fuel_code_repo.get_formatted_fuel_types() + fuel_types = await self.get_formatted_fuel_types() expected_uses = await self.fuel_code_repo.get_expected_use_types() units_of_measure = [unit.value for unit in QuantityUnitsEnum] provisions_of_the_act = ( @@ -75,7 +76,7 @@ async def get_latest_other_uses_by_group_uuid( ) result = await self.db.execute(query) - return result.scalars().first() + return result.unique().scalars().first() @repo_handler async def get_other_uses(self, compliance_report_id: int) -> List[OtherUsesSchema]: @@ -302,3 +303,85 @@ async def get_other_use_version_by_user( result = await self.db.execute(query) return result.scalars().first() + + @repo_handler + async def get_formatted_fuel_types(self) -> List[Dict[str, Any]]: + """Get all fuel type options with their associated fuel categories and fuel codes for other uses""" + # Define the filtering conditions for fuel codes + current_date = date.today() + fuel_code_filters = ( + or_( + FuelCode.effective_date == None, FuelCode.effective_date <= current_date + ) + & or_( + FuelCode.expiration_date == None, + FuelCode.expiration_date > current_date, + ) + & (FuelType.other_uses_fossil_derived == True) + ) + + # Build the query with filtered fuel_codes + query = ( + select(FuelType) + .outerjoin(FuelType.fuel_instances) + .outerjoin(FuelInstance.fuel_category) + .outerjoin(FuelType.fuel_codes) + .where(fuel_code_filters) + .options( + contains_eager(FuelType.fuel_instances).contains_eager( + FuelInstance.fuel_category + ), + contains_eager(FuelType.fuel_codes), + joinedload(FuelType.provision_1), + joinedload(FuelType.provision_2), + ) + ) + + result = await self.db.execute(query) + fuel_types = result.unique().scalars().all() + + # Prepare the data in the format matching your schema + formatted_fuel_types = [] + for fuel_type in fuel_types: + formatted_fuel_type = { + "fuel_type_id": fuel_type.fuel_type_id, + "fuel_type": fuel_type.fuel_type, + "default_carbon_intensity": fuel_type.default_carbon_intensity, + "units": fuel_type.units if fuel_type.units else None, + "unrecognized": fuel_type.unrecognized, + "fuel_categories": [ + { + "fuel_category_id": fc.fuel_category.fuel_category_id, + "category": fc.fuel_category.category, + } + for fc in fuel_type.fuel_instances + ], + "fuel_codes": [ + { + "fuel_code_id": fc.fuel_code_id, + "fuel_code": fc.fuel_code, + "carbon_intensity": fc.carbon_intensity, + } + for fc in fuel_type.fuel_codes + ], + "provision_of_the_act": [], + } + + if fuel_type.provision_1: + formatted_fuel_type["provision_of_the_act"].append( + { + "provision_of_the_act_id": fuel_type.provision_1_id, + "name": fuel_type.provision_1.name, + } + ) + + if fuel_type.provision_2: + formatted_fuel_type["provision_of_the_act"].append( + { + "provision_of_the_act_id": fuel_type.provision_2_id, + "name": fuel_type.provision_2.name, + } + ) + formatted_fuel_types.append(formatted_fuel_type) + + return formatted_fuel_types diff --git a/backend/lcfs/web/api/user/repo.py b/backend/lcfs/web/api/user/repo.py index c1906c34c..f55e0eab4 100644 --- a/backend/lcfs/web/api/user/repo.py +++ b/backend/lcfs/web/api/user/repo.py @@ -669,7 +669,7 @@ async def create_login_history(self, user: UserProfile): self.db.add(login_history) @repo_handler - async def update_notifications_email( + async def update_email( self, user_profile_id: int, email: str ) -> UserProfile: # Fetch the user profile @@ -679,8 +679,7 @@ async def update_notifications_email( result = await self.db.execute(query) user_profile = result.scalar_one_or_none() - # Update the notifications_email field - user_profile.notifications_email = email + user_profile.email = email # Flush and refresh without committing await self.db.flush() diff --git a/backend/lcfs/web/api/user/schema.py b/backend/lcfs/web/api/user/schema.py index 6cf34fe46..e773133b1 100644 --- a/backend/lcfs/web/api/user/schema.py +++ b/backend/lcfs/web/api/user/schema.py @@ -40,7 +40,6 @@ class UserBaseSchema(BaseSchema): keycloak_username: str keycloak_email: EmailStr email: Optional[EmailStr] = None - notifications_email: Optional[EmailStr] = None title: Optional[str] = None phone: Optional[str] = None first_name: Optional[str] = None @@ -97,5 +96,5 @@ class UserLoginHistoryResponseSchema(BaseSchema): pagination: PaginationResponseSchema -class UpdateNotificationsEmailSchema(BaseSchema): - notifications_email: EmailStr +class UpdateEmailSchema(BaseSchema): + email: EmailStr diff --git a/backend/lcfs/web/api/user/services.py b/backend/lcfs/web/api/user/services.py index 48bad17de..bd539f6bb 100644 --- a/backend/lcfs/web/api/user/services.py +++ b/backend/lcfs/web/api/user/services.py @@ -334,15 +334,12 @@ async def track_user_login(self, user: UserProfile): await self.repo.create_login_history(user) @service_handler - async def update_notifications_email(self, user_id: int, email: str): + async def update_email(self, user_id: int, email: str): try: - # Update the notifications_email field of the user - return await self.repo.update_notifications_email(user_id, email) - # Return the updated user - return UserBaseSchema.model_validate(user) + return await self.repo.update_email(user_id, email) except DataNotFoundException as e: logger.error(f"User not found: {e}") raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - logger.error(f"Error updating notifications email: {e}") + logger.error(f"Error updating email: {e}") raise HTTPException(status_code=500, detail="Internal Server Error") diff --git a/backend/lcfs/web/api/user/views.py b/backend/lcfs/web/api/user/views.py index ff1a69aca..fab92a37f 100644 --- a/backend/lcfs/web/api/user/views.py +++ b/backend/lcfs/web/api/user/views.py @@ -1,16 +1,7 @@ from typing import List import structlog -from fastapi import ( - APIRouter, - Body, - status, - Request, - Response, - Depends, - Query, - HTTPException, -) +from fastapi import APIRouter, Body, status, Request, Response, Depends, Query from fastapi.responses import StreamingResponse from lcfs.db import dependencies @@ -23,7 +14,7 @@ UserLoginHistoryResponseSchema, UsersSchema, UserActivitiesResponseSchema, - UpdateNotificationsEmailSchema, + UpdateEmailSchema, ) from lcfs.web.api.user.services import UserServices from lcfs.web.core.decorators import view_handler @@ -255,18 +246,18 @@ async def get_all_user_login_history( @router.post( - "/update-notifications-email", - response_model=UpdateNotificationsEmailSchema, + "/update-email", + response_model=UpdateEmailSchema, status_code=status.HTTP_200_OK, ) @view_handler(["*"]) -async def update_notifications_email( +async def update_email( request: Request, - email_data: UpdateNotificationsEmailSchema = Body(...), + email_data: UpdateEmailSchema = Body(...), service: UserServices = Depends(), ): user_id = request.user.user_profile_id - email = email_data.notifications_email + email = email_data.email - user = await service.update_notifications_email(user_id, email) - return user + user = await service.update_email(user_id, email) + return UpdateEmailSchema(email=user.email) diff --git a/backend/lcfs/web/application.py b/backend/lcfs/web/application.py index 6d31484d0..e7117a105 100644 --- a/backend/lcfs/web/application.py +++ b/backend/lcfs/web/application.py @@ -1,30 +1,26 @@ -from importlib import metadata -import structlog import logging +import uuid -import os -import debugpy +import structlog from fastapi import FastAPI, HTTPException from fastapi.exceptions import RequestValidationError -from fastapi.responses import UJSONResponse from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import UJSONResponse from prometheus_fastapi_instrumentator import Instrumentator -from starlette.middleware.authentication import AuthenticationMiddleware from starlette.authentication import ( AuthenticationBackend, AuthCredentials, UnauthenticatedUser, ) +from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse -import uuid -import contextvars -from lcfs.settings import settings from lcfs.logging_config import setup_logging, correlation_id_var -from lcfs.web.api.router import api_router from lcfs.services.keycloak.authentication import UserAuthentication +from lcfs.settings import settings +from lcfs.web.api.router import api_router from lcfs.web.exception.exception_handler import validation_exception_handler from lcfs.web.lifetime import register_shutdown_event, register_startup_event @@ -67,6 +63,9 @@ async def authenticate(self, request): if request.scope["method"] == "OPTIONS": return AuthCredentials([]), UnauthenticatedUser() + if request.url.path == "/api/health": # Skip for health check + return AuthCredentials([]), UnauthenticatedUser() + # Lazily retrieve Redis, session, and settings from app state redis_client = self.app.state.redis_client session_factory = self.app.state.db_session_factory diff --git a/backend/lcfs/web/core/decorators.py b/backend/lcfs/web/core/decorators.py index 07dc7c5ab..e67d9afca 100644 --- a/backend/lcfs/web/core/decorators.py +++ b/backend/lcfs/web/core/decorators.py @@ -215,7 +215,7 @@ async def wrapper(*args, **kwargs): return await func(*args, **kwargs) # raise the error to the view layer - except (DatabaseException, HTTPException, DataNotFoundException): + except (DatabaseException, HTTPException, DataNotFoundException, ValueError): raise # all other errors that occur in the service layer will log an error except Exception as e: diff --git a/etl/database/nifi-registry-primary.mv.db b/etl/database/nifi-registry-primary.mv.db index 56acc8498..35bd40492 100644 Binary files a/etl/database/nifi-registry-primary.mv.db and b/etl/database/nifi-registry-primary.mv.db differ diff --git a/etl/nifi/conf/flow.json.gz b/etl/nifi/conf/flow.json.gz index 3d193752f..5780d2a55 100644 Binary files a/etl/nifi/conf/flow.json.gz and b/etl/nifi/conf/flow.json.gz differ diff --git a/etl/nifi/conf/flow.xml.gz b/etl/nifi/conf/flow.xml.gz index be40cbe0a..78bde91d7 100644 Binary files a/etl/nifi/conf/flow.xml.gz and b/etl/nifi/conf/flow.xml.gz differ diff --git a/etl/nifi_scripts/adminAdjTrxn.groovy b/etl/nifi_scripts/adminAdjTrxn.groovy new file mode 100644 index 000000000..5492bbf86 --- /dev/null +++ b/etl/nifi_scripts/adminAdjTrxn.groovy @@ -0,0 +1,436 @@ +import groovy.json.JsonSlurper +import java.sql.Connection +import java.sql.PreparedStatement +import java.sql.ResultSet +import java.time.OffsetDateTime +import java.sql.Timestamp + +def SOURCE_QUERY = """ +WITH + internal_comment AS ( + SELECT + ctc.id, + ctc.credit_trade_id, + ctc.credit_trade_comment, + ctc.create_user_id, + ctc.create_timestamp, + STRING_AGG (r."name", '; ') AS role_names + FROM + credit_trade_comment ctc + JOIN "user" u ON u.id = ctc.create_user_id + AND u.organization_id = 1 + AND ctc.is_privileged_access = TRUE + JOIN user_role ur ON ur.user_id = u.id + JOIN "role" r ON ur.role_id = r.id + GROUP BY + ctc.id, + ctc.credit_trade_id, + ctc.credit_trade_comment, + ctc.create_user_id, + ctc.create_timestamp + ORDER BY + ctc.credit_trade_id, + ctc.create_timestamp + ) + SELECT + ct.id AS admin_adjustment_id, + ct.respondent_id AS to_organization_id, + ct.date_of_written_agreement AS agreement_date, + ct.trade_effective_date AS transaction_effective_date, + ct.number_of_credits AS compliance_units, + ct.create_user_id as create_user, + ct.update_user_id as update_user, + ct.update_timestamp as update_date, + ct.create_timestamp as create_date, + -- Aggregate comments from government with internal comment handling + STRING_AGG (DISTINCT gov_ctc.credit_trade_comment, '; ') AS gov_comment, + -- JSON aggregation for internal comments + json_agg (row_to_json (internal_comment)) AS internal_comments, + -- JSON aggregation for credit trade history + json_agg ( + json_build_object ( + 'admin_adjustment_id', + cth.credit_trade_id, + 'admin_adjustment_status', + case + WHEN cts_history.status IN ('Cancelled', 'Not Recommended', 'Declined', 'Refused') or ct.is_rescinded = true THEN 'Deleted' + WHEN cts_history.status IN ('Accepted', 'Submitted', 'Recommended') THEN 'Recommended' + WHEN cts_history.status IN ('Approved', 'Recorded') THEN 'Approved' + ELSE 'Draft' + END, + 'user_profile_id', + cth.create_user_id, + 'create_timestamp', + cth.create_timestamp + ) + ) AS credit_trade_history, + case + WHEN cts.status IN ('Cancelled', 'Not Recommended', 'Declined', 'Refused') or ct.is_rescinded = true THEN 'Deleted' + WHEN cts.status IN ('Accepted', 'Submitted', 'Recommended') THEN 'Recommended' + WHEN cts.status IN ('Approved', 'Recorded') THEN 'Approved' + ELSE 'Draft' + END AS current_status, cts.status + FROM + credit_trade ct + JOIN credit_trade_type ctt ON ct.type_id = ctt.id + LEFT OUTER JOIN credit_trade_category ctc ON ct.trade_category_id = ctc.id + JOIN credit_trade_status cts ON ct.status_id = cts.id + LEFT JOIN credit_trade_zero_reason ctzr ON ctzr.id = ct.zero_reason_id + AND ctzr.reason = 'Internal' + -- Join for Initiator Comments + LEFT JOIN credit_trade_comment from_ctc ON from_ctc.credit_trade_id = ct.id + AND from_ctc.create_user_id IN ( + SELECT + u.id + FROM + "user" u + WHERE + u.organization_id = ct.initiator_id + ) + -- Join for Respondent Comments + LEFT JOIN credit_trade_comment to_ctc ON to_ctc.credit_trade_id = ct.id + AND to_ctc.create_user_id IN ( + SELECT + u.id + FROM + "user" u + WHERE + u.organization_id = ct.respondent_id + ) + -- Join for Government Comments + LEFT JOIN credit_trade_comment gov_ctc ON gov_ctc.credit_trade_id = ct.id + AND gov_ctc.create_user_id IN ( + SELECT + u.id + FROM + "user" u + WHERE + u.organization_id = 1 + AND gov_ctc.is_privileged_access = FALSE + ) + -- Join the internal comment logic for role-based filtering and audience_scope + LEFT JOIN internal_comment ON internal_comment.credit_trade_id = ct.id + -- Join for credit trade history + LEFT JOIN credit_trade_history cth ON cth.credit_trade_id = ct.id + JOIN credit_trade_status cts_history ON cth.status_id = cts_history.id + WHERE + ctt.the_type IN ('Administrative Adjustment') + GROUP BY + ct.id, + ct.respondent_id, + ct.date_of_written_agreement, + ct.trade_effective_date, + ct.number_of_credits, + cts.status, + ctzr.description, + internal_comment.role_names; + """ + +// Fetch connections to both the source and destination databases +def sourceDbcpService = context.controllerServiceLookup.getControllerService('3245b078-0192-1000-ffff-ffffba20c1eb') +def destinationDbcpService = context.controllerServiceLookup.getControllerService('3244bf63-0192-1000-ffff-ffffc8ec6d93') + +Connection sourceConn = null +Connection destinationConn = null + +try { + sourceConn = sourceDbcpService.getConnection() + destinationConn = destinationDbcpService.getConnection() + destinationConn.setAutoCommit(false) + + // Fetch status and category data once and cache it + def preparedData = prepareData(destinationConn) + + def statements = prepareStatements(destinationConn) + + destinationConn.createStatement().execute('DROP FUNCTION IF EXISTS refresh_transaction_aggregate() CASCADE;') + destinationConn.createStatement().execute(""" + CREATE OR REPLACE FUNCTION refresh_transaction_aggregate() + RETURNS void AS \$\$ + BEGIN + -- Temporarily disable the materialized view refresh + END; + \$\$ LANGUAGE plpgsql; + """) + + PreparedStatement sourceStmt = sourceConn.prepareStatement(SOURCE_QUERY) + ResultSet resultSet = sourceStmt.executeQuery() + + int recordCount = 0 + + while (resultSet.next()) { + recordCount++ + def jsonSlurper = new JsonSlurper() + def internalComments = resultSet.getString('internal_comments') + def creditTradeHistory = resultSet.getString('credit_trade_history') + + def internalCommentsJson = internalComments ? jsonSlurper.parseText(internalComments) : [] + def creditTradeHistoryJson = creditTradeHistory ? jsonSlurper.parseText(creditTradeHistory) : [] + + def toTransactionId = processTransactions(resultSet.getString('current_status'), + resultSet, statements.transactionStmt) + + def adminAdjustmentId = insertadminAdjustment(resultSet, statements.adminAdjustmentStmt, + toTransactionId, preparedData, destinationConn) + + if (adminAdjustmentId) { + processHistory(adminAdjustmentId, creditTradeHistoryJson, statements.historyStmt, preparedData) + processInternalComments(adminAdjustmentId, internalCommentsJson, statements.internalCommentStmt, + statements.adminAdjustmentInternalCommentStmt) + } else { + log.warn("admin-adjustment not inserted for record: ${resultSet.getInt('admin_adjustment_id')}") + } + } + resultSet.close() + destinationConn.createStatement().execute(""" + CREATE OR REPLACE FUNCTION refresh_transaction_aggregate() + RETURNS void AS \$\$ + BEGIN + REFRESH MATERIALIZED VIEW CONCURRENTLY mv_transaction_aggregate; + END; + \$\$ LANGUAGE plpgsql; + """) + destinationConn.createStatement().execute('REFRESH MATERIALIZED VIEW CONCURRENTLY mv_transaction_aggregate') + + destinationConn.commit() + log.debug("Processed ${recordCount} records successfully.") +} catch (Exception e) { + log.error('Error occurred while processing data', e) + destinationConn?.rollback() +} finally { + if (sourceConn != null) sourceConn.close() + if (destinationConn != null) destinationConn.close() +} + +def logResultSetRow(ResultSet rs) { + def metaData = rs.getMetaData() + def columnCount = metaData.getColumnCount() + def rowData = [:] + + for (int i = 1; i <= columnCount; i++) { + def columnName = metaData.getColumnName(i) + def columnValue = rs.getObject(i) + rowData[columnName] = columnValue + } + + log.debug("Row data: ${rowData}") +} + +def loadTableData(Connection conn, String query, String keyColumn, String valueColumn) { + def dataMap = [:] + def stmt = conn.createStatement() + def rs = stmt.executeQuery(query) + + while (rs.next()) { + def key = rs.getString(keyColumn) + def value = rs.getInt(valueColumn) + if (dataMap.containsKey(key)) { + log.warn("Duplicate key found for ${key}. Existing value: ${dataMap[key]}, New value: ${value}.") + } + dataMap[key] = value + } + + rs.close() + stmt.close() + return dataMap +} + +def prepareData(Connection conn) { + def statusMap = loadTableData(conn, + 'SELECT DISTINCT status, MIN(admin_adjustment_status_id) AS admin_adjustment_status_id FROM admin_adjustment_status GROUP BY status', + 'status', + 'admin_adjustment_status_id' + ) + return [ + statusMap : statusMap + ] +} + +def getStatusId(String status, Map preparedData) { + return preparedData.statusMap[status] +} + +def prepareStatements(Connection conn) { + def INSERT_admin_adjustment_SQL = ''' + INSERT INTO admin_adjustment ( + to_organization_id, transaction_id, transaction_effective_date, compliance_units, gov_comment, + current_status_id, create_date, update_date, create_user, update_user, effective_status, + admin_adjustment_id + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, true, ?) + RETURNING admin_adjustment_id + ''' + def INSERT_admin_adjustment_HISTORY_SQL = ''' + INSERT INTO admin_adjustment_history ( + admin_adjustment_history_id, admin_adjustment_id, admin_adjustment_status_id, user_profile_id, create_date, effective_status + ) VALUES (DEFAULT, ?, ?, ?, ?, true) + ''' + def INSERT_INTERNAL_COMMENT_SQL = ''' + INSERT INTO internal_comment ( + internal_comment_id, comment, audience_scope, create_user, create_date + ) VALUES (DEFAULT, ?, ?::audience_scope, ?, ?) + RETURNING internal_comment_id + ''' + def INSERT_admin_adjustment_INTERNAL_COMMENT_SQL = ''' + INSERT INTO admin_adjustment_internal_comment ( + admin_adjustment_id, internal_comment_id + ) VALUES (?, ?) + ''' + def INSERT_TRANSACTION_SQL = ''' + INSERT INTO transaction ( + transaction_id, compliance_units, organization_id, transaction_action, effective_date, create_user, create_date, effective_status + ) VALUES (DEFAULT, ?, ?, ?::transaction_action_enum, ?, ?, ?, true) + RETURNING transaction_id + ''' + return [adminAdjustmentStmt : conn.prepareStatement(INSERT_admin_adjustment_SQL), + historyStmt : conn.prepareStatement(INSERT_admin_adjustment_HISTORY_SQL), + internalCommentStmt : conn.prepareStatement(INSERT_INTERNAL_COMMENT_SQL), + adminAdjustmentInternalCommentStmt + : conn.prepareStatement(INSERT_admin_adjustment_INTERNAL_COMMENT_SQL), + transactionStmt : conn.prepareStatement(INSERT_TRANSACTION_SQL)] +} + +def toSqlTimestamp(String timestampString) { + try { + // Parse the ISO 8601 timestamp and convert to java.sql.Timestamp + def offsetDateTime = OffsetDateTime.parse(timestampString) + return Timestamp.from(offsetDateTime.toInstant()) + } catch (Exception e) { + log.error("Invalid timestamp format: ${timestampString}, defaulting to '1970-01-01T00:00:00Z'") + return Timestamp.valueOf('1970-01-01 00:00:00') + } +} + +def processTransactions(String currentStatus, ResultSet rs, PreparedStatement stmt) { + def toTransactionId = null + + if (currentStatus == 'Approved') { + toTransactionId = insertTransaction(stmt, rs, 'Adjustment', rs.getInt('to_organization_id')) + } + + return toTransactionId +} + +def insertTransaction(PreparedStatement stmt, ResultSet rs, String action, int orgId) { + stmt.setInt(1, rs.getInt('compliance_units')) + stmt.setInt(2, orgId) + stmt.setString(3, action) + stmt.setDate(4, rs.getDate('transaction_effective_date') ?: rs.getDate('agreement_date')) + stmt.setInt(5, rs.getInt('create_user')) + stmt.setTimestamp(6, rs.getTimestamp('create_date')) + + def result = stmt.executeQuery() + return result.next() ? result.getInt('transaction_id') : null +} + +def processHistory(Integer adminAdjustmentId, List creditTradeHistory, PreparedStatement historyStmt, Map preparedData) { + if (!creditTradeHistory) return + + // Use a Set to track unique combinations of admin_adjustment_id and admin_adjustment_status + def processedEntries = new HashSet() + + creditTradeHistory.each { historyItem -> + try { + def statusId = getStatusId(historyItem.admin_adjustment_status, preparedData) + def uniqueKey = "${adminAdjustmentId}_${statusId}" + + // Check if this combination has already been processed + if (!processedEntries.contains(uniqueKey)) { + // If not processed, add to batch and mark as processed + historyStmt.setInt(1, adminAdjustmentId) + historyStmt.setInt(2, statusId) + historyStmt.setInt(3, historyItem.user_profile_id) + historyStmt.setTimestamp(4, toSqlTimestamp(historyItem.create_timestamp ?: '2013-01-01T00:00:00Z')) + historyStmt.addBatch() + + processedEntries.add(uniqueKey) + } + } catch (Exception e) { + log.error("Error processing history record for admin_adjustment_id: ${adminAdjustmentId}", e) + } + } + + // Execute batch + historyStmt.executeBatch() +} + + +def processInternalComments(Integer adminAdjustmentId, List internalComments, + PreparedStatement internalCommentStmt, + PreparedStatement adminAdjustmentInternalCommentStmt) { + if (!internalComments) return + + internalComments.each { comment -> + if (!comment) return // Skip null comments + + try { + // Insert the internal comment + internalCommentStmt.setString(1, comment.credit_trade_comment ?: '') + internalCommentStmt.setString(2, getAudienceScope(comment.role_names ?: '')) + internalCommentStmt.setInt(3, comment.create_user_id ?: null) + internalCommentStmt.setTimestamp(4, toSqlTimestamp(comment.create_timestamp ?: '2013-01-01T00:00:00Z')) + + def internalCommentId = null + def commentResult = internalCommentStmt.executeQuery() + if (commentResult.next()) { + internalCommentId = commentResult.getInt('internal_comment_id') + + // Insert the admin-adjustment-comment relationship + adminAdjustmentInternalCommentStmt.setInt(1, adminAdjustmentId) + adminAdjustmentInternalCommentStmt.setInt(2, internalCommentId) + adminAdjustmentInternalCommentStmt.executeUpdate() + } + + commentResult.close() + } catch (Exception e) { + log.error("Error processing internal comment for admin-adjustment ${adminAdjustmentId}: ${e.getMessage()}", e) + } + } + } + +// Helper function to determine audience scope based on role names +def getAudienceScope(String roleNames) { + if (!roleNames) return 'Analyst' + + switch (true) { + case roleNames.contains('GovDirector'): + return 'Director' + case roleNames.contains('GovComplianceManager'): + return 'Compliance Manager' + default: + return 'Analyst' + } +} + +def insertadminAdjustment(ResultSet rs, PreparedStatement adminAdjustmentStmt, + Long toTransactionId, Map preparedData, Connection conn) { + // Check for duplicates in the `admin_adjustment` table + def adminAdjustmentId = rs.getInt('admin_adjustment_id') + def duplicateCheckStmt = conn.prepareStatement('SELECT COUNT(*) FROM admin_adjustment WHERE admin_adjustment_id = ?') + duplicateCheckStmt.setInt(1, adminAdjustmentId) + def duplicateResult = duplicateCheckStmt.executeQuery() + duplicateResult.next() + def count = duplicateResult.getInt(1) + duplicateResult.close() + duplicateCheckStmt.close() + + if (count > 0) { + log.warn("Duplicate admin_adjustment detected with admin_adjustment_id: ${adminAdjustmentId}, skipping insertion.") + return null + } + + // Proceed with insertion if no duplicate exists + def statusId = getStatusId(rs.getString('current_status'), preparedData) + adminAdjustmentStmt.setInt(1, rs.getInt('to_organization_id')) + adminAdjustmentStmt.setObject(2, toTransactionId) + adminAdjustmentStmt.setTimestamp(3, rs.getTimestamp('transaction_effective_date')) + adminAdjustmentStmt.setInt(4, rs.getInt('compliance_units')) + adminAdjustmentStmt.setString(5, rs.getString('gov_comment')) + adminAdjustmentStmt.setObject(6, statusId) + adminAdjustmentStmt.setTimestamp(7, rs.getTimestamp('create_date')) + adminAdjustmentStmt.setTimestamp(8, rs.getTimestamp('update_date')) + adminAdjustmentStmt.setInt(9, rs.getInt('create_user')) + adminAdjustmentStmt.setInt(10, rs.getInt('update_user')) + adminAdjustmentStmt.setInt(11, rs.getInt('admin_adjustment_id')) + def result = adminAdjustmentStmt.executeQuery() + return result.next() ? result.getInt('admin_adjustment_id') : null +} \ No newline at end of file diff --git a/etl/nifi_scripts/initiativeAgrmtTrxn.groovy b/etl/nifi_scripts/initiativeAgrmtTrxn.groovy new file mode 100644 index 000000000..8e3f9f121 --- /dev/null +++ b/etl/nifi_scripts/initiativeAgrmtTrxn.groovy @@ -0,0 +1,436 @@ +import groovy.json.JsonSlurper +import java.sql.Connection +import java.sql.PreparedStatement +import java.sql.ResultSet +import java.time.OffsetDateTime +import java.sql.Timestamp + +def SOURCE_QUERY = """ +WITH + internal_comment AS ( + SELECT + ctc.id, + ctc.credit_trade_id, + ctc.credit_trade_comment, + ctc.create_user_id, + ctc.create_timestamp, + STRING_AGG (r."name", '; ') AS role_names + FROM + credit_trade_comment ctc + JOIN "user" u ON u.id = ctc.create_user_id + AND u.organization_id = 1 + AND ctc.is_privileged_access = TRUE + JOIN user_role ur ON ur.user_id = u.id + JOIN "role" r ON ur.role_id = r.id + GROUP BY + ctc.id, + ctc.credit_trade_id, + ctc.credit_trade_comment, + ctc.create_user_id, + ctc.create_timestamp + ORDER BY + ctc.credit_trade_id, + ctc.create_timestamp + ) + SELECT + ct.id AS initiative_agreement_id, + ct.respondent_id AS to_organization_id, + ct.date_of_written_agreement AS agreement_date, + ct.trade_effective_date AS transaction_effective_date, + ct.number_of_credits AS compliance_units, + ct.create_user_id as create_user, + ct.update_user_id as update_user, + ct.update_timestamp as update_date, + ct.create_timestamp as create_date, + -- Aggregate comments from government with internal comment handling + STRING_AGG (DISTINCT gov_ctc.credit_trade_comment, '; ') AS gov_comment, + -- JSON aggregation for internal comments + json_agg (row_to_json (internal_comment)) AS internal_comments, + -- JSON aggregation for credit trade history + json_agg ( + json_build_object ( + 'initiative_agreement_id', + cth.credit_trade_id, + 'initiative_agreement_status', + case + WHEN cts_history.status IN ('Cancelled', 'Not Recommended', 'Declined', 'Refused') or ct.is_rescinded = true THEN 'Deleted' + WHEN cts_history.status IN ('Accepted', 'Submitted', 'Recommended') THEN 'Recommended' + WHEN cts_history.status IN ('Approved', 'Recorded') THEN 'Approved' + ELSE 'Draft' + END, + 'user_profile_id', + cth.create_user_id, + 'create_timestamp', + cth.create_timestamp + ) + ) AS credit_trade_history, + case + WHEN cts.status IN ('Cancelled', 'Not Recommended', 'Declined', 'Refused') or ct.is_rescinded = true THEN 'Deleted' + WHEN cts.status IN ('Accepted', 'Submitted', 'Recommended') THEN 'Recommended' + WHEN cts.status IN ('Approved', 'Recorded') THEN 'Approved' + ELSE 'Draft' + END AS current_status, cts.status + FROM + credit_trade ct + JOIN credit_trade_type ctt ON ct.type_id = ctt.id + LEFT OUTER JOIN credit_trade_category ctc ON ct.trade_category_id = ctc.id + JOIN credit_trade_status cts ON ct.status_id = cts.id + LEFT JOIN credit_trade_zero_reason ctzr ON ctzr.id = ct.zero_reason_id + AND ctzr.reason = 'Internal' + -- Join for Initiator Comments + LEFT JOIN credit_trade_comment from_ctc ON from_ctc.credit_trade_id = ct.id + AND from_ctc.create_user_id IN ( + SELECT + u.id + FROM + "user" u + WHERE + u.organization_id = ct.initiator_id + ) + -- Join for Respondent Comments + LEFT JOIN credit_trade_comment to_ctc ON to_ctc.credit_trade_id = ct.id + AND to_ctc.create_user_id IN ( + SELECT + u.id + FROM + "user" u + WHERE + u.organization_id = ct.respondent_id + ) + -- Join for Government Comments + LEFT JOIN credit_trade_comment gov_ctc ON gov_ctc.credit_trade_id = ct.id + AND gov_ctc.create_user_id IN ( + SELECT + u.id + FROM + "user" u + WHERE + u.organization_id = 1 + AND gov_ctc.is_privileged_access = FALSE + ) + -- Join the internal comment logic for role-based filtering and audience_scope + LEFT JOIN internal_comment ON internal_comment.credit_trade_id = ct.id + -- Join for credit trade history + LEFT JOIN credit_trade_history cth ON cth.credit_trade_id = ct.id + JOIN credit_trade_status cts_history ON cth.status_id = cts_history.id + WHERE + ctt.the_type IN ('Part 3 Award') + GROUP BY + ct.id, + ct.respondent_id, + ct.date_of_written_agreement, + ct.trade_effective_date, + ct.number_of_credits, + cts.status, + ctzr.description, + internal_comment.role_names; + """ + +// Fetch connections to both the source and destination databases +def sourceDbcpService = context.controllerServiceLookup.getControllerService('3245b078-0192-1000-ffff-ffffba20c1eb') +def destinationDbcpService = context.controllerServiceLookup.getControllerService('3244bf63-0192-1000-ffff-ffffc8ec6d93') + +Connection sourceConn = null +Connection destinationConn = null + +try { + sourceConn = sourceDbcpService.getConnection() + destinationConn = destinationDbcpService.getConnection() + destinationConn.setAutoCommit(false) + + // Fetch status and category data once and cache it + def preparedData = prepareData(destinationConn) + + def statements = prepareStatements(destinationConn) + + destinationConn.createStatement().execute('DROP FUNCTION IF EXISTS refresh_transaction_aggregate() CASCADE;') + destinationConn.createStatement().execute(""" + CREATE OR REPLACE FUNCTION refresh_transaction_aggregate() + RETURNS void AS \$\$ + BEGIN + -- Temporarily disable the materialized view refresh + END; + \$\$ LANGUAGE plpgsql; + """) + + PreparedStatement sourceStmt = sourceConn.prepareStatement(SOURCE_QUERY) + ResultSet resultSet = sourceStmt.executeQuery() + + int recordCount = 0 + + while (resultSet.next()) { + recordCount++ + def jsonSlurper = new JsonSlurper() + def internalComments = resultSet.getString('internal_comments') + def creditTradeHistory = resultSet.getString('credit_trade_history') + + def internalCommentsJson = internalComments ? jsonSlurper.parseText(internalComments) : [] + def creditTradeHistoryJson = creditTradeHistory ? jsonSlurper.parseText(creditTradeHistory) : [] + + def toTransactionId = processTransactions(resultSet.getString('current_status'), + resultSet, statements.transactionStmt) + + def initiativeAgreementId = insertInitiativeAgreement(resultSet, statements.initiativeAgreementStmt, + toTransactionId, preparedData, destinationConn) + + if (initiativeAgreementId) { + processHistory(initiativeAgreementId, creditTradeHistoryJson, statements.historyStmt, preparedData) + processInternalComments(initiativeAgreementId, internalCommentsJson, statements.internalCommentStmt, + statements.initiativeAgreementInternalCommentStmt) + } else { + log.warn("initiative-agreement not inserted for record: ${resultSet.getInt('initiative_agreement_id')}") + } + } + resultSet.close() + destinationConn.createStatement().execute(""" + CREATE OR REPLACE FUNCTION refresh_transaction_aggregate() + RETURNS void AS \$\$ + BEGIN + REFRESH MATERIALIZED VIEW CONCURRENTLY mv_transaction_aggregate; + END; + \$\$ LANGUAGE plpgsql; + """) + destinationConn.createStatement().execute('REFRESH MATERIALIZED VIEW CONCURRENTLY mv_transaction_aggregate') + + destinationConn.commit() + log.debug("Processed ${recordCount} records successfully.") +} catch (Exception e) { + log.error('Error occurred while processing data', e) + destinationConn?.rollback() +} finally { + if (sourceConn != null) sourceConn.close() + if (destinationConn != null) destinationConn.close() +} + +def logResultSetRow(ResultSet rs) { + def metaData = rs.getMetaData() + def columnCount = metaData.getColumnCount() + def rowData = [:] + + for (int i = 1; i <= columnCount; i++) { + def columnName = metaData.getColumnName(i) + def columnValue = rs.getObject(i) + rowData[columnName] = columnValue + } + + log.debug("Row data: ${rowData}") +} + +def loadTableData(Connection conn, String query, String keyColumn, String valueColumn) { + def dataMap = [:] + def stmt = conn.createStatement() + def rs = stmt.executeQuery(query) + + while (rs.next()) { + def key = rs.getString(keyColumn) + def value = rs.getInt(valueColumn) + if (dataMap.containsKey(key)) { + log.warn("Duplicate key found for ${key}. Existing value: ${dataMap[key]}, New value: ${value}.") + } + dataMap[key] = value + } + + rs.close() + stmt.close() + return dataMap +} + +def prepareData(Connection conn) { + def statusMap = loadTableData(conn, + 'SELECT DISTINCT status, MIN(initiative_agreement_status_id) AS initiative_agreement_status_id FROM initiative_agreement_status GROUP BY status', + 'status', + 'initiative_agreement_status_id' + ) + return [ + statusMap : statusMap + ] +} + +def getStatusId(String status, Map preparedData) { + return preparedData.statusMap[status] +} + +def prepareStatements(Connection conn) { + def INSERT_INITIATIVE_AGREEMENT_SQL = ''' + INSERT INTO initiative_agreement ( + to_organization_id, transaction_id, transaction_effective_date, compliance_units, gov_comment, + current_status_id, create_date, update_date, create_user, update_user, effective_status, + initiative_agreement_id + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, true, ?) + RETURNING initiative_agreement_id + ''' + def INSERT_INITIATIVE_AGREEMENT_HISTORY_SQL = ''' + INSERT INTO initiative_agreement_history ( + initiative_agreement_history_id, initiative_agreement_id, initiative_agreement_status_id, user_profile_id, create_date, effective_status + ) VALUES (DEFAULT, ?, ?, ?, ?, true) + ''' + def INSERT_INTERNAL_COMMENT_SQL = ''' + INSERT INTO internal_comment ( + internal_comment_id, comment, audience_scope, create_user, create_date + ) VALUES (DEFAULT, ?, ?::audience_scope, ?, ?) + RETURNING internal_comment_id + ''' + def INSERT_INITIATIVE_AGREEMENT_INTERNAL_COMMENT_SQL = ''' + INSERT INTO initiative_agreement_internal_comment ( + initiative_agreement_id, internal_comment_id + ) VALUES (?, ?) + ''' + def INSERT_TRANSACTION_SQL = ''' + INSERT INTO transaction ( + transaction_id, compliance_units, organization_id, transaction_action, effective_date, create_user, create_date, effective_status + ) VALUES (DEFAULT, ?, ?, ?::transaction_action_enum, ?, ?, ?, true) + RETURNING transaction_id + ''' + return [initiativeAgreementStmt : conn.prepareStatement(INSERT_INITIATIVE_AGREEMENT_SQL), + historyStmt : conn.prepareStatement(INSERT_INITIATIVE_AGREEMENT_HISTORY_SQL), + internalCommentStmt : conn.prepareStatement(INSERT_INTERNAL_COMMENT_SQL), + initiativeAgreementInternalCommentStmt + : conn.prepareStatement(INSERT_INITIATIVE_AGREEMENT_INTERNAL_COMMENT_SQL), + transactionStmt : conn.prepareStatement(INSERT_TRANSACTION_SQL)] +} + +def toSqlTimestamp(String timestampString) { + try { + // Parse the ISO 8601 timestamp and convert to java.sql.Timestamp + def offsetDateTime = OffsetDateTime.parse(timestampString) + return Timestamp.from(offsetDateTime.toInstant()) + } catch (Exception e) { + log.error("Invalid timestamp format: ${timestampString}, defaulting to '1970-01-01T00:00:00Z'") + return Timestamp.valueOf('1970-01-01 00:00:00') + } +} + +def processTransactions(String currentStatus, ResultSet rs, PreparedStatement stmt) { + def toTransactionId = null + + if (currentStatus == 'Approved') { + toTransactionId = insertTransaction(stmt, rs, 'Adjustment', rs.getInt('to_organization_id')) + } + + return toTransactionId +} + +def insertTransaction(PreparedStatement stmt, ResultSet rs, String action, int orgId) { + stmt.setInt(1, rs.getInt('compliance_units')) + stmt.setInt(2, orgId) + stmt.setString(3, action) + stmt.setDate(4, rs.getDate('transaction_effective_date') ?: rs.getDate('agreement_date')) + stmt.setInt(5, rs.getInt('create_user')) + stmt.setTimestamp(6, rs.getTimestamp('create_date')) + + def result = stmt.executeQuery() + return result.next() ? result.getInt('transaction_id') : null +} + +def processHistory(Integer initiativeAgreementId, List creditTradeHistory, PreparedStatement historyStmt, Map preparedData) { + if (!creditTradeHistory) return + + // Use a Set to track unique combinations of initiative_agreement_id and initiative_agreement_status + def processedEntries = new HashSet() + + creditTradeHistory.each { historyItem -> + try { + def statusId = getStatusId(historyItem.initiative_agreement_status, preparedData) + def uniqueKey = "${initiativeAgreementId}_${statusId}" + + // Check if this combination has already been processed + if (!processedEntries.contains(uniqueKey)) { + // If not processed, add to batch and mark as processed + historyStmt.setInt(1, initiativeAgreementId) + historyStmt.setInt(2, statusId) + historyStmt.setInt(3, historyItem.user_profile_id) + historyStmt.setTimestamp(4, toSqlTimestamp(historyItem.create_timestamp ?: '2013-01-01T00:00:00Z')) + historyStmt.addBatch() + + processedEntries.add(uniqueKey) + } + } catch (Exception e) { + log.error("Error processing history record for initiative_agreement_id: ${initiativeAgreementId}", e) + } + } + + // Execute batch + historyStmt.executeBatch() +} + + +def processInternalComments(Integer initiativeAgreementId, List internalComments, + PreparedStatement internalCommentStmt, + PreparedStatement initiativeAgreementInternalCommentStmt) { + if (!internalComments) return + + internalComments.each { comment -> + if (!comment) return // Skip null comments + + try { + // Insert the internal comment + internalCommentStmt.setString(1, comment.credit_trade_comment ?: '') + internalCommentStmt.setString(2, getAudienceScope(comment.role_names ?: '')) + internalCommentStmt.setInt(3, comment.create_user_id ?: null) + internalCommentStmt.setTimestamp(4, toSqlTimestamp(comment.create_timestamp ?: '2013-01-01T00:00:00Z')) + + def internalCommentId = null + def commentResult = internalCommentStmt.executeQuery() + if (commentResult.next()) { + internalCommentId = commentResult.getInt('internal_comment_id') + + // Insert the initiative-agreement-comment relationship + initiativeAgreementInternalCommentStmt.setInt(1, initiativeAgreementId) + initiativeAgreementInternalCommentStmt.setInt(2, internalCommentId) + initiativeAgreementInternalCommentStmt.executeUpdate() + } + + commentResult.close() + } catch (Exception e) { + log.error("Error processing internal comment for initiative-agreement ${initiativeAgreementId}: ${e.getMessage()}", e) + } + } + } + +// Helper function to determine audience scope based on role names +def getAudienceScope(String roleNames) { + if (!roleNames) return 'Analyst' + + switch (true) { + case roleNames.contains('GovDirector'): + return 'Director' + case roleNames.contains('GovComplianceManager'): + return 'Compliance Manager' + default: + return 'Analyst' + } +} + +def insertInitiativeAgreement(ResultSet rs, PreparedStatement initiativeAgreementStmt, + Long toTransactionId, Map preparedData, Connection conn) { + // Check for duplicates in the `initiative_agreement` table + def initiativeAgreementId = rs.getInt('initiative_agreement_id') + def duplicateCheckStmt = conn.prepareStatement('SELECT COUNT(*) FROM initiative_agreement WHERE initiative_agreement_id = ?') + duplicateCheckStmt.setInt(1, initiativeAgreementId) + def duplicateResult = duplicateCheckStmt.executeQuery() + duplicateResult.next() + def count = duplicateResult.getInt(1) + duplicateResult.close() + duplicateCheckStmt.close() + + if (count > 0) { + log.warn("Duplicate initiative_agreement detected with initiative_agreement_id: ${initiativeAgreementId}, skipping insertion.") + return null + } + + // Proceed with insertion if no duplicate exists + def statusId = getStatusId(rs.getString('current_status'), preparedData) + initiativeAgreementStmt.setInt(1, rs.getInt('to_organization_id')) + initiativeAgreementStmt.setObject(2, toTransactionId) + initiativeAgreementStmt.setTimestamp(3, rs.getTimestamp('transaction_effective_date')) + initiativeAgreementStmt.setInt(4, rs.getInt('compliance_units')) + initiativeAgreementStmt.setString(5, rs.getString('gov_comment')) + initiativeAgreementStmt.setObject(6, statusId) + initiativeAgreementStmt.setTimestamp(7, rs.getTimestamp('create_date')) + initiativeAgreementStmt.setTimestamp(8, rs.getTimestamp('update_date')) + initiativeAgreementStmt.setInt(9, rs.getInt('create_user')) + initiativeAgreementStmt.setInt(10, rs.getInt('update_user')) + initiativeAgreementStmt.setInt(11, rs.getInt('initiative_agreement_id')) + def result = initiativeAgreementStmt.executeQuery() + return result.next() ? result.getInt('initiative_agreement_id') : null +} \ No newline at end of file diff --git a/etl/nifi_scripts/transfer.groovy b/etl/nifi_scripts/transfer.groovy index 69211d890..aa6c5362d 100644 --- a/etl/nifi_scripts/transfer.groovy +++ b/etl/nifi_scripts/transfer.groovy @@ -404,9 +404,6 @@ def processInternalComments(Integer transferId, List internalComments, PreparedStatement transferInternalCommentStmt) { if (!internalComments) return - // Use Set to track processed IDs and avoid duplicates - def processedIds = new HashSet() - internalComments.each { comment -> if (!comment) return // Skip null comments diff --git a/frontend/cypress/e2e/Pages/ComplianceReport/ComplianceReport.test.js b/frontend/cypress/e2e/Pages/ComplianceReport/ComplianceReport.test.js index e175c8f20..8e92fff86 100644 --- a/frontend/cypress/e2e/Pages/ComplianceReport/ComplianceReport.test.js +++ b/frontend/cypress/e2e/Pages/ComplianceReport/ComplianceReport.test.js @@ -2,7 +2,7 @@ import { Given, When, Then } from '@badeball/cypress-cucumber-preprocessor' const currentYear = new Date().getFullYear().toString() -Given('the supplier is on the login page', () => { +Given('the user is on the login page', () => { cy.clearAllCookies() cy.clearAllLocalStorage() cy.clearAllSessionStorage() @@ -20,7 +20,7 @@ When('the supplier logs in with valid credentials', () => { cy.getByDataTest('dashboard-container').should('exist') }) -When('the supplier navigates to the compliance reports page', () => { +When('they navigate to the compliance reports page', () => { cy.get('a[href="/compliance-reporting"]').click() }) @@ -140,3 +140,43 @@ Then('the compliance report summary includes the quantity', () => { .should('be.visible') .and('have.text', '500') }) + +When('the supplier fills out line 6', () => { + cy.get( + '[data-test="renewable-summary"] > .MuiTable-root > .MuiTableBody-root > :nth-child(6) > :nth-child(3)' + ) + .find('input') + .type('50{enter}') + .blur() +}) + +Then('it should round the amount to 25', () => { + cy.get( + '[data-test="renewable-summary"] > .MuiTable-root > .MuiTableBody-root > :nth-child(6) > :nth-child(3)' + ) + .find('input') + .should('be.visible') + .and('have.value', '25') +}) + +When('the supplier accepts the agreement', () => { + cy.get('#signing-authority-declaration').click() +}) + +When('the supplier submits the report', () => { + cy.contains('button', 'Submit report').click() + cy.get('#modal-btn-submit-report').click() + cy.wait(2000) +}) + +Then('the status should change to Submitted', () => { + cy.get('[data-test="compliance-report-status"]') + .should('be.visible') + .and('have.text', 'Status: Submitted') +}) + +Then('they see the previously submitted report', () => { + cy.get('.ag-column-first > a > .MuiBox-root') + .should('be.visible') + .and('have.text', currentYear) +}) diff --git a/frontend/cypress/e2e/Pages/ComplianceReport/ComplianceReportManagement.feature b/frontend/cypress/e2e/Pages/ComplianceReport/ComplianceReportManagement.feature index a0e9c1975..501cf3790 100644 --- a/frontend/cypress/e2e/Pages/ComplianceReport/ComplianceReportManagement.feature +++ b/frontend/cypress/e2e/Pages/ComplianceReport/ComplianceReportManagement.feature @@ -1,12 +1,23 @@ Feature: Compliance Report Management Scenario: Supplier saves a draft compliance report - Given the supplier is on the login page - When the supplier logs in with valid credentials - And the supplier navigates to the compliance reports page + Given the user is on the login page + And the supplier logs in with valid credentials + And they navigate to the compliance reports page And the supplier creates a new compliance report Then the compliance report introduction is shown When the supplier navigates to the fuel supply page And the supplier enters a valid fuel supply row And saves and returns to the report Then the compliance report summary includes the quantity + When the supplier fills out line 6 + Then it should round the amount to 25 + When the supplier accepts the agreement + And the supplier submits the report + Then the status should change to Submitted + + Scenario: Analyst logs in to review a compliance report + Given the user is on the login page + And the analyst logs in with valid credentials + And they navigate to the compliance reports page + Then they see the previously submitted report \ No newline at end of file diff --git a/frontend/cypress/e2e/Pages/User/UserCreation.test.js b/frontend/cypress/e2e/Pages/User/UserCreation.test.js index bf9b8ff25..079dd77a6 100644 --- a/frontend/cypress/e2e/Pages/User/UserCreation.test.js +++ b/frontend/cypress/e2e/Pages/User/UserCreation.test.js @@ -25,39 +25,39 @@ When('the IDIR user logs in with valid credentials', () => { }) When('the IDIR user navigates to the user creation page', () => { - cy.get('a[href="/admin"]').click() - cy.url().should('include', '/admin/users') - cy.contains('New user').click() - cy.url().should('include', '/admin/users/add-user') + cy.get('a[href="/admin"]').click() + cy.url().should('include', '/admin/users') + cy.contains('New user').click() + cy.url().should('include', '/admin/users/add-user') }) When('the IDIR user fills out the form with valid data', () => { - cy.get('input[id="firstName"]').type('John') - cy.get('input[id="lastName"]').type('Doe') - cy.get('input[id="jobTitle"]').type('Senior Analyst') - cy.get('input[id="userName"]').type('johndoe') - cy.get('input[id="keycloakEmail"]').type('john.doe@example.com') - cy.get('input[id="phone"]').type('1234567890') - cy.get('input[id="mobilePhone"]').type('0987654321') + cy.get('input[id="firstName"]').type('John') + cy.get('input[id="lastName"]').type('Doe') + cy.get('input[id="jobTitle"]').type('Senior Analyst') + cy.get('input[id="userName"]').type('johndoe') + cy.get('input[id="keycloakEmail"]').type('john.doe@example.com') + cy.get('input[id="phone"]').type('1234567890') + cy.get('input[id="mobilePhone"]').type('0987654321') - // Select the Analyst role - cy.get('input[type="radio"][value="analyst"]').check() + // Select the Analyst role + cy.get('input[type="radio"][value="analyst"]').check() }) When('the IDIR user submits the form', () => { - cy.get('button[data-test="saveUser"]').click() + cy.get('button[data-test="saveUser"]').click() }) Then('a success message is displayed', () => { - cy.get("[data-test='alert-box'] .MuiBox-root").should( - 'contain', - 'User has been successfully saved.' - ) + cy.get("[data-test='alert-box'] .MuiBox-root").should( + 'contain', + 'User has been successfully saved.' + ) }) Then('the new user appears in the user list', () => { - cy.visit('/admin/users') - cy.contains('a', Cypress.env('john.doe@example.com')).should('be.visible') + cy.visit('/admin/users') + cy.contains('a', Cypress.env('john.doe@example.com')).should('be.visible') }) // Test for validation error @@ -75,7 +75,7 @@ When('the IDIR user fills out the form with invalid data', () => { }) Then('an error message is displayed for validation', () => { - cy.get('#userName-helper-text').should('contain', 'User name is required') + cy.get('#userName-helper-text').should('contain', 'User name is required') }) // Cleanup after the test diff --git a/frontend/cypress/e2e/accessibility.cy.js b/frontend/cypress/e2e/accessibility.cy.js index 4c44b2543..44147beff 100644 --- a/frontend/cypress/e2e/accessibility.cy.js +++ b/frontend/cypress/e2e/accessibility.cy.js @@ -18,7 +18,7 @@ describe('Accessibility Tests for LCFS', () => { it('Should have no accessibility violations in the navigation bar', () => { cy.visit('/') cy.injectAxe() - cy.login( + cy.loginWith( 'idir', Cypress.env('IDIR_TEST_USER'), Cypress.env('IDIR_TEST_PASS') diff --git a/frontend/cypress/e2e/add__edit_org.cy.js b/frontend/cypress/e2e/add__edit_org.cy.js index 1cd1e136e..1a714ee3f 100644 --- a/frontend/cypress/e2e/add__edit_org.cy.js +++ b/frontend/cypress/e2e/add__edit_org.cy.js @@ -6,7 +6,7 @@ describe('Add Organization Test Suite', () => { beforeEach(() => { - cy.login( + cy.loginWith( 'idir', Cypress.env('IDIR_TEST_USER'), Cypress.env('IDIR_TEST_PASS') diff --git a/frontend/cypress/e2e/disclaimer_banner.cy.js b/frontend/cypress/e2e/disclaimer_banner.cy.js index 363cc207a..ded9a66a4 100644 --- a/frontend/cypress/e2e/disclaimer_banner.cy.js +++ b/frontend/cypress/e2e/disclaimer_banner.cy.js @@ -5,7 +5,7 @@ describe('Disclaimer Banner Visibility Test Suite', () => { context('BCeID User', () => { beforeEach(() => { - cy.login( + cy.loginWith( 'bceid', Cypress.env('BCEID_TEST_USER'), Cypress.env('BCEID_TEST_PASS') @@ -29,7 +29,7 @@ describe('Disclaimer Banner Visibility Test Suite', () => { context('IDIR User', () => { beforeEach(() => { - cy.login( + cy.loginWith( 'idir', Cypress.env('IDIR_TEST_USER'), Cypress.env('IDIR_TEST_PASS') diff --git a/frontend/cypress/e2e/organization.cy.js b/frontend/cypress/e2e/organization.cy.js index 4ca2e7f38..2ac742e9b 100644 --- a/frontend/cypress/e2e/organization.cy.js +++ b/frontend/cypress/e2e/organization.cy.js @@ -5,7 +5,7 @@ describe('Organization Test Suite', () => { beforeEach(() => { // Login and visit the page - cy.login( + cy.loginWith( 'idir', Cypress.env('IDIR_TEST_USER'), Cypress.env('IDIR_TEST_PASS') diff --git a/frontend/cypress/e2e/user_flow.cy.js b/frontend/cypress/e2e/user_flow.cy.js index e99dde6be..233823650 100644 --- a/frontend/cypress/e2e/user_flow.cy.js +++ b/frontend/cypress/e2e/user_flow.cy.js @@ -28,12 +28,12 @@ describe('User Login Test Suite', () => { describe('IDIR Login Flow', () => { it('fails login with wrong IDIR user credentials', () => { - cy.login('idir', 'wrong_username', 'wrong_password') + cy.loginWith('idir', 'wrong_username', 'wrong_password') cy.getByDataTest('main-layout-navbar').should('not.exist') }) it('completes login with IDIR user credentials', () => { - cy.login( + cy.loginWith( 'idir', Cypress.env('IDIR_TEST_USER'), Cypress.env('IDIR_TEST_PASS') @@ -42,7 +42,7 @@ describe('User Login Test Suite', () => { }) it('executes logout functionality for IDIR user', () => { - cy.login( + cy.loginWith( 'idir', Cypress.env('IDIR_TEST_USER'), Cypress.env('IDIR_TEST_PASS') @@ -53,12 +53,12 @@ describe('User Login Test Suite', () => { describe('BCeID Login Flow', () => { it('fails login with wrong BCeID user credentials', () => { - cy.login('bceid', 'wrong_username', 'wrong_password') + cy.loginWith('bceid', 'wrong_username', 'wrong_password') cy.getByDataTest('main-layout-navbar').should('not.exist') }) it('completes login with BCeID user credentials', () => { - cy.login( + cy.loginWith( 'bceid', Cypress.env('BCEID_TEST_USER'), Cypress.env('BCEID_TEST_PASS') diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 127024c7d..2e6c33946 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -32,7 +32,7 @@ import { AddEditOtherUses } from './views/OtherUses/AddEditOtherUses' import { AddEditFinalSupplyEquipments } from './views/FinalSupplyEquipments/AddEditFinalSupplyEquipments' import { AddEditFuelSupplies } from './views/FuelSupplies/AddEditFuelSupplies' import { AddEditFuelExports } from './views/FuelExports/AddEditFuelExports' -import { AddEditAllocationAgreements } from './views/AllocationAgreements/AddAllocationAgreements' +import { AddEditAllocationAgreements } from './views/AllocationAgreements/AddEditAllocationAgreements' import { logout } from '@/utils/keycloak.js' import { CompareReports } from '@/views/CompareReports/CompareReports' diff --git a/frontend/src/assets/locales/en/finalSupplyEquipment.json b/frontend/src/assets/locales/en/finalSupplyEquipment.json index 00e74933f..d48862f59 100644 --- a/frontend/src/assets/locales/en/finalSupplyEquipment.json +++ b/frontend/src/assets/locales/en/finalSupplyEquipment.json @@ -27,6 +27,7 @@ "rows": "rows", "finalSupplyEquipmentColLabels": { "complianceReportId": "Compliance Report ID", + "organizationName": "Organization", "supplyFrom": "Supply date range", "kwhUsage":"kWh usage", "supplyFromDate": "Dates of supply from", diff --git a/frontend/src/components/BCDataGrid/components/Renderers/ValidationRenderer2.jsx b/frontend/src/components/BCDataGrid/components/Renderers/ValidationRenderer2.jsx index 043b127b5..5d4931865 100644 --- a/frontend/src/components/BCDataGrid/components/Renderers/ValidationRenderer2.jsx +++ b/frontend/src/components/BCDataGrid/components/Renderers/ValidationRenderer2.jsx @@ -18,7 +18,7 @@ export const ValidationRenderer2 = ({ data }) => { ) case 'error': return ( - + { ...options, mutationFn: async (fuelCodeID) => { return await client.delete( - apiRoutes.updateFuelCode.replace(':fuelCodeId', fuelCodeID) + apiRoutes.deleteFuelCode.replace(':fuelCodeId', fuelCodeID) ) } }) diff --git a/frontend/src/layouts/MainLayout/components/UserProfileActions.jsx b/frontend/src/layouts/MainLayout/components/UserProfileActions.jsx index f5f7dbc68..37ebd0557 100644 --- a/frontend/src/layouts/MainLayout/components/UserProfileActions.jsx +++ b/frontend/src/layouts/MainLayout/components/UserProfileActions.jsx @@ -38,7 +38,6 @@ export const UserProfileActions = () => { refetchInterval: 60000 // Automatically refetch every 1 minute (60000ms) }) const notificationsCount = notificationsData?.count || 0 - console.log(notificationsData) // Call refetch whenever the route changes useEffect(() => { diff --git a/frontend/src/views/AllocationAgreements/AddAllocationAgreements.jsx b/frontend/src/views/AllocationAgreements/AddEditAllocationAgreements.jsx similarity index 94% rename from frontend/src/views/AllocationAgreements/AddAllocationAgreements.jsx rename to frontend/src/views/AllocationAgreements/AddEditAllocationAgreements.jsx index 7d3a50512..f09d3ee0e 100644 --- a/frontend/src/views/AllocationAgreements/AddAllocationAgreements.jsx +++ b/frontend/src/views/AllocationAgreements/AddEditAllocationAgreements.jsx @@ -71,7 +71,21 @@ export const AddEditAllocationAgreements = () => { const onGridReady = useCallback( async (params) => { setGridApi(params.api) - setRowData([...(data.allocationAgreements || { id: uuid() })]) + + if ( + Array.isArray(data.allocationAgreements) && + data.allocationAgreements.length > 0 + ) { + const updatedRowData = data.allocationAgreements.map((item) => ({ + ...item, + id: item.id || uuid() // Ensure every item has a unique ID + })) + setRowData(updatedRowData) + } else { + // If allocationAgreements is not available or empty, initialize with a single row + setRowData([{ id: uuid() }]) + } + params.api.sizeColumnsToFit() }, [data] diff --git a/frontend/src/views/AllocationAgreements/__tests__/AddEditAllocationAgreements.test.jsx b/frontend/src/views/AllocationAgreements/__tests__/AddEditAllocationAgreements.test.jsx new file mode 100644 index 000000000..ce7d72113 --- /dev/null +++ b/frontend/src/views/AllocationAgreements/__tests__/AddEditAllocationAgreements.test.jsx @@ -0,0 +1,133 @@ +// src/views/AllocationAgreements/__tests__/AddEditAllocationAgreements.test.jsx + +import React from 'react' +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import { describe, it, expect, beforeEach, vi } from 'vitest' +import { AddEditAllocationAgreements } from '../AddEditAllocationAgreements' +import * as useAllocationAgreementHook from '@/hooks/useAllocationAgreement' +import { wrapper } from '@/tests/utils/wrapper' + +// Mock react-router-dom hooks +const mockUseLocation = vi.fn() +const mockUseNavigate = vi.fn() +const mockUseParams = vi.fn() + +vi.mock('react-router-dom', () => ({ + ...vi.importActual('react-router-dom'), + useLocation: () => mockUseLocation(), + useNavigate: () => mockUseNavigate(), + useParams: () => mockUseParams() +})) + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key) => key + }) +})) + +// Mock hooks related to allocation agreements +vi.mock('@/hooks/useAllocationAgreement') + +// Mock BCGridEditor component +vi.mock('@/components/BCDataGrid/BCGridEditor', () => ({ + BCGridEditor: ({ + gridRef, + alertRef, + onGridReady, + rowData, + onCellValueChanged, + onCellEditingStopped + }) => ( +
+
+ {rowData.map((row, index) => ( +
+ {row.id} +
+ ))} +
+
+ ) +})) + +describe('AddEditAllocationAgreements', () => { + beforeEach(() => { + vi.resetAllMocks() + + // Mock react-router-dom hooks with complete location object + mockUseLocation.mockReturnValue({ + pathname: '/test-path', // Include pathname to prevent undefined errors + state: {} + }) + mockUseNavigate.mockReturnValue(vi.fn()) + mockUseParams.mockReturnValue({ + complianceReportId: 'testReportId', + compliancePeriod: '2024' + }) + + // Mock useGetAllocationAgreements hook to return empty data initially + vi.mocked( + useAllocationAgreementHook.useGetAllocationAgreements + ).mockReturnValue({ + data: { allocationAgreements: [] }, + isLoading: false + }) + + // Mock useAllocationAgreementOptions hook + vi.mocked( + useAllocationAgreementHook.useAllocationAgreementOptions + ).mockReturnValue({ + data: { fuelTypes: [] }, + isLoading: false, + isFetched: true + }) + + // Mock useSaveAllocationAgreement hook + vi.mocked( + useAllocationAgreementHook.useSaveAllocationAgreement + ).mockReturnValue({ + mutateAsync: vi.fn() + }) + }) + + it('renders the component', () => { + render(, { wrapper }) + expect( + screen.getByText('allocationAgreement:addAllocationAgreementRowsTitle') + ).toBeInTheDocument() + }) + + it('initializes with at least one row in the empty state', () => { + render(, { wrapper }) + const rows = screen.getAllByTestId('grid-row') + expect(rows.length).toBe(1) // Ensure at least one row exists + }) + + it('loads data when allocationAgreements are available', async () => { + // Update the mock to return allocation agreements + vi.mocked( + useAllocationAgreementHook.useGetAllocationAgreements + ).mockReturnValue({ + data: { + allocationAgreements: [ + { allocationAgreementId: 'testId1' }, + { allocationAgreementId: 'testId2' } + ] + }, + isLoading: false + }) + + render(, { wrapper }) + + // Use findAllByTestId for asynchronous elements + const rows = await screen.findAllByTestId('grid-row') + expect(rows.length).toBe(2) + // Check that each row's textContent matches UUID format + const uuidRegex = + /^[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$/i + rows.forEach((row) => { + expect(uuidRegex.test(row.textContent)).toBe(true) + }) + }) +}) diff --git a/frontend/src/views/FinalSupplyEquipments/FinalSupplyEquipmentSummary.jsx b/frontend/src/views/FinalSupplyEquipments/FinalSupplyEquipmentSummary.jsx index 84437a91b..1a2f7851e 100644 --- a/frontend/src/views/FinalSupplyEquipments/FinalSupplyEquipmentSummary.jsx +++ b/frontend/src/views/FinalSupplyEquipments/FinalSupplyEquipmentSummary.jsx @@ -48,6 +48,12 @@ export const FinalSupplyEquipmentSummary = ({ data }) => { ) const columns = useMemo( () => [ + { + headerName: t( + 'finalSupplyEquipment:finalSupplyEquipmentColLabels.organizationName' + ), + field: 'organizationName' + }, { headerName: t( 'finalSupplyEquipment:finalSupplyEquipmentColLabels.supplyFromDate' diff --git a/frontend/src/views/FinalSupplyEquipments/_schema.jsx b/frontend/src/views/FinalSupplyEquipments/_schema.jsx index d57cbb739..656471a6c 100644 --- a/frontend/src/views/FinalSupplyEquipments/_schema.jsx +++ b/frontend/src/views/FinalSupplyEquipments/_schema.jsx @@ -12,7 +12,7 @@ import i18n from '@/i18n' import { actions, validation } from '@/components/BCDataGrid/columns' import moment from 'moment' import { CommonArrayRenderer } from '@/utils/grid/cellRenderers' -import { StandardCellErrors } from '@/utils/grid/errorRenderers' +import { StandardCellWarningAndErrors, StandardCellErrors } from '@/utils/grid/errorRenderers' import { apiRoutes } from '@/constants/routes' import { numberFormatter } from '@/utils/formatters.js' @@ -41,6 +41,42 @@ export const finalSupplyEquipmentColDefs = ( cellDataType: 'text', hide: true }, + { + field: 'organizationName', + headerComponent: RequiredHeader, + headerName: i18n.t( + 'finalSupplyEquipment:finalSupplyEquipmentColLabels.organizationName' + ), + cellEditor: AutocompleteCellEditor, + cellRenderer: (params) => + params.value || + (!params.value && Select), + cellEditorParams: { + options: optionsData?.organizationNames?.sort() || [], + multiple: false, + disableCloseOnSelect: false, + freeSolo: true, + openOnFocus: true, + }, + cellStyle: (params) => + StandardCellWarningAndErrors(params, errors), + suppressKeyboardEvent, + minWidth: 260, + editable: true, + valueGetter: (params) => { + return params.data?.organizationName || ''; + }, + valueSetter: (params) => { + if (params.newValue) { + const isValidOrganizationName = optionsData?.organizationNames.includes(params.newValue); + + params.data.organizationName = isValidOrganizationName ? params.newValue : params.newValue; + return true; + } + return false; + }, + tooltipValueGetter: (params) => "Select the organization name from the list" + }, { field: 'supplyFrom', headerName: i18n.t( diff --git a/frontend/src/views/FuelCodes/AddFuelCode/AddEditFuelCode.jsx b/frontend/src/views/FuelCodes/AddFuelCode/AddEditFuelCode.jsx index 1b5d6448c..c06a73a7b 100644 --- a/frontend/src/views/FuelCodes/AddFuelCode/AddEditFuelCode.jsx +++ b/frontend/src/views/FuelCodes/AddFuelCode/AddEditFuelCode.jsx @@ -25,6 +25,7 @@ import BCModal from '@/components/BCModal' import BCTypography from '@/components/BCTypography' import { FUEL_CODE_STATUSES } from '@/constants/statuses' import { useCurrentUser } from '@/hooks/useCurrentUser' +import Papa from 'papaparse' const AddEditFuelCodeBase = () => { const { fuelCodeID } = useParams() @@ -197,6 +198,7 @@ const AddEditFuelCodeBase = () => { } else { const res = await createFuelCode(updatedData) updatedData.fuelCodeId = res.data.fuelCodeId + updatedData.fuelSuffix = res.data.fuelSuffix } updatedData = { @@ -210,7 +212,9 @@ const AddEditFuelCodeBase = () => { }) } catch (error) { setErrors({ - [params.node.data.id]: error.response.data.errors[0].fields + [params.node.data.id]: + error.response.data?.errors && + error.response.data?.errors[0]?.fields }) updatedData = { @@ -229,10 +233,12 @@ const AddEditFuelCodeBase = () => { const errMsg = `Error updating row: ${ fieldLabels.length === 1 ? fieldLabels[0] : '' } ${message}` - + updatedData.validationMsg = errMsg handleError(error, errMsg) } else { - handleError(error, `Error updating row: ${error.message}`) + const errMsg = error.response?.data?.detail || error.message + updatedData.validationMsg = errMsg + handleError(error, `Error updating row: ${errMsg}`) } } @@ -241,6 +247,69 @@ const AddEditFuelCodeBase = () => { [updateFuelCode, t] ) + const handlePaste = useCallback( + (event, { api, columnApi }) => { + const newData = [] + const clipboardData = event.clipboardData || window.clipboardData + const pastedData = clipboardData.getData('text/plain') + const headerRow = api + .getAllDisplayedColumns() + .map((column) => column.colDef.field) + .filter((col) => col) + .join('\t') + const parsedData = Papa.parse(headerRow + '\n' + pastedData, { + delimiter: '\t', + header: true, + transform: (value, field) => { + // Check for date fields and format them + const dateRegex = /^\d{4}-\d{2}-\d{2}$/ // Matches YYYY-MM-DD format + if (field.toLowerCase().includes('date') && !dateRegex.test(value)) { + const parsedDate = new Date(value) + if (!isNaN(parsedDate)) { + return parsedDate.toISOString().split('T')[0] // Format as YYYY-MM-DD + } + } + const num = Number(value) // Attempt to convert to a number if possible + return isNaN(num) ? value : num // Return the number if valid, otherwise keep as string + }, + skipEmptyLines: true + }) + if (parsedData.data?.length < 0) { + return + } + parsedData.data.forEach((row) => { + const newRow = { ...row } + newRow.id = uuid() + newRow.prefixId = optionsData?.fuelCodePrefixes?.find( + (o) => o.prefix === row.prefix + )?.fuelCodePrefixId + newRow.fuelTypeId = optionsData?.fuelTypes?.find( + (o) => o.fuelType === row.fuelType + )?.fuelTypeId + newRow.fuelSuffix = newRow.fuelSuffix.toString() + newRow.feedstockFuelTransportMode = row.feedstockFuelTransportMode + .split(',') + .map((item) => item.trim()) + newRow.finishedFuelTransportMode = row.finishedFuelTransportMode + .split(',') + .map((item) => item.trim()) + newRow.modified = true + newData.push(newRow) + }) + const transactions = api.applyTransaction({ add: newData }) + // Trigger onCellEditingStopped event to update the row in backend. + transactions.add.forEach((node) => { + onCellEditingStopped({ + node, + oldValue: '', + newvalue: undefined, + ...api + }) + }) + }, + [onCellEditingStopped, optionsData] + ) + const duplicateFuelCode = async (params) => { const rowData = { ...params.data, @@ -327,6 +396,7 @@ const AddEditFuelCodeBase = () => { onAction={onAction} showAddRowsButton={!existingFuelCode && hasRoles(roles.analyst)} context={{ errors }} + handlePaste={handlePaste} /> {existingFuelCode?.fuelCodeStatus.status !== FUEL_CODE_STATUSES.APPROVED && ( diff --git a/frontend/src/views/FuelCodes/AddFuelCode/_schema.jsx b/frontend/src/views/FuelCodes/AddFuelCode/_schema.jsx index e42f2b8c7..66d033af6 100644 --- a/frontend/src/views/FuelCodes/AddFuelCode/_schema.jsx +++ b/frontend/src/views/FuelCodes/AddFuelCode/_schema.jsx @@ -97,9 +97,9 @@ export const fuelCodeColDefs = (optionsData, errors, isCreate, canEdit) => [ const selectedPrefix = optionsData?.fuelCodePrefixes?.find( (obj) => obj.prefix === params.newValue ) - params.data.fuelTypeId = selectedPrefix.fuelCodePrefixId + params.data.fuelCodePrefixId = selectedPrefix.fuelCodePrefixId - params.data.fuelCode = optionsData?.fuelCodePrefixes?.find( + params.data.fuelSuffix = optionsData?.fuelCodePrefixes?.find( (obj) => obj.prefix === params.newValue )?.nextFuelCode params.data.company = undefined @@ -327,12 +327,12 @@ export const fuelCodeColDefs = (optionsData, errors, isCreate, canEdit) => [ return selectedOption.fuelType } const selectedOption = optionsData?.fuelTypes?.find( - (obj) => obj.fuelType === params.data.fuel + (obj) => obj.fuelType === params.data.fuelType ) if (selectedOption) { params.data.fuelTypeId = selectedOption.fuelTypeId } - return params.data.fuel + return params.data.fuelType }, valueSetter: (params) => { if (params.newValue) { @@ -341,6 +341,7 @@ export const fuelCodeColDefs = (optionsData, errors, isCreate, canEdit) => [ ) params.data.fuelTypeId = selectedFuelType.fuelTypeId } + return params.data.fuelType }, cellEditorParams: { options: optionsData?.fuelTypes diff --git a/frontend/src/views/Notifications/NotificationMenu/components/BCeIDNotificationSettings.jsx b/frontend/src/views/Notifications/NotificationMenu/components/BCeIDNotificationSettings.jsx index 74af3935c..d96cd56e9 100644 --- a/frontend/src/views/Notifications/NotificationMenu/components/BCeIDNotificationSettings.jsx +++ b/frontend/src/views/Notifications/NotificationMenu/components/BCeIDNotificationSettings.jsx @@ -29,7 +29,7 @@ const BCeIDNotificationSettings = () => { ) } diff --git a/frontend/src/views/Notifications/NotificationMenu/components/NotificationSettingsForm.jsx b/frontend/src/views/Notifications/NotificationMenu/components/NotificationSettingsForm.jsx index 6fd482744..5a7e0bd93 100644 --- a/frontend/src/views/Notifications/NotificationMenu/components/NotificationSettingsForm.jsx +++ b/frontend/src/views/Notifications/NotificationMenu/components/NotificationSettingsForm.jsx @@ -25,8 +25,6 @@ import { notificationTypes, notificationChannels } from '@/constants/notificationTypes' -import { useNavigate } from 'react-router-dom' -import { ROUTES } from '@/constants/routes' import MailIcon from '@mui/icons-material/Mail' import NotificationsIcon from '@mui/icons-material/Notifications' import BCButton from '@/components/BCButton' @@ -40,19 +38,18 @@ const NotificationSettingsForm = ({ initialEmail }) => { const { t } = useTranslation(['notifications']) - const navigate = useNavigate() const [isFormLoading, setIsFormLoading] = useState(false) const [message, setMessage] = useState('') const { data: subscriptionsData, isLoading: isSubscriptionsLoading } = useNotificationSubscriptions() const createSubscription = useCreateSubscription() const deleteSubscription = useDeleteSubscription() - const updateNotificationsEmail = useUpdateNotificationsEmail() + const updateEmail = useUpdateNotificationsEmail() // Validation schema const validationSchema = Yup.object().shape({ ...(showEmailField && { - notificationEmail: Yup.string() + email: Yup.string() .email(t('errors.invalidEmail')) .required(t('errors.emailRequired')) }) @@ -75,7 +72,7 @@ const NotificationSettingsForm = ({ }) } if (showEmailField && initialEmail) { - setValue('notificationEmail', initialEmail) + setValue('email', initialEmail) } }, [subscriptionsData, showEmailField, initialEmail, setValue]) @@ -122,8 +119,8 @@ const NotificationSettingsForm = ({ try { if (showEmailField) { // BCeID user, save the email address - await updateNotificationsEmail.mutateAsync({ - notifications_email: data.notificationEmail + await updateEmail.mutateAsync({ + email: data.email }) setMessage(t('messages.emailSaved')) } @@ -423,11 +420,11 @@ const NotificationSettingsForm = ({ > {showEmailField && ( - - {t('notificationsEmail')}: + + {t('email')}: ( { helperText={errors.complianceUnits?.message} value={formattedValue} onChange={(e) => { - // Remove all non-digit characters - const numericValue = e.target.value.replace(/\D/g, '') + // Remove all non-digit characters (other than - at the front) + const numericValue = e.target.value.replace( + /(?!^-)[^0-9]/g, + '' + ) // Update the form state with the raw number onChange(numericValue) }}