Skip to content

Commit

Permalink
fixed python functions unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
qcdyx committed Nov 26, 2024
1 parent d8a498a commit b06b42b
Show file tree
Hide file tree
Showing 13 changed files with 74 additions and 70 deletions.
4 changes: 2 additions & 2 deletions functions-python/batch_process_dataset/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from database_gen.sqlacodegen_models import Gtfsdataset, t_feedsearch
from dataset_service.main import DatasetTraceService, DatasetTrace, Status
from helpers.database import Database
from helpers.database import Database, refresh_materialized_view
import logging

from helpers.logger import Logger
Expand Down Expand Up @@ -239,7 +239,7 @@ def create_dataset(self, dataset_file: DatasetFile):
session.add(latest_dataset)
session.add(new_dataset)

db.refresh_materialized_view(session, t_feedsearch.name)
refresh_materialized_view(session, t_feedsearch.name)
session.commit()
logging.info(f"[{self.feed_stable_id}] Dataset created successfully.")
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,9 @@ def test_extract_location_batch(
mock_dataset2,
]
uuid_mock.return_value = "batch-uuid"
database_mock.return_value.start_db_session.return_value = mock_session
database_mock.return_value.start_db_session.return_value.__enter__.return_value = (
mock_session
)

mock_publisher = MagicMock()
publisher_client_mock.return_value = mock_publisher
Expand Down Expand Up @@ -361,10 +363,9 @@ def test_extract_location_batch_no_topic_name(self, logger_mock):
@patch("extract_location.src.main.Database")
@patch("extract_location.src.main.Logger")
def test_extract_location_batch_exception(self, logger_mock, database_mock):
mock_session = MagicMock()
mock_session.side_effect = Exception("Database error")

database_mock.return_value.start_db_session.return_value = mock_session
database_mock.return_value.start_db_session.side_effect = Exception(
"Database error"
)

response = extract_location_batch(None)
self.assertEqual(response, ("Error while fetching datasets.", 500))
24 changes: 13 additions & 11 deletions functions-python/gbfs_validator/tests/test_gbfs_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
gbfs_validator_batch,
fetch_all_gbfs_feeds,
)
from test_utils.database_utils import default_db_url
from test_utils.database_utils import default_db_url, reset_database_class


class TestMainFunctions(unittest.TestCase):
def tearDown(self) -> None:
reset_database_class()
return super().tearDown()

