Skip to content

Commit

Permalink
changes in await
Browse files Browse the repository at this point in the history
  • Loading branch information
areyeslo committed Dec 11, 2024
1 parent 7b56dac commit be92276
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
19 changes: 9 additions & 10 deletions backend/lcfs/tests/other_uses/test_other_uses_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,22 +207,21 @@ async def test_get_latest_other_uses_by_group_uuid(other_uses_repo, mock_db_sess
mock_other_use_gov.user_type = UserTypeEnum.GOVERNMENT
mock_other_use_gov.version = 2

mock_other_use_supplier = MagicMock(spec=OtherUses)
mock_other_use_supplier.user_type = UserTypeEnum.SUPPLIER
mock_other_use_supplier.version = 3

# Mock response with both government and supplier versions
mock_db_session.execute.return_value.scalars.return_value.first.side_effect = [
mock_other_use_gov,
mock_other_use_supplier,
]
# Setup mock result chain
mock_result = AsyncMock()
mock_result.unique = MagicMock(return_value=mock_result)
mock_result.scalars = MagicMock(return_value=mock_result)
mock_result.first = MagicMock(return_value=mock_other_use_gov)

# Configure mock db session
mock_db_session.execute = AsyncMock(return_value=mock_result)
other_uses_repo.db = mock_db_session

result = await other_uses_repo.get_latest_other_uses_by_group_uuid(group_uuid)

assert result.user_type == UserTypeEnum.GOVERNMENT
assert result.version == 2


@pytest.mark.anyio
async def test_get_other_use_version_by_user(other_uses_repo, mock_db_session):
group_uuid = "test-group-uuid"
Expand Down
15 changes: 10 additions & 5 deletions backend/lcfs/web/api/other_uses/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def get_latest_other_uses_by_group_uuid(
)

result = await self.db.execute(query)
return await result.unique().scalars().first()
return result.unique().scalars().first()

@repo_handler
async def get_other_uses(self, compliance_report_id: int) -> List[OtherUsesSchema]:
Expand Down Expand Up @@ -310,8 +310,13 @@ async def get_formatted_fuel_types(self) -> List[Dict[str, Any]]:
# Define the filtering conditions for fuel codes
current_date = date.today()
fuel_code_filters = (
or_(FuelCode.effective_date == None, FuelCode.effective_date <= current_date)
& or_(FuelCode.expiration_date == None, FuelCode.expiration_date > current_date)
or_(
FuelCode.effective_date == None, FuelCode.effective_date <= current_date
)
& or_(
FuelCode.expiration_date == None,
FuelCode.expiration_date > current_date,
)
& (FuelType.other_uses_fossil_derived == True)
)

Expand All @@ -333,7 +338,7 @@ async def get_formatted_fuel_types(self) -> List[Dict[str, Any]]:
)

result = await self.db.execute(query)
fuel_types = await result.unique().scalars().all()
fuel_types = result.unique().scalars().all()

# Prepare the data in the format matching your schema
formatted_fuel_types = []
Expand Down Expand Up @@ -379,4 +384,4 @@ async def get_formatted_fuel_types(self) -> List[Dict[str, Any]]:
)
formatted_fuel_types.append(formatted_fuel_type)

return formatted_fuel_types
return formatted_fuel_types

0 comments on commit be92276

Please sign in to comment.