Skip to content

Commit

Permalink
more refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
qcdyx committed Nov 25, 2024
1 parent 4745fef commit d8a498a
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,14 @@ def test_batch_datasets(mock_client, mock_publish):
]


@patch("batch_datasets.src.main.start_db_session")
def test_batch_datasets_exception(start_db_session_mock):
@patch("batch_datasets.src.main.Database")
def test_batch_datasets_exception(database_mock):
exception_message = "Failure occurred"
start_db_session_mock.side_effect = Exception(exception_message)
mock_session = MagicMock()
mock_session.side_effect = Exception(exception_message)

database_mock.return_value.start_db_session.return_value = mock_session

with pytest.raises(Exception) as exec_info:
batch_datasets(Mock())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,12 @@ def test_extract_location_exception_2(
"GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json",
},
)
@patch("extract_location.src.main.start_db_session")
@patch("extract_location.src.main.Database")
@patch("extract_location.src.main.pubsub_v1.PublisherClient")
@patch("extract_location.src.main.Logger")
@patch("uuid.uuid4")
def test_extract_location_batch(
self, uuid_mock, logger_mock, publisher_client_mock, start_db_session_mock
self, uuid_mock, logger_mock, publisher_client_mock, database_mock
):
mock_session = MagicMock()
mock_dataset1 = Gtfsdataset(
Expand All @@ -300,7 +300,7 @@ def test_extract_location_batch(
mock_dataset2,
]
uuid_mock.return_value = "batch-uuid"
start_db_session_mock.return_value = mock_session
database_mock.return_value.start_db_session.return_value = mock_session

mock_publisher = MagicMock()
publisher_client_mock.return_value = mock_publisher
Expand Down Expand Up @@ -358,10 +358,13 @@ def test_extract_location_batch_no_topic_name(self, logger_mock):
"GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json",
},
)
@patch("extract_location.src.main.start_db_session")
@patch("extract_location.src.main.Database")
@patch("extract_location.src.main.Logger")
def test_extract_location_batch_exception(self, logger_mock, start_db_session_mock):
start_db_session_mock.side_effect = Exception("Database error")
def test_extract_location_batch_exception(self, logger_mock, database_mock):
mock_session = MagicMock()
mock_session.side_effect = Exception("Database error")

database_mock.return_value.start_db_session.return_value = mock_session

response = extract_location_batch(None)
self.assertEqual(response, ("Error while fetching datasets.", 500))
42 changes: 21 additions & 21 deletions functions-python/gbfs_validator/tests/test_gbfs_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TestMainFunctions(unittest.TestCase):
"VALIDATOR_URL": "https://mock-validator-url.com",
},
)
@patch("gbfs_validator.src.main.start_db_session")
@patch("gbfs_validator.src.main.Database")
@patch("gbfs_validator.src.main.DatasetTraceService")
@patch("gbfs_validator.src.main.fetch_gbfs_files")
@patch("gbfs_validator.src.main.GBFSValidator.create_gbfs_json_with_bucket_paths")
Expand All @@ -47,11 +47,11 @@ def test_gbfs_validator_pubsub(
mock_create_gbfs_json,
mock_fetch_gbfs_files,
mock_dataset_trace_service,
mock_start_db_session,
mock_database,
):
# Prepare mocks
mock_session = MagicMock()
mock_start_db_session.return_value = mock_session
mock_database.return_value.start_db_session.return_value = mock_session

mock_trace_service = MagicMock()
mock_dataset_trace_service.return_value = mock_trace_service
Expand Down Expand Up @@ -95,16 +95,16 @@ def test_gbfs_validator_pubsub(
"PUBSUB_TOPIC_NAME": "mock-topic",
},
)
@patch("gbfs_validator.src.main.start_db_session")
@patch("gbfs_validator.src.main.Database")
@patch("gbfs_validator.src.main.pubsub_v1.PublisherClient")
@patch("gbfs_validator.src.main.fetch_all_gbfs_feeds")
@patch("gbfs_validator.src.main.Logger")
def test_gbfs_validator_batch(
self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_start_db_session
self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_database
):
# Prepare mocks
mock_session = MagicMock()
mock_start_db_session.return_value = mock_session
mock_database.return_value.start_db_session.return_value = mock_session

mock_publisher = MagicMock()
mock_publisher_client.return_value = mock_publisher
Expand All @@ -131,11 +131,11 @@ def test_gbfs_validator_batch_missing_topic(self, _): # mock_logger
result = gbfs_validator_batch(None)
self.assertEqual(result[1], 500)

