Skip to content

Commit

Permalink
Merge pull request #1293 from bcgov/feat/kevin-1193
Browse files Browse the repository at this point in the history
feat: retrieve compliance report chain
  • Loading branch information
kevin-hashimoto authored Dec 5, 2024
2 parents 591e41b + eb78bed commit 505f78c
Show file tree
Hide file tree
Showing 12 changed files with 356 additions and 180 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lcfs.web.api.compliance_report.schema import (
ComplianceReportUpdateSchema,
ComplianceReportSummaryUpdateSchema,
ChainedComplianceReportSchema,
)
from lcfs.services.s3.client import DocumentService

Expand Down Expand Up @@ -226,7 +227,9 @@ async def test_get_compliance_report_by_id_success(
) as mock_validate_organization_access:
set_mock_user(fastapi_app, [RoleEnum.GOVERNMENT])

mock_compliance_report = compliance_report_base_schema()
mock_compliance_report = ChainedComplianceReportSchema(
report=compliance_report_base_schema(), chain=[]
)

mock_get_compliance_report_by_id.return_value = mock_compliance_report
mock_validate_organization_access.return_value = None
Expand All @@ -240,7 +243,9 @@ async def test_get_compliance_report_by_id_success(
expected_response = json.loads(mock_compliance_report.json(by_alias=True))

assert response.json() == expected_response
mock_get_compliance_report_by_id.assert_called_once_with(1, False)
mock_get_compliance_report_by_id.assert_called_once_with(
1, False, get_chain=True
)
mock_validate_organization_access.assert_called_once_with(1)


Expand Down
49 changes: 29 additions & 20 deletions backend/lcfs/tests/organization/test_organization_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from lcfs.web.api.organization.validation import OrganizationValidation

from lcfs.web.api.compliance_report.services import ComplianceReportServices
from lcfs.web.api.compliance_report.schema import ChainedComplianceReportSchema


@pytest.mark.anyio
Expand Down Expand Up @@ -160,7 +161,8 @@ async def test_export_transactions_for_org_success(
):
set_mock_user(fastapi_app, [RoleEnum.SUPPLIER])

mock_transactions_services.export_transactions.return_value = {"streaming": True}
mock_transactions_services.export_transactions.return_value = {
"streaming": True}

fastapi_app.dependency_overrides[TransactionsService] = (
lambda: mock_transactions_services
Expand Down Expand Up @@ -188,7 +190,8 @@ async def test_create_transfer_success(
set_mock_user(fastapi_app, [RoleEnum.SUPPLIER])

organization_id = 1
url = fastapi_app.url_path_for("create_transfer", organization_id=organization_id)
url = fastapi_app.url_path_for(
"create_transfer", organization_id=organization_id)

payload = {"from_organization_id": 1, "to_organization_id": 2}

Expand Down Expand Up @@ -226,7 +229,8 @@ async def test_update_transfer_success(
):
set_mock_user(fastapi_app, [RoleEnum.SUPPLIER])

url = fastapi_app.url_path_for("update_transfer", organization_id=1, transfer_id=1)
url = fastapi_app.url_path_for(
"update_transfer", organization_id=1, transfer_id=1)

payload = {"from_organization_id": 1, "to_organization_id": 2}

Expand Down Expand Up @@ -274,7 +278,8 @@ async def test_create_compliance_report_success(
"create_compliance_report", organization_id=organization_id
)

payload = {"compliance_period": "2024", "organization_id": 1, "status": "status"}
payload = {"compliance_period": "2024",
"organization_id": 1, "status": "status"}

mock_organization_validation.create_compliance_report.return_value = None
mock_compliance_report_services.create_compliance_report.return_value = {
Expand Down Expand Up @@ -346,7 +351,8 @@ async def test_get_all_org_reported_years_success(
):
set_mock_user(fastapi_app, [RoleEnum.SUPPLIER])

url = fastapi_app.url_path_for("get_all_org_reported_years", organization_id=1)
url = fastapi_app.url_path_for(
"get_all_org_reported_years", organization_id=1)

mock_compliance_report_services.get_all_org_reported_years.return_value = [
{"compliance_period_id": 1, "description": "2024"}
Expand Down Expand Up @@ -379,20 +385,23 @@ async def test_get_compliance_report_by_id_success(
)

# Mock the compliance report service's method
mock_compliance_report_services.get_compliance_report_by_id.return_value = {
"compliance_report_id": 1,
"compliance_period_id": 1,
"compliance_period": {"compliance_period_id": 1, "description": "2024"},
"organization_id": 1,
"organization": {"organization_id": 1, "name": "org1"},
"current_status_id": 1,
"current_status": {"compliance_report_status_id": 1, "status": "status"},
"summary": {"summary_id": 1, "is_locked": False},
"compliance_report_group_uuid": "uuid",
"version": 0,
"supplemental_initiator": SupplementalInitiatorType.SUPPLIER_SUPPLEMENTAL,
"has_supplemental": False,
}
mock_compliance_report_services.get_compliance_report_by_id.return_value = ChainedComplianceReportSchema(
report={
"compliance_report_id": 1,
"compliance_period_id": 1,
"compliance_period": {"compliance_period_id": 1, "description": "2024"},
"organization_id": 1,
"organization": {"organization_id": 1, "name": "org1"},
"current_status_id": 1,
"current_status": {"compliance_report_status_id": 1, "status": "status"},
"summary": {"summary_id": 1, "is_locked": False},
"compliance_report_group_uuid": "uuid",
"version": 0,
"supplemental_initiator": SupplementalInitiatorType.SUPPLIER_SUPPLEMENTAL,
"has_supplemental": False,
},
chain=[]
)

# Create a mock for the validation service
mock_compliance_report_validation = AsyncMock()
Expand All @@ -412,7 +421,7 @@ async def test_get_compliance_report_by_id_success(
# Assertions
assert response.status_code == 200
mock_compliance_report_services.get_compliance_report_by_id.assert_awaited_once_with(
1, apply_masking=True
1, apply_masking=True, get_chain=True
)
mock_compliance_report_validation.validate_organization_access.assert_awaited_once_with(
1
Expand Down
78 changes: 52 additions & 26 deletions backend/lcfs/web/api/compliance_report/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
)
from lcfs.web.api.compliance_report.schema import (
ComplianceReportBaseSchema,
ComplianceReportSummarySchema,
ComplianceReportSummaryUpdateSchema,
)
from lcfs.db.models.compliance.ComplianceReportHistory import ComplianceReportHistory
Expand Down Expand Up @@ -435,34 +434,61 @@ async def get_compliance_report_by_id(self, report_id: int, is_model: bool = Fal
"""
Retrieve a compliance report from the database by ID
"""
result = (
(
await self.db.execute(
select(ComplianceReport)
.options(
joinedload(ComplianceReport.organization),
joinedload(ComplianceReport.compliance_period),
joinedload(ComplianceReport.current_status),
joinedload(ComplianceReport.summary),
joinedload(ComplianceReport.history).joinedload(
ComplianceReportHistory.status
),
joinedload(ComplianceReport.history).joinedload(
ComplianceReportHistory.user_profile
),
joinedload(ComplianceReport.transaction),
)
.where(ComplianceReport.compliance_report_id == report_id)
)
result = await self.db.execute(
select(ComplianceReport)
.options(
joinedload(ComplianceReport.organization),
joinedload(ComplianceReport.compliance_period),
joinedload(ComplianceReport.current_status),
joinedload(ComplianceReport.summary),
joinedload(ComplianceReport.history).joinedload(
ComplianceReportHistory.status
),
joinedload(ComplianceReport.history).joinedload(
ComplianceReportHistory.user_profile
),
joinedload(ComplianceReport.transaction),
)
.unique()
.scalars()
.first()
.where(ComplianceReport.compliance_report_id == report_id)
)

compliance_report = result.scalars().unique().first()

if not compliance_report:
return None

if is_model:
return result
else:
return ComplianceReportBaseSchema.model_validate(result)
return compliance_report

return ComplianceReportBaseSchema.model_validate(compliance_report)

@repo_handler
async def get_compliance_report_chain(self, group_uuid: str):
result = await self.db.execute(
select(ComplianceReport)
.options(
joinedload(ComplianceReport.organization),
joinedload(ComplianceReport.compliance_period),
joinedload(ComplianceReport.current_status),
joinedload(ComplianceReport.summary),
joinedload(ComplianceReport.history).joinedload(
ComplianceReportHistory.status
),
joinedload(ComplianceReport.history).joinedload(
ComplianceReportHistory.user_profile
),
joinedload(ComplianceReport.transaction),
)
.where(ComplianceReport.compliance_report_group_uuid == group_uuid)
.order_by(ComplianceReport.version.desc()) # Ensure ordering by version
)

compliance_reports = result.scalars().unique().all()

return [
ComplianceReportBaseSchema.model_validate(report)
for report in compliance_reports
]

@repo_handler
async def get_fuel_type(self, fuel_type_id: int) -> FuelType:
Expand Down
5 changes: 5 additions & 0 deletions backend/lcfs/web/api/compliance_report/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ class ComplianceReportBaseSchema(BaseSchema):
has_supplemental: bool


class ChainedComplianceReportSchema(BaseSchema):
report: ComplianceReportBaseSchema
chain: Optional[List[ComplianceReportBaseSchema]] = []


class ComplianceReportCreateSchema(BaseSchema):
compliance_period: str
organization_id: int
Expand Down
31 changes: 28 additions & 3 deletions backend/lcfs/web/api/compliance_report/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ async def create_compliance_report(
report_data.status
)
if not draft_status:
raise DataNotFoundException(f"Status '{report_data.status}' not found.")
raise DataNotFoundException(
f"Status '{report_data.status}' not found.")

# Generate a new group_uuid for the new report series
group_uuid = str(uuid.uuid4())
Expand Down Expand Up @@ -193,6 +194,7 @@ def _mask_report_status(self, reports: List) -> List:
ComplianceReportStatusEnum.Submitted.value
)
report.current_status.compliance_report_status_id = None

masked_reports.append(report)
else:
masked_reports.append(report)
Expand All @@ -201,22 +203,45 @@ def _mask_report_status(self, reports: List) -> List:

@service_handler
async def get_compliance_report_by_id(
self, report_id: int, apply_masking: bool = False
) -> ComplianceReportBaseSchema:
self, report_id: int, apply_masking: bool = False, get_chain: bool = False
):
"""Fetches a specific compliance report by ID."""
report = await self.repo.get_compliance_report_by_id(report_id)
if report is None:
raise DataNotFoundException("Compliance report not found.")

validated_report = ComplianceReportBaseSchema.model_validate(report)
masked_report = (
self._mask_report_status([validated_report])[0]
if apply_masking
else validated_report
)

history_masked_report = self._mask_report_status_for_history(
masked_report, apply_masking
)

if get_chain:
compliance_report_chain = await self.repo.get_compliance_report_chain(
report.compliance_report_group_uuid
)

if apply_masking:
# Apply masking to each report in the chain
masked_chain = self._mask_report_status(
compliance_report_chain)
# Apply history masking to each report in the chain
masked_chain = [
self._mask_report_status_for_history(report, apply_masking)
for report in masked_chain
]
compliance_report_chain = masked_chain

return {
"report": history_masked_report,
"chain": compliance_report_chain,
}

return history_masked_report

def _mask_report_status_for_history(
Expand Down
17 changes: 12 additions & 5 deletions backend/lcfs/web/api/compliance_report/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
ComplianceReportBaseSchema,
ComplianceReportListSchema,
ComplianceReportSummarySchema,
ComplianceReportUpdateSchema, ComplianceReportSummaryUpdateSchema,
ChainedComplianceReportSchema,
ComplianceReportUpdateSchema,
ComplianceReportSummaryUpdateSchema,
)
from lcfs.web.api.compliance_report.services import ComplianceReportServices
from lcfs.web.api.compliance_report.summary_service import (
Expand Down Expand Up @@ -66,12 +68,12 @@ async def get_compliance_reports(
pagination.filters.append(
FilterModel(field="status", filter="Draft", filter_type="text", type="notEqual")
)
return await service.get_compliance_reports_paginated(pagination)
return await service.get_compliance_reports_paginated(pagination)


@router.get(
"/{report_id}",
response_model=ComplianceReportBaseSchema,
response_model=ChainedComplianceReportSchema,
status_code=status.HTTP_200_OK,
)
@view_handler([RoleEnum.GOVERNMENT])
Expand All @@ -80,12 +82,16 @@ async def get_compliance_report_by_id(
report_id: int,
service: ComplianceReportServices = Depends(),
validate: ComplianceReportValidation = Depends(),
) -> ComplianceReportBaseSchema:
) -> ChainedComplianceReportSchema:
await validate.validate_organization_access(report_id)

mask_statuses = not user_has_roles(request.user, [RoleEnum.GOVERNMENT])

return await service.get_compliance_report_by_id(report_id, mask_statuses)
result = await service.get_compliance_report_by_id(
report_id, mask_statuses, get_chain=True
)

return result


@router.get(
Expand Down Expand Up @@ -128,6 +134,7 @@ async def update_compliance_report_summary(
report_id, summary_data
)


@view_handler(["*"])
@router.put(
"/{report_id}",
Expand Down
Loading

0 comments on commit 505f78c

Please sign in to comment.