From 732e484bb23acec28b1d9784aea155ead456cd46 Mon Sep 17 00:00:00 2001 From: Arturo Reyes Lopez Date: Mon, 9 Dec 2024 16:32:10 -0700 Subject: [PATCH] updating pytest --- .../tests/other_uses/test_other_uses_repo.py | 27 ++++++++++++++----- backend/lcfs/web/api/other_uses/repo.py | 6 ++--- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/backend/lcfs/tests/other_uses/test_other_uses_repo.py b/backend/lcfs/tests/other_uses/test_other_uses_repo.py index 8bd39f562..97b5308f5 100644 --- a/backend/lcfs/tests/other_uses/test_other_uses_repo.py +++ b/backend/lcfs/tests/other_uses/test_other_uses_repo.py @@ -11,14 +11,19 @@ @pytest.fixture -def mock_db_session(): +def mock_query_result(): + # Setup mock for database query result chain + mock_result = AsyncMock() + mock_result.unique = MagicMock(return_value=mock_result) + mock_result.scalars = MagicMock(return_value=mock_result) + mock_result.all = MagicMock(return_value=[MagicMock(spec=OtherUses)]) + return mock_result + + +@pytest.fixture +def mock_db_session(mock_query_result): session = MagicMock(spec=AsyncSession) - execute_result = AsyncMock() - execute_result.unique = MagicMock(return_value=execute_result) - execute_result.scalars = MagicMock(return_value=execute_result) - execute_result.all = MagicMock(return_value=[MagicMock(spec=OtherUses)]) - execute_result.first = MagicMock(return_value=MagicMock(spec=OtherUses)) - session.execute.return_value = execute_result + session.execute = AsyncMock(return_value=mock_query_result) return session @@ -29,6 +34,14 @@ def other_uses_repo(mock_db_session): repo.fuel_code_repo.get_fuel_categories = AsyncMock(return_value=[]) repo.fuel_code_repo.get_fuel_types = AsyncMock(return_value=[]) repo.fuel_code_repo.get_expected_use_types = AsyncMock(return_value=[]) + + # Mock for local get_formatted_fuel_types method + async def mock_get_formatted_fuel_types(): + mock_result = await mock_db_session.execute(AsyncMock()) + return mock_result.unique().scalars().all() + + repo.get_formatted_fuel_types = AsyncMock(side_effect=mock_get_formatted_fuel_types) + return repo diff --git a/backend/lcfs/web/api/other_uses/repo.py b/backend/lcfs/web/api/other_uses/repo.py index 8e0f5ca4d..595de5882 100644 --- a/backend/lcfs/web/api/other_uses/repo.py +++ b/backend/lcfs/web/api/other_uses/repo.py @@ -76,7 +76,7 @@ async def get_latest_other_uses_by_group_uuid( ) result = await self.db.execute(query) - return result.scalars().first() + return await result.unique().scalars().first() @repo_handler async def get_other_uses(self, compliance_report_id: int) -> List[OtherUsesSchema]: @@ -302,7 +302,7 @@ async def get_other_use_version_by_user( ) result = await self.db.execute(query) - return result.scalars().first() + return await result.scalars().first() @repo_handler async def get_formatted_fuel_types(self) -> List[Dict[str, Any]]: @@ -333,7 +333,7 @@ async def get_formatted_fuel_types(self) -> List[Dict[str, Any]]: ) result = await self.db.execute(query) - fuel_types = result.unique().scalars().all() + fuel_types = await result.unique().scalars().all() # Prepare the data in the format matching your schema formatted_fuel_types = []