@patch("gbfs_validator.src.main.start_db_session")
@patch("gbfs_validator.src.main.Database")
@patch("gbfs_validator.src.main.Logger")
def test_fetch_all_gbfs_feeds(self, _, mock_start_db_session):
def test_fetch_all_gbfs_feeds(self, _, mock_database):
mock_session = MagicMock()
mock_start_db_session.return_value = mock_session
mock_database.return_value.start_db_session.return_value = mock_session
mock_feed = MagicMock()
mock_session.query.return_value.options.return_value.all.return_value = [
mock_feed
Expand All @@ -144,14 +144,14 @@ def test_fetch_all_gbfs_feeds(self, _, mock_start_db_session):
result = fetch_all_gbfs_feeds()
self.assertEqual(result, [mock_feed])

mock_start_db_session.assert_called_once()
mock_database.assert_called_once()
mock_session.close.assert_called_once()

@patch("gbfs_validator.src.main.start_db_session")
@patch("gbfs_validator.src.main.Database")
@patch("gbfs_validator.src.main.Logger")
def test_fetch_all_gbfs_feeds_exception(self, _, mock_start_db_session):
def test_fetch_all_gbfs_feeds_exception(self, _, mock_database):
mock_session = MagicMock()
mock_start_db_session.return_value = mock_session
mock_database.return_value.start_db_session.return_value = mock_session

# Simulate an exception when querying the database
mock_session.query.side_effect = Exception("Database error")
Expand All @@ -161,19 +161,19 @@ def test_fetch_all_gbfs_feeds_exception(self, _, mock_start_db_session):

self.assertTrue("Database error" in str(context.exception))

mock_start_db_session.assert_called_once()
mock_database.assert_called_once()
mock_session.close.assert_called_once()

@patch("gbfs_validator.src.main.start_db_session")
def test_fetch_all_gbfs_feeds_none_session(self, mock_start_db_session):
mock_start_db_session.return_value = None
@patch("gbfs_validator.src.main.Database")
def test_fetch_all_gbfs_feeds_none_session(self, mock_database):
mock_database.return_value = None

with self.assertRaises(Exception) as context:
fetch_all_gbfs_feeds()

self.assertTrue("NoneType" in str(context.exception))

mock_start_db_session.assert_called_once()
mock_database.assert_called_once()

@patch.dict(
os.environ,
Expand All @@ -199,16 +199,16 @@ def test_gbfs_validator_batch_fetch_exception(self, _, mock_fetch_all_gbfs_feeds
"PUBSUB_TOPIC_NAME": "mock-topic",
},
)
@patch("gbfs_validator.src.main.start_db_session")
@patch("gbfs_validator.src.main.Database")
@patch("gbfs_validator.src.main.pubsub_v1.PublisherClient")
@patch("gbfs_validator.src.main.fetch_all_gbfs_feeds")
@patch("gbfs_validator.src.main.Logger")
def test_gbfs_validator_batch_publish_exception(
self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_start_db_session
self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_database
):
# Prepare mocks
mock_session = MagicMock()
mock_start_db_session.return_value = mock_session
mock_database.return_value.start_db_session.return_value = mock_session

mock_publisher_client.side_effect = Exception("Pub/Sub error")

Expand Down
10 changes: 3 additions & 7 deletions functions-python/validation_to_ndjson/src/utils/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sqlalchemy.orm import joinedload

from database_gen.sqlacodegen_models import Feed, Location
from helpers.database import start_db_session
from helpers.database import Database


def get_feed_location(data_type: str, stable_id: str) -> List[Location]:
Expand All @@ -14,9 +14,8 @@ def get_feed_location(data_type: str, stable_id: str) -> List[Location]:
@param stable_id: The stable ID of the feed.
@return: A list of locations.
"""
session = None
try:
session = start_db_session(os.getenv("FEEDS_DATABASE_URL"))
db = Database(database_url=os.getenv("FEEDS_DATABASE_URL"))
with db.start_db_session() as session:
feeds = (
session.query(Feed)
.filter(Feed.data_type == data_type)
Expand All @@ -25,6 +24,3 @@ def get_feed_location(data_type: str, stable_id: str) -> List[Location]:
.all()
)
return feeds[0].locations if feeds is not None and len(feeds) > 0 else []
finally:
if session:
session.close()
16 changes: 8 additions & 8 deletions functions-python/validation_to_ndjson/tests/test_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@


class TestFeedsLocations(unittest.TestCase):
@patch("validation_to_ndjson.src.utils.locations.start_db_session")
@patch("validation_to_ndjson.src.utils.locations.Database")
@patch("validation_to_ndjson.src.utils.locations.os.getenv")
@patch("validation_to_ndjson.src.utils.locations.joinedload")
def test_get_feeds_locations_map(self, _, mock_getenv, mock_start_db_session):
def test_get_feeds_locations_map(self, _, mock_getenv, mock_database):
mock_getenv.return_value = "fake_database_url"

mock_session = MagicMock()
mock_start_db_session.return_value = mock_session
mock_database.return_value.start_db_session.return_value = mock_session

mock_feed = MagicMock()
mock_feed.stable_id = "feed1"
Expand All @@ -28,18 +28,18 @@ def test_get_feeds_locations_map(self, _, mock_getenv, mock_start_db_session):
mock_session.query.return_value = mock_query
result = get_feed_location("gtfs", "feed1")

mock_start_db_session.assert_called_once_with("fake_database_url")
mock_database.assert_called_once_with("fake_database_url")
mock_session.query.assert_called_once() # Verify that query was called
mock_query.filter.assert_called_once() # Verify that filter was applied
mock_query.filter.return_value.filter.return_value.options.assert_called_once()
mock_query.filter.return_value.filter.return_value.options.return_value.all.assert_called_once()

self.assertEqual(result, [mock_location1, mock_location2]) # Verify the mapping

@patch("validation_to_ndjson.src.utils.locations.start_db_session")
def test_get_feeds_locations_map_no_feeds(self, mock_start_db_session):
@patch("validation_to_ndjson.src.utils.locations.Database")
def test_get_feeds_locations_map_no_feeds(self, mock_database):
mock_session = MagicMock()
mock_start_db_session.return_value = mock_session
mock_database.return_value.start_db_session.return_value = mock_session

mock_query = MagicMock()
mock_query.filter.return_value.filter.return_value.options.return_value.all.return_value = (
Expand All @@ -50,5 +50,5 @@ def test_get_feeds_locations_map_no_feeds(self, mock_start_db_session):

result = get_feed_location("test_data_type", "test_stable_id")

mock_start_db_session.assert_called_once()
mock_database.return_value.start_db_session.assert_called_once()
self.assertEqual(result, []) # The result should be an empty dictionary

0 comments on commit d8a498a

Please sign in to comment.