From 2fc7a88344eddf90e2d0caa59daf21f9cc44704a Mon Sep 17 00:00:00 2001 From: Arturo Reyes Lopez Date: Wed, 11 Dec 2024 10:26:34 -0700 Subject: [PATCH] changes in await --- .../tests/other_uses/test_other_uses_repo.py | 19 +++++++++---------- backend/lcfs/web/api/other_uses/repo.py | 15 ++++++++++----- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/backend/lcfs/tests/other_uses/test_other_uses_repo.py b/backend/lcfs/tests/other_uses/test_other_uses_repo.py index 97b5308f5..8ef79cd70 100644 --- a/backend/lcfs/tests/other_uses/test_other_uses_repo.py +++ b/backend/lcfs/tests/other_uses/test_other_uses_repo.py @@ -207,22 +207,21 @@ async def test_get_latest_other_uses_by_group_uuid(other_uses_repo, mock_db_sess mock_other_use_gov.user_type = UserTypeEnum.GOVERNMENT mock_other_use_gov.version = 2 - mock_other_use_supplier = MagicMock(spec=OtherUses) - mock_other_use_supplier.user_type = UserTypeEnum.SUPPLIER - mock_other_use_supplier.version = 3 - - # Mock response with both government and supplier versions - mock_db_session.execute.return_value.scalars.return_value.first.side_effect = [ - mock_other_use_gov, - mock_other_use_supplier, - ] + # Setup mock result chain + mock_result = AsyncMock() + mock_result.unique = MagicMock(return_value=mock_result) + mock_result.scalars = MagicMock(return_value=mock_result) + mock_result.first = MagicMock(return_value=mock_other_use_gov) + + # Configure mock db session + mock_db_session.execute = AsyncMock(return_value=mock_result) + other_uses_repo.db = mock_db_session result = await other_uses_repo.get_latest_other_uses_by_group_uuid(group_uuid) assert result.user_type == UserTypeEnum.GOVERNMENT assert result.version == 2 - @pytest.mark.anyio async def test_get_other_use_version_by_user(other_uses_repo, mock_db_session): group_uuid = "test-group-uuid" diff --git a/backend/lcfs/web/api/other_uses/repo.py b/backend/lcfs/web/api/other_uses/repo.py index 595de5882..9f49a3d5d 100644 --- a/backend/lcfs/web/api/other_uses/repo.py +++ b/backend/lcfs/web/api/other_uses/repo.py @@ -76,7 +76,7 @@ async def get_latest_other_uses_by_group_uuid( ) result = await self.db.execute(query) - return await result.unique().scalars().first() + return result.unique().scalars().first() @repo_handler async def get_other_uses(self, compliance_report_id: int) -> List[OtherUsesSchema]: @@ -310,8 +310,13 @@ async def get_formatted_fuel_types(self) -> List[Dict[str, Any]]: # Define the filtering conditions for fuel codes current_date = date.today() fuel_code_filters = ( - or_(FuelCode.effective_date == None, FuelCode.effective_date <= current_date) - & or_(FuelCode.expiration_date == None, FuelCode.expiration_date > current_date) + or_( + FuelCode.effective_date == None, FuelCode.effective_date <= current_date + ) + & or_( + FuelCode.expiration_date == None, + FuelCode.expiration_date > current_date, + ) & (FuelType.other_uses_fossil_derived == True) ) @@ -333,7 +338,7 @@ async def get_formatted_fuel_types(self) -> List[Dict[str, Any]]: ) result = await self.db.execute(query) - fuel_types = await result.unique().scalars().all() + fuel_types = result.unique().scalars().all() # Prepare the data in the format matching your schema formatted_fuel_types = [] @@ -379,4 +384,4 @@ async def get_formatted_fuel_types(self) -> List[Dict[str, Any]]: ) formatted_fuel_types.append(formatted_fuel_type) - return formatted_fuel_types \ No newline at end of file + return formatted_fuel_types