diff --git a/functions-python/batch_process_dataset/src/main.py b/functions-python/batch_process_dataset/src/main.py index d4e9d211a..68160c70e 100644 --- a/functions-python/batch_process_dataset/src/main.py +++ b/functions-python/batch_process_dataset/src/main.py @@ -31,7 +31,7 @@ from database_gen.sqlacodegen_models import Gtfsdataset, t_feedsearch from dataset_service.main import DatasetTraceService, DatasetTrace, Status -from helpers.database import Database +from helpers.database import Database, refresh_materialized_view import logging from helpers.logger import Logger @@ -239,7 +239,7 @@ def create_dataset(self, dataset_file: DatasetFile): session.add(latest_dataset) session.add(new_dataset) - db.refresh_materialized_view(session, t_feedsearch.name) + refresh_materialized_view(session, t_feedsearch.name) session.commit() logging.info(f"[{self.feed_stable_id}] Dataset created successfully.") except Exception as e: diff --git a/functions-python/extract_location/tests/test_location_extraction.py b/functions-python/extract_location/tests/test_location_extraction.py index 4a29c0b8a..03736ae91 100644 --- a/functions-python/extract_location/tests/test_location_extraction.py +++ b/functions-python/extract_location/tests/test_location_extraction.py @@ -300,7 +300,9 @@ def test_extract_location_batch( mock_dataset2, ] uuid_mock.return_value = "batch-uuid" - database_mock.return_value.start_db_session.return_value = mock_session + database_mock.return_value.start_db_session.return_value.__enter__.return_value = ( + mock_session + ) mock_publisher = MagicMock() publisher_client_mock.return_value = mock_publisher @@ -361,10 +363,9 @@ def test_extract_location_batch_no_topic_name(self, logger_mock): @patch("extract_location.src.main.Database") @patch("extract_location.src.main.Logger") def test_extract_location_batch_exception(self, logger_mock, database_mock): - mock_session = MagicMock() - mock_session.side_effect = Exception("Database error") - - database_mock.return_value.start_db_session.return_value = mock_session + database_mock.return_value.start_db_session.side_effect = Exception( + "Database error" + ) response = extract_location_batch(None) self.assertEqual(response, ("Error while fetching datasets.", 500)) diff --git a/functions-python/gbfs_validator/tests/test_gbfs_validator.py b/functions-python/gbfs_validator/tests/test_gbfs_validator.py index b77baeedd..64ff26607 100644 --- a/functions-python/gbfs_validator/tests/test_gbfs_validator.py +++ b/functions-python/gbfs_validator/tests/test_gbfs_validator.py @@ -13,10 +13,14 @@ gbfs_validator_batch, fetch_all_gbfs_feeds, ) -from test_utils.database_utils import default_db_url +from test_utils.database_utils import default_db_url, reset_database_class class TestMainFunctions(unittest.TestCase): + def tearDown(self) -> None: + reset_database_class() + return super().tearDown() + @patch.dict( os.environ, { @@ -131,11 +135,11 @@ def test_gbfs_validator_batch_missing_topic(self, _): # mock_logger result = gbfs_validator_batch(None) self.assertEqual(result[1], 500) - @patch("gbfs_validator.src.main.Database") + @patch("helpers.database.sessionmaker") @patch("gbfs_validator.src.main.Logger") - def test_fetch_all_gbfs_feeds(self, _, mock_database): + def test_fetch_all_gbfs_feeds(self, _, mock_sessionmaker): mock_session = MagicMock() - mock_database.return_value.start_db_session.return_value = mock_session + mock_sessionmaker.return_value.return_value = mock_session mock_feed = MagicMock() mock_session.query.return_value.options.return_value.all.return_value = [ mock_feed @@ -144,14 +148,14 @@ def test_fetch_all_gbfs_feeds(self, _, mock_database): result = fetch_all_gbfs_feeds() self.assertEqual(result, [mock_feed]) - mock_database.assert_called_once() + mock_sessionmaker.return_value.assert_called_once() mock_session.close.assert_called_once() - @patch("gbfs_validator.src.main.Database") + @patch("helpers.database.sessionmaker") @patch("gbfs_validator.src.main.Logger") - def test_fetch_all_gbfs_feeds_exception(self, _, mock_database): + def test_fetch_all_gbfs_feeds_exception(self, _, mock_sessionmaker): mock_session = MagicMock() - mock_database.return_value.start_db_session.return_value = mock_session + mock_sessionmaker.return_value.return_value = mock_session # Simulate an exception when querying the database mock_session.query.side_effect = Exception("Database error") @@ -161,7 +165,7 @@ def test_fetch_all_gbfs_feeds_exception(self, _, mock_database): self.assertTrue("Database error" in str(context.exception)) - mock_database.assert_called_once() + mock_sessionmaker.return_value.assert_called_once() mock_session.close.assert_called_once() @patch("gbfs_validator.src.main.Database") @@ -207,8 +211,6 @@ def test_gbfs_validator_batch_publish_exception( self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_database ): # Prepare mocks - mock_session = MagicMock() - mock_database.return_value.start_db_session.return_value = mock_session mock_publisher_client.side_effect = Exception("Pub/Sub error") diff --git a/functions-python/helpers/database.py b/functions-python/helpers/database.py index 1e6c6a4ee..730e3302c 100644 --- a/functions-python/helpers/database.py +++ b/functions-python/helpers/database.py @@ -20,10 +20,11 @@ from typing import Final, Optional from sqlalchemy import create_engine, text -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, Session import logging DB_REUSE_SESSION: Final[str] = "DB_REUSE_SESSION" +LOGGER = logging.getLogger(__name__) def with_db_session(func): @@ -75,7 +76,6 @@ def __init__(self, database_url: Optional[str] = None, echo: bool = True): ) self.connection_attempts: int = 0 self.Session = sessionmaker(bind=self.engine, autoflush=False) - self.logger = logging.getLogger(__name__) @contextmanager def start_db_session(self): @@ -92,14 +92,15 @@ def start_db_session(self): def is_session_reusable(): return os.getenv("%s" % DB_REUSE_SESSION, "false").lower() == "true" - def refresh_materialized_view(self, session, view_name: str) -> bool: - """ - Refresh Materialized view by name. - @return: True if the view was refreshed successfully, False otherwise - """ - try: - session.execute(text(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {view_name}")) - return True - except Exception as error: - self.logger.error(f"Error raised while refreshing view: {error}") - return False + +def refresh_materialized_view(session: "Session", view_name: str) -> bool: + """ + Refresh Materialized view by name. + @return: True if the view was refreshed successfully, False otherwise + """ + try: + session.execute(text(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {view_name}")) + return True + except Exception as error: + LOGGER.error(f"Error raised while refreshing view: {error}") + return False diff --git a/functions-python/helpers/tests/test_database.py b/functions-python/helpers/tests/test_database.py index 7a3effaab..b92f7a182 100644 --- a/functions-python/helpers/tests/test_database.py +++ b/functions-python/helpers/tests/test_database.py @@ -3,7 +3,7 @@ from typing import Final from unittest import mock -from helpers.database import refresh_materialized_view, start_db_session +from helpers.database import refresh_materialized_view, Database default_db_url: Final[ str @@ -23,15 +23,17 @@ class TestDatabase(unittest.TestCase): def test_refresh_materialized_view_existing_view(self): view_name = "feedsearch" - session = start_db_session(os.getenv("FEEDS_DATABASE_URL")) - result = refresh_materialized_view(session, view_name) + db = Database(database_url=os.getenv("FEEDS_DATABASE_URL")) + with db.start_db_session() as session: + result = refresh_materialized_view(session, view_name) self.assertTrue(result) def test_refresh_materialized_view_invalid_view(self): view_name = "invalid_view_name" - session = start_db_session(os.getenv("FEEDS_DATABASE_URL")) - result = refresh_materialized_view(session, view_name) + db = Database(database_url=os.getenv("FEEDS_DATABASE_URL")) + with db.start_db_session() as session: + result = refresh_materialized_view(session, view_name) self.assertFalse(result) diff --git a/functions-python/preprocessed_analytics/src/processors/base_analytics_processor.py b/functions-python/preprocessed_analytics/src/processors/base_analytics_processor.py index 461be3cb2..bb738b06c 100644 --- a/functions-python/preprocessed_analytics/src/processors/base_analytics_processor.py +++ b/functions-python/preprocessed_analytics/src/processors/base_analytics_processor.py @@ -13,7 +13,7 @@ Gtfsfeed, Gtfsdataset, ) -from helpers.database import start_db_session +from helpers.database import Database class NoFeedDataException(Exception): @@ -23,7 +23,9 @@ class NoFeedDataException(Exception): class BaseAnalyticsProcessor: def __init__(self, run_date): self.run_date = run_date - self.session = start_db_session(os.getenv("FEEDS_DATABASE_URL"), echo=False) + self.session = Database( + database_url=os.getenv("FEEDS_DATABASE_URL"), echo=False + ).Session() self.processed_feeds = set() self.data = [] self.feed_metrics_data = [] diff --git a/functions-python/preprocessed_analytics/tests/test_base_processor.py b/functions-python/preprocessed_analytics/tests/test_base_processor.py index bf709da22..0bcb3ae01 100644 --- a/functions-python/preprocessed_analytics/tests/test_base_processor.py +++ b/functions-python/preprocessed_analytics/tests/test_base_processor.py @@ -9,16 +9,11 @@ class TestBaseAnalyticsProcessor(unittest.TestCase): - @patch( - "preprocessed_analytics.src.processors.base_analytics_processor.start_db_session" - ) + @patch("preprocessed_analytics.src.processors.base_analytics_processor.Database") @patch( "preprocessed_analytics.src.processors.base_analytics_processor.storage.Client" ) - def setUp(self, mock_storage_client, mock_start_db_session): - self.mock_session = MagicMock() - mock_start_db_session.return_value = self.mock_session - + def setUp(self, mock_storage_client, _): self.mock_storage_client = mock_storage_client self.mock_bucket = MagicMock() self.mock_storage_client().bucket.return_value = self.mock_bucket diff --git a/functions-python/preprocessed_analytics/tests/test_gbfs_processor.py b/functions-python/preprocessed_analytics/tests/test_gbfs_processor.py index 215f4782d..6e63edea6 100644 --- a/functions-python/preprocessed_analytics/tests/test_gbfs_processor.py +++ b/functions-python/preprocessed_analytics/tests/test_gbfs_processor.py @@ -7,16 +7,11 @@ class TestGBFSAnalyticsProcessor(unittest.TestCase): - @patch( - "preprocessed_analytics.src.processors.base_analytics_processor.start_db_session" - ) + @patch("preprocessed_analytics.src.processors.base_analytics_processor.Database") @patch( "preprocessed_analytics.src.processors.base_analytics_processor.storage.Client" ) - def setUp(self, mock_storage_client, mock_start_db_session): - self.mock_session = MagicMock() - mock_start_db_session.return_value = self.mock_session - + def setUp(self, mock_storage_client, _): self.mock_storage_client = mock_storage_client self.mock_bucket = MagicMock() self.mock_storage_client().bucket.return_value = self.mock_bucket diff --git a/functions-python/preprocessed_analytics/tests/test_gtfs_processor.py b/functions-python/preprocessed_analytics/tests/test_gtfs_processor.py index fc587ba90..3b39dd3ad 100644 --- a/functions-python/preprocessed_analytics/tests/test_gtfs_processor.py +++ b/functions-python/preprocessed_analytics/tests/test_gtfs_processor.py @@ -7,16 +7,11 @@ class TestGTFSAnalyticsProcessor(unittest.TestCase): - @patch( - "preprocessed_analytics.src.processors.base_analytics_processor.start_db_session" - ) + @patch("preprocessed_analytics.src.processors.base_analytics_processor.Database") @patch( "preprocessed_analytics.src.processors.base_analytics_processor.storage.Client" ) - def setUp(self, mock_storage_client, mock_start_db_session): - self.mock_session = MagicMock() - mock_start_db_session.return_value = self.mock_session - + def setUp(self, mock_storage_client, _): self.mock_storage_client = mock_storage_client self.mock_bucket = MagicMock() self.mock_storage_client().bucket.return_value = self.mock_bucket diff --git a/functions-python/test_utils/database_utils.py b/functions-python/test_utils/database_utils.py index 98367c976..df53a46e2 100644 --- a/functions-python/test_utils/database_utils.py +++ b/functions-python/test_utils/database_utils.py @@ -84,3 +84,9 @@ def clean_testing_db(): except Exception as error: trans.rollback() logging.error(f"Error while deleting from test db: {error}") + + +def reset_database_class(): + """Resets the Database class to its initial state.""" + Database.instance = None + Database.initialized = False diff --git a/functions-python/update_validation_report/src/main.py b/functions-python/update_validation_report/src/main.py index d7c5ed299..3a46d8d91 100644 --- a/functions-python/update_validation_report/src/main.py +++ b/functions-python/update_validation_report/src/main.py @@ -29,7 +29,7 @@ from sqlalchemy.engine.interfaces import Any from database_gen.sqlacodegen_models import Gtfsdataset, Gtfsfeed, Validationreport -from helpers.database import start_db_session +from helpers.database import Database from google.cloud import workflows_v1 from google.cloud.workflows import executions_v1 from google.cloud.workflows.executions_v1 import Execution @@ -72,10 +72,11 @@ def update_validation_report(request: flask.Request): validator_version = get_validator_version(validator_endpoint) logging.info(f"Accessing bucket {bucket_name}") - session = start_db_session(os.getenv("FEEDS_DATABASE_URL"), echo=False) - latest_datasets = get_latest_datasets_without_validation_reports( - session, validator_version, force_update - ) + db = Database(os.getenv("FEEDS_DATABASE_URL"), echo=False) + with db.start_db_session() as session: + latest_datasets = get_latest_datasets_without_validation_reports( + session, validator_version, force_update + ) logging.info(f"Retrieved {len(latest_datasets)} latest datasets.") valid_latest_datasets = get_datasets_for_validation(latest_datasets) diff --git a/functions-python/validation_report_processor/tests/test_validation_report.py b/functions-python/validation_report_processor/tests/test_validation_report.py index b071b8a29..50896b547 100644 --- a/functions-python/validation_report_processor/tests/test_validation_report.py +++ b/functions-python/validation_report_processor/tests/test_validation_report.py @@ -11,7 +11,7 @@ Gtfsfeed, Validationreport, ) -from helpers.database import start_db_session +from helpers.database import Database from test_utils.database_utils import default_db_url from validation_report_processor.src.main import ( read_json_report, @@ -51,10 +51,11 @@ def test_read_json_report_failure(self, mock_get): def test_get_feature(self): """Test get_feature function.""" - session = start_db_session(default_db_url) + session = Database(default_db_url).Session() feature_name = faker.word() feature = get_feature(feature_name, session) session.add(feature) + session.flush() same_feature = get_feature(feature_name, session) self.assertIsInstance(feature, Feature) @@ -65,7 +66,7 @@ def test_get_feature(self): def test_get_dataset(self): """Test get_dataset function.""" - session = start_db_session(default_db_url) + session = Database(default_db_url).Session() dataset_stable_id = faker.word() dataset = get_dataset(dataset_stable_id, session) self.assertIsNone(dataset) @@ -79,6 +80,7 @@ def test_get_dataset(self): try: session.add(feed) session.add(dataset) + session.flush() returned_dataset = get_dataset(dataset_stable_id, session) self.assertIsNotNone(returned_dataset) self.assertEqual(returned_dataset, dataset) @@ -116,7 +118,7 @@ def test_create_validation_report_entities(self, mock_get): dataset = Gtfsdataset( id=faker.word(), feed_id=feed.id, stable_id=dataset_stable_id, latest=True ) - session = start_db_session(default_db_url) + session = Database(default_db_url).Session() try: session.add(feed) session.add(dataset) diff --git a/functions-python/validation_to_ndjson/tests/test_locations.py b/functions-python/validation_to_ndjson/tests/test_locations.py index 0afecf790..996befe19 100644 --- a/functions-python/validation_to_ndjson/tests/test_locations.py +++ b/functions-python/validation_to_ndjson/tests/test_locations.py @@ -12,7 +12,9 @@ def test_get_feeds_locations_map(self, _, mock_getenv, mock_database): mock_getenv.return_value = "fake_database_url" mock_session = MagicMock() - mock_database.return_value.start_db_session.return_value = mock_session + mock_database.return_value.start_db_session.return_value.__enter__.return_value = ( + mock_session + ) mock_feed = MagicMock() mock_feed.stable_id = "feed1" @@ -28,7 +30,7 @@ def test_get_feeds_locations_map(self, _, mock_getenv, mock_database): mock_session.query.return_value = mock_query result = get_feed_location("gtfs", "feed1") - mock_database.assert_called_once_with("fake_database_url") + mock_database.assert_called_once_with(database_url="fake_database_url") mock_session.query.assert_called_once() # Verify that query was called mock_query.filter.assert_called_once() # Verify that filter was applied mock_query.filter.return_value.filter.return_value.options.assert_called_once()