diff --git a/backend/lcfs/tests/other_uses/test_other_uses_repo.py b/backend/lcfs/tests/other_uses/test_other_uses_repo.py index 8bd39f562..84ef51cf0 100644 --- a/backend/lcfs/tests/other_uses/test_other_uses_repo.py +++ b/backend/lcfs/tests/other_uses/test_other_uses_repo.py @@ -11,24 +11,58 @@ @pytest.fixture -def mock_db_session(): +def mock_query_result(): + # Setup mock for database query result chain + mock_result = AsyncMock() + mock_result.unique = MagicMock(return_value=mock_result) + mock_result.scalars = MagicMock(return_value=mock_result) + mock_result.all = MagicMock( + return_value=[ + MagicMock( + fuel_type_id=1, + fuel_type="Gasoline", + default_carbon_intensity=12.34, + units="L", + unrecognized=False, + fuel_instances=[ + MagicMock( + fuel_category=MagicMock( + fuel_category_id=1, category="Petroleum-based" + ) + ) + ], + fuel_codes=[ + MagicMock(fuel_code_id=1, fuel_code="FC123", carbon_intensity=10.5) + ], + ) + ] + ) + return mock_result + + +@pytest.fixture +def mock_db_session(mock_query_result): session = MagicMock(spec=AsyncSession) - execute_result = AsyncMock() - execute_result.unique = MagicMock(return_value=execute_result) - execute_result.scalars = MagicMock(return_value=execute_result) - execute_result.all = MagicMock(return_value=[MagicMock(spec=OtherUses)]) - execute_result.first = MagicMock(return_value=MagicMock(spec=OtherUses)) - session.execute.return_value = execute_result + session.execute = AsyncMock(return_value=mock_query_result) return session @pytest.fixture def other_uses_repo(mock_db_session): repo = OtherUsesRepository(db=mock_db_session) + repo.fuel_code_repo = MagicMock() repo.fuel_code_repo.get_fuel_categories = AsyncMock(return_value=[]) repo.fuel_code_repo.get_fuel_types = AsyncMock(return_value=[]) repo.fuel_code_repo.get_expected_use_types = AsyncMock(return_value=[]) + + # Mock for local get_formatted_fuel_types method + async def mock_get_formatted_fuel_types(): + mock_result = await mock_db_session.execute(AsyncMock()) + return mock_result.unique().scalars().all() + + repo.get_formatted_fuel_types = AsyncMock(side_effect=mock_get_formatted_fuel_types) + return repo diff --git a/backend/lcfs/web/api/other_uses/repo.py b/backend/lcfs/web/api/other_uses/repo.py index 8e0f5ca4d..595de5882 100644 --- a/backend/lcfs/web/api/other_uses/repo.py +++ b/backend/lcfs/web/api/other_uses/repo.py @@ -76,7 +76,7 @@ async def get_latest_other_uses_by_group_uuid( ) result = await self.db.execute(query) - return result.scalars().first() + return await result.unique().scalars().first() @repo_handler async def get_other_uses(self, compliance_report_id: int) -> List[OtherUsesSchema]: @@ -302,7 +302,7 @@ async def get_other_use_version_by_user( ) result = await self.db.execute(query) - return result.scalars().first() + return await result.scalars().first() @repo_handler async def get_formatted_fuel_types(self) -> List[Dict[str, Any]]: @@ -333,7 +333,7 @@ async def get_formatted_fuel_types(self) -> List[Dict[str, Any]]: ) result = await self.db.execute(query) - fuel_types = result.unique().scalars().all() + fuel_types = await result.unique().scalars().all() # Prepare the data in the format matching your schema formatted_fuel_types = []