@patch.dict(
os.environ,
{
Expand Down Expand Up @@ -131,11 +135,11 @@ def test_gbfs_validator_batch_missing_topic(self, _): # mock_logger
result = gbfs_validator_batch(None)
self.assertEqual(result[1], 500)

@patch("gbfs_validator.src.main.Database")
@patch("helpers.database.sessionmaker")
@patch("gbfs_validator.src.main.Logger")
def test_fetch_all_gbfs_feeds(self, _, mock_database):
def test_fetch_all_gbfs_feeds(self, _, mock_sessionmaker):
mock_session = MagicMock()
mock_database.return_value.start_db_session.return_value = mock_session
mock_sessionmaker.return_value.return_value = mock_session
mock_feed = MagicMock()
mock_session.query.return_value.options.return_value.all.return_value = [
mock_feed
Expand All @@ -144,14 +148,14 @@ def test_fetch_all_gbfs_feeds(self, _, mock_database):
result = fetch_all_gbfs_feeds()
self.assertEqual(result, [mock_feed])

mock_database.assert_called_once()
mock_sessionmaker.return_value.assert_called_once()
mock_session.close.assert_called_once()

@patch("gbfs_validator.src.main.Database")
@patch("helpers.database.sessionmaker")
@patch("gbfs_validator.src.main.Logger")
def test_fetch_all_gbfs_feeds_exception(self, _, mock_database):
def test_fetch_all_gbfs_feeds_exception(self, _, mock_sessionmaker):
mock_session = MagicMock()
mock_database.return_value.start_db_session.return_value = mock_session
mock_sessionmaker.return_value.return_value = mock_session

# Simulate an exception when querying the database
mock_session.query.side_effect = Exception("Database error")
Expand All @@ -161,7 +165,7 @@ def test_fetch_all_gbfs_feeds_exception(self, _, mock_database):

self.assertTrue("Database error" in str(context.exception))

mock_database.assert_called_once()
mock_sessionmaker.return_value.assert_called_once()
mock_session.close.assert_called_once()

@patch("gbfs_validator.src.main.Database")
Expand Down Expand Up @@ -207,8 +211,6 @@ def test_gbfs_validator_batch_publish_exception(
self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_database
):
# Prepare mocks
mock_session = MagicMock()
mock_database.return_value.start_db_session.return_value = mock_session

mock_publisher_client.side_effect = Exception("Pub/Sub error")

Expand Down
27 changes: 14 additions & 13 deletions functions-python/helpers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from typing import Final, Optional

from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import sessionmaker, Session
import logging

DB_REUSE_SESSION: Final[str] = "DB_REUSE_SESSION"
LOGGER = logging.getLogger(__name__)


def with_db_session(func):
Expand Down Expand Up @@ -75,7 +76,6 @@ def __init__(self, database_url: Optional[str] = None, echo: bool = True):
)
self.connection_attempts: int = 0
self.Session = sessionmaker(bind=self.engine, autoflush=False)
self.logger = logging.getLogger(__name__)

@contextmanager
def start_db_session(self):
Expand All @@ -92,14 +92,15 @@ def start_db_session(self):
def is_session_reusable():
return os.getenv("%s" % DB_REUSE_SESSION, "false").lower() == "true"

def refresh_materialized_view(self, session, view_name: str) -> bool:
"""
Refresh Materialized view by name.
@return: True if the view was refreshed successfully, False otherwise
"""
try:
session.execute(text(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {view_name}"))
return True
except Exception as error:
self.logger.error(f"Error raised while refreshing view: {error}")
return False

def refresh_materialized_view(session: "Session", view_name: str) -> bool:
"""
Refresh Materialized view by name.
@return: True if the view was refreshed successfully, False otherwise
"""
try:
session.execute(text(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {view_name}"))
return True
except Exception as error:
LOGGER.error(f"Error raised while refreshing view: {error}")
return False
12 changes: 7 additions & 5 deletions functions-python/helpers/tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Final
from unittest import mock

from helpers.database import refresh_materialized_view, start_db_session
from helpers.database import refresh_materialized_view, Database

default_db_url: Final[
str
Expand All @@ -23,15 +23,17 @@ class TestDatabase(unittest.TestCase):
def test_refresh_materialized_view_existing_view(self):
view_name = "feedsearch"

session = start_db_session(os.getenv("FEEDS_DATABASE_URL"))
result = refresh_materialized_view(session, view_name)
db = Database(database_url=os.getenv("FEEDS_DATABASE_URL"))
with db.start_db_session() as session:
result = refresh_materialized_view(session, view_name)

self.assertTrue(result)

def test_refresh_materialized_view_invalid_view(self):
view_name = "invalid_view_name"

session = start_db_session(os.getenv("FEEDS_DATABASE_URL"))
result = refresh_materialized_view(session, view_name)
db = Database(database_url=os.getenv("FEEDS_DATABASE_URL"))
with db.start_db_session() as session:
result = refresh_materialized_view(session, view_name)

self.assertFalse(result)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Gtfsfeed,
Gtfsdataset,
)
from helpers.database import start_db_session
from helpers.database import Database


class NoFeedDataException(Exception):
Expand All @@ -23,7 +23,9 @@ class NoFeedDataException(Exception):
class BaseAnalyticsProcessor:
def __init__(self, run_date):
self.run_date = run_date
self.session = start_db_session(os.getenv("FEEDS_DATABASE_URL"), echo=False)
self.session = Database(
database_url=os.getenv("FEEDS_DATABASE_URL"), echo=False
).Session()
self.processed_feeds = set()
self.data = []
self.feed_metrics_data = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,11 @@


class TestBaseAnalyticsProcessor(unittest.TestCase):
@patch(
"preprocessed_analytics.src.processors.base_analytics_processor.start_db_session"
)
@patch("preprocessed_analytics.src.processors.base_analytics_processor.Database")
@patch(
"preprocessed_analytics.src.processors.base_analytics_processor.storage.Client"
)
def setUp(self, mock_storage_client, mock_start_db_session):
self.mock_session = MagicMock()
mock_start_db_session.return_value = self.mock_session

def setUp(self, mock_storage_client, _):
self.mock_storage_client = mock_storage_client
self.mock_bucket = MagicMock()
self.mock_storage_client().bucket.return_value = self.mock_bucket
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,11 @@


class TestGBFSAnalyticsProcessor(unittest.TestCase):
@patch(
"preprocessed_analytics.src.processors.base_analytics_processor.start_db_session"
)
@patch("preprocessed_analytics.src.processors.base_analytics_processor.Database")
@patch(
"preprocessed_analytics.src.processors.base_analytics_processor.storage.Client"
)
def setUp(self, mock_storage_client, mock_start_db_session):
self.mock_session = MagicMock()
mock_start_db_session.return_value = self.mock_session

def setUp(self, mock_storage_client, _):
self.mock_storage_client = mock_storage_client
self.mock_bucket = MagicMock()
self.mock_storage_client().bucket.return_value = self.mock_bucket
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,11 @@


class TestGTFSAnalyticsProcessor(unittest.TestCase):
@patch(
"preprocessed_analytics.src.processors.base_analytics_processor.start_db_session"
)
@patch("preprocessed_analytics.src.processors.base_analytics_processor.Database")
@patch(
"preprocessed_analytics.src.processors.base_analytics_processor.storage.Client"
)
def setUp(self, mock_storage_client, mock_start_db_session):
self.mock_session = MagicMock()
mock_start_db_session.return_value = self.mock_session

def setUp(self, mock_storage_client, _):
self.mock_storage_client = mock_storage_client
self.mock_bucket = MagicMock()
self.mock_storage_client().bucket.return_value = self.mock_bucket
Expand Down
6 changes: 6 additions & 0 deletions functions-python/test_utils/database_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,9 @@ def clean_testing_db():
except Exception as error:
trans.rollback()
logging.error(f"Error while deleting from test db: {error}")


def reset_database_class():
"""Resets the Database class to its initial state."""
Database.instance = None
Database.initialized = False
11 changes: 6 additions & 5 deletions functions-python/update_validation_report/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from sqlalchemy.engine.interfaces import Any

from database_gen.sqlacodegen_models import Gtfsdataset, Gtfsfeed, Validationreport
from helpers.database import start_db_session
from helpers.database import Database
from google.cloud import workflows_v1
from google.cloud.workflows import executions_v1
from google.cloud.workflows.executions_v1 import Execution
Expand Down Expand Up @@ -72,10 +72,11 @@ def update_validation_report(request: flask.Request):
validator_version = get_validator_version(validator_endpoint)
logging.info(f"Accessing bucket {bucket_name}")

session = start_db_session(os.getenv("FEEDS_DATABASE_URL"), echo=False)
latest_datasets = get_latest_datasets_without_validation_reports(
session, validator_version, force_update
)
db = Database(os.getenv("FEEDS_DATABASE_URL"), echo=False)
with db.start_db_session() as session:
latest_datasets = get_latest_datasets_without_validation_reports(
session, validator_version, force_update
)
logging.info(f"Retrieved {len(latest_datasets)} latest datasets.")

valid_latest_datasets = get_datasets_for_validation(latest_datasets)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Gtfsfeed,
Validationreport,
)
from helpers.database import start_db_session
from helpers.database import Database
from test_utils.database_utils import default_db_url
from validation_report_processor.src.main import (
read_json_report,
Expand Down Expand Up @@ -51,10 +51,11 @@ def test_read_json_report_failure(self, mock_get):

def test_get_feature(self):
"""Test get_feature function."""
session = start_db_session(default_db_url)
session = Database(default_db_url).Session()
feature_name = faker.word()
feature = get_feature(feature_name, session)
session.add(feature)
session.flush()
same_feature = get_feature(feature_name, session)

self.assertIsInstance(feature, Feature)
Expand All @@ -65,7 +66,7 @@ def test_get_feature(self):

def test_get_dataset(self):
"""Test get_dataset function."""
session = start_db_session(default_db_url)
session = Database(default_db_url).Session()
dataset_stable_id = faker.word()
dataset = get_dataset(dataset_stable_id, session)
self.assertIsNone(dataset)
Expand All @@ -79,6 +80,7 @@ def test_get_dataset(self):
try:
session.add(feed)
session.add(dataset)
session.flush()
returned_dataset = get_dataset(dataset_stable_id, session)
self.assertIsNotNone(returned_dataset)
self.assertEqual(returned_dataset, dataset)
Expand Down Expand Up @@ -116,7 +118,7 @@ def test_create_validation_report_entities(self, mock_get):
dataset = Gtfsdataset(
id=faker.word(), feed_id=feed.id, stable_id=dataset_stable_id, latest=True
)
session = start_db_session(default_db_url)
session = Database(default_db_url).Session()
try:
session.add(feed)
session.add(dataset)
Expand Down
6 changes: 4 additions & 2 deletions functions-python/validation_to_ndjson/tests/test_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ def test_get_feeds_locations_map(self, _, mock_getenv, mock_database):
mock_getenv.return_value = "fake_database_url"

mock_session = MagicMock()
mock_database.return_value.start_db_session.return_value = mock_session
mock_database.return_value.start_db_session.return_value.__enter__.return_value = (
mock_session
)

mock_feed = MagicMock()
mock_feed.stable_id = "feed1"
Expand All @@ -28,7 +30,7 @@ def test_get_feeds_locations_map(self, _, mock_getenv, mock_database):
mock_session.query.return_value = mock_query
result = get_feed_location("gtfs", "feed1")

mock_database.assert_called_once_with("fake_database_url")
mock_database.assert_called_once_with(database_url="fake_database_url")
mock_session.query.assert_called_once() # Verify that query was called
mock_query.filter.assert_called_once() # Verify that filter was applied
mock_query.filter.return_value.filter.return_value.options.assert_called_once()
Expand Down

0 comments on commit b06b42b

Please sign in to comment.