From d8a498a7aeae462c4f141a10dffa938bdf2599a5 Mon Sep 17 00:00:00 2001 From: Jingsi Lu Date: Mon, 25 Nov 2024 12:31:10 -0500 Subject: [PATCH] more refactoring --- .../tests/test_batch_datasets_main.py | 10 +++-- .../tests/test_location_extraction.py | 15 ++++--- .../tests/test_gbfs_validator.py | 42 +++++++++---------- .../src/utils/locations.py | 10 ++--- .../tests/test_locations.py | 16 +++---- 5 files changed, 48 insertions(+), 45 deletions(-) diff --git a/functions-python/batch_datasets/tests/test_batch_datasets_main.py b/functions-python/batch_datasets/tests/test_batch_datasets_main.py index b8423f8a1..be6b175f8 100644 --- a/functions-python/batch_datasets/tests/test_batch_datasets_main.py +++ b/functions-python/batch_datasets/tests/test_batch_datasets_main.py @@ -64,10 +64,14 @@ def test_batch_datasets(mock_client, mock_publish): ] -@patch("batch_datasets.src.main.start_db_session") -def test_batch_datasets_exception(start_db_session_mock): +@patch("batch_datasets.src.main.Database") +def test_batch_datasets_exception(database_mock): exception_message = "Failure occurred" - start_db_session_mock.side_effect = Exception(exception_message) + mock_session = MagicMock() + mock_session.side_effect = Exception(exception_message) + + database_mock.return_value.start_db_session.return_value = mock_session + with pytest.raises(Exception) as exec_info: batch_datasets(Mock()) diff --git a/functions-python/extract_location/tests/test_location_extraction.py b/functions-python/extract_location/tests/test_location_extraction.py index b57889631..4a29c0b8a 100644 --- a/functions-python/extract_location/tests/test_location_extraction.py +++ b/functions-python/extract_location/tests/test_location_extraction.py @@ -268,12 +268,12 @@ def test_extract_location_exception_2( "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", }, ) - @patch("extract_location.src.main.start_db_session") + @patch("extract_location.src.main.Database") @patch("extract_location.src.main.pubsub_v1.PublisherClient") @patch("extract_location.src.main.Logger") @patch("uuid.uuid4") def test_extract_location_batch( - self, uuid_mock, logger_mock, publisher_client_mock, start_db_session_mock + self, uuid_mock, logger_mock, publisher_client_mock, database_mock ): mock_session = MagicMock() mock_dataset1 = Gtfsdataset( @@ -300,7 +300,7 @@ def test_extract_location_batch( mock_dataset2, ] uuid_mock.return_value = "batch-uuid" - start_db_session_mock.return_value = mock_session + database_mock.return_value.start_db_session.return_value = mock_session mock_publisher = MagicMock() publisher_client_mock.return_value = mock_publisher @@ -358,10 +358,13 @@ def test_extract_location_batch_no_topic_name(self, logger_mock): "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", }, ) - @patch("extract_location.src.main.start_db_session") + @patch("extract_location.src.main.Database") @patch("extract_location.src.main.Logger") - def test_extract_location_batch_exception(self, logger_mock, start_db_session_mock): - start_db_session_mock.side_effect = Exception("Database error") + def test_extract_location_batch_exception(self, logger_mock, database_mock): + mock_session = MagicMock() + mock_session.side_effect = Exception("Database error") + + database_mock.return_value.start_db_session.return_value = mock_session response = extract_location_batch(None) self.assertEqual(response, ("Error while fetching datasets.", 500)) diff --git a/functions-python/gbfs_validator/tests/test_gbfs_validator.py b/functions-python/gbfs_validator/tests/test_gbfs_validator.py index 26e242941..b77baeedd 100644 --- a/functions-python/gbfs_validator/tests/test_gbfs_validator.py +++ b/functions-python/gbfs_validator/tests/test_gbfs_validator.py @@ -28,7 +28,7 @@ class TestMainFunctions(unittest.TestCase): "VALIDATOR_URL": "https://mock-validator-url.com", }, ) - @patch("gbfs_validator.src.main.start_db_session") + @patch("gbfs_validator.src.main.Database") @patch("gbfs_validator.src.main.DatasetTraceService") @patch("gbfs_validator.src.main.fetch_gbfs_files") @patch("gbfs_validator.src.main.GBFSValidator.create_gbfs_json_with_bucket_paths") @@ -47,11 +47,11 @@ def test_gbfs_validator_pubsub( mock_create_gbfs_json, mock_fetch_gbfs_files, mock_dataset_trace_service, - mock_start_db_session, + mock_database, ): # Prepare mocks mock_session = MagicMock() - mock_start_db_session.return_value = mock_session + mock_database.return_value.start_db_session.return_value = mock_session mock_trace_service = MagicMock() mock_dataset_trace_service.return_value = mock_trace_service @@ -95,16 +95,16 @@ def test_gbfs_validator_pubsub( "PUBSUB_TOPIC_NAME": "mock-topic", }, ) - @patch("gbfs_validator.src.main.start_db_session") + @patch("gbfs_validator.src.main.Database") @patch("gbfs_validator.src.main.pubsub_v1.PublisherClient") @patch("gbfs_validator.src.main.fetch_all_gbfs_feeds") @patch("gbfs_validator.src.main.Logger") def test_gbfs_validator_batch( - self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_start_db_session + self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_database ): # Prepare mocks mock_session = MagicMock() - mock_start_db_session.return_value = mock_session + mock_database.return_value.start_db_session.return_value = mock_session mock_publisher = MagicMock() mock_publisher_client.return_value = mock_publisher @@ -131,11 +131,11 @@ def test_gbfs_validator_batch_missing_topic(self, _): # mock_logger result = gbfs_validator_batch(None) self.assertEqual(result[1], 500) - @patch("gbfs_validator.src.main.start_db_session") + @patch("gbfs_validator.src.main.Database") @patch("gbfs_validator.src.main.Logger") - def test_fetch_all_gbfs_feeds(self, _, mock_start_db_session): + def test_fetch_all_gbfs_feeds(self, _, mock_database): mock_session = MagicMock() - mock_start_db_session.return_value = mock_session + mock_database.return_value.start_db_session.return_value = mock_session mock_feed = MagicMock() mock_session.query.return_value.options.return_value.all.return_value = [ mock_feed @@ -144,14 +144,14 @@ def test_fetch_all_gbfs_feeds(self, _, mock_start_db_session): result = fetch_all_gbfs_feeds() self.assertEqual(result, [mock_feed]) - mock_start_db_session.assert_called_once() + mock_database.assert_called_once() mock_session.close.assert_called_once() - @patch("gbfs_validator.src.main.start_db_session") + @patch("gbfs_validator.src.main.Database") @patch("gbfs_validator.src.main.Logger") - def test_fetch_all_gbfs_feeds_exception(self, _, mock_start_db_session): + def test_fetch_all_gbfs_feeds_exception(self, _, mock_database): mock_session = MagicMock() - mock_start_db_session.return_value = mock_session + mock_database.return_value.start_db_session.return_value = mock_session # Simulate an exception when querying the database mock_session.query.side_effect = Exception("Database error") @@ -161,19 +161,19 @@ def test_fetch_all_gbfs_feeds_exception(self, _, mock_start_db_session): self.assertTrue("Database error" in str(context.exception)) - mock_start_db_session.assert_called_once() + mock_database.assert_called_once() mock_session.close.assert_called_once() - @patch("gbfs_validator.src.main.start_db_session") - def test_fetch_all_gbfs_feeds_none_session(self, mock_start_db_session): - mock_start_db_session.return_value = None + @patch("gbfs_validator.src.main.Database") + def test_fetch_all_gbfs_feeds_none_session(self, mock_database): + mock_database.return_value = None with self.assertRaises(Exception) as context: fetch_all_gbfs_feeds() self.assertTrue("NoneType" in str(context.exception)) - mock_start_db_session.assert_called_once() + mock_database.assert_called_once() @patch.dict( os.environ, @@ -199,16 +199,16 @@ def test_gbfs_validator_batch_fetch_exception(self, _, mock_fetch_all_gbfs_feeds "PUBSUB_TOPIC_NAME": "mock-topic", }, ) - @patch("gbfs_validator.src.main.start_db_session") + @patch("gbfs_validator.src.main.Database") @patch("gbfs_validator.src.main.pubsub_v1.PublisherClient") @patch("gbfs_validator.src.main.fetch_all_gbfs_feeds") @patch("gbfs_validator.src.main.Logger") def test_gbfs_validator_batch_publish_exception( - self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_start_db_session + self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_database ): # Prepare mocks mock_session = MagicMock() - mock_start_db_session.return_value = mock_session + mock_database.return_value.start_db_session.return_value = mock_session mock_publisher_client.side_effect = Exception("Pub/Sub error") diff --git a/functions-python/validation_to_ndjson/src/utils/locations.py b/functions-python/validation_to_ndjson/src/utils/locations.py index b86187287..5a55cdd0b 100644 --- a/functions-python/validation_to_ndjson/src/utils/locations.py +++ b/functions-python/validation_to_ndjson/src/utils/locations.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import joinedload from database_gen.sqlacodegen_models import Feed, Location -from helpers.database import start_db_session +from helpers.database import Database def get_feed_location(data_type: str, stable_id: str) -> List[Location]: @@ -14,9 +14,8 @@ def get_feed_location(data_type: str, stable_id: str) -> List[Location]: @param stable_id: The stable ID of the feed. @return: A list of locations. """ - session = None - try: - session = start_db_session(os.getenv("FEEDS_DATABASE_URL")) + db = Database(database_url=os.getenv("FEEDS_DATABASE_URL")) + with db.start_db_session() as session: feeds = ( session.query(Feed) .filter(Feed.data_type == data_type) @@ -25,6 +24,3 @@ def get_feed_location(data_type: str, stable_id: str) -> List[Location]: .all() ) return feeds[0].locations if feeds is not None and len(feeds) > 0 else [] - finally: - if session: - session.close() diff --git a/functions-python/validation_to_ndjson/tests/test_locations.py b/functions-python/validation_to_ndjson/tests/test_locations.py index b6fcb6958..0afecf790 100644 --- a/functions-python/validation_to_ndjson/tests/test_locations.py +++ b/functions-python/validation_to_ndjson/tests/test_locations.py @@ -5,14 +5,14 @@ class TestFeedsLocations(unittest.TestCase): - @patch("validation_to_ndjson.src.utils.locations.start_db_session") + @patch("validation_to_ndjson.src.utils.locations.Database") @patch("validation_to_ndjson.src.utils.locations.os.getenv") @patch("validation_to_ndjson.src.utils.locations.joinedload") - def test_get_feeds_locations_map(self, _, mock_getenv, mock_start_db_session): + def test_get_feeds_locations_map(self, _, mock_getenv, mock_database): mock_getenv.return_value = "fake_database_url" mock_session = MagicMock() - mock_start_db_session.return_value = mock_session + mock_database.return_value.start_db_session.return_value = mock_session mock_feed = MagicMock() mock_feed.stable_id = "feed1" @@ -28,7 +28,7 @@ def test_get_feeds_locations_map(self, _, mock_getenv, mock_start_db_session): mock_session.query.return_value = mock_query result = get_feed_location("gtfs", "feed1") - mock_start_db_session.assert_called_once_with("fake_database_url") + mock_database.assert_called_once_with("fake_database_url") mock_session.query.assert_called_once() # Verify that query was called mock_query.filter.assert_called_once() # Verify that filter was applied mock_query.filter.return_value.filter.return_value.options.assert_called_once() @@ -36,10 +36,10 @@ def test_get_feeds_locations_map(self, _, mock_getenv, mock_start_db_session): self.assertEqual(result, [mock_location1, mock_location2]) # Verify the mapping - @patch("validation_to_ndjson.src.utils.locations.start_db_session") - def test_get_feeds_locations_map_no_feeds(self, mock_start_db_session): + @patch("validation_to_ndjson.src.utils.locations.Database") + def test_get_feeds_locations_map_no_feeds(self, mock_database): mock_session = MagicMock() - mock_start_db_session.return_value = mock_session + mock_database.return_value.start_db_session.return_value = mock_session mock_query = MagicMock() mock_query.filter.return_value.filter.return_value.options.return_value.all.return_value = ( @@ -50,5 +50,5 @@ def test_get_feeds_locations_map_no_feeds(self, mock_start_db_session): result = get_feed_location("test_data_type", "test_stable_id") - mock_start_db_session.assert_called_once() + mock_database.return_value.start_db_session.assert_called_once() self.assertEqual(result, []) # The result should be an empty dictionary