Skip to content

Commit

Permalink
more refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
qcdyx committed Nov 27, 2024
1 parent 0ccb413 commit 82aaef6
Show file tree
Hide file tree
Showing 17 changed files with 182 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,14 @@ def test_batch_datasets(mock_client, mock_publish):
]


@patch("batch_datasets.src.main.start_db_session")
def test_batch_datasets_exception(start_db_session_mock):
@patch("batch_datasets.src.main.Database")
def test_batch_datasets_exception(database_mock):
exception_message = "Failure occurred"
start_db_session_mock.side_effect = Exception(exception_message)
mock_session = MagicMock()
mock_session.side_effect = Exception(exception_message)

database_mock.return_value.start_db_session.return_value = mock_session

with pytest.raises(Exception) as exec_info:
batch_datasets(Mock())

Expand Down
76 changes: 38 additions & 38 deletions functions-python/batch_process_dataset/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import zipfile
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from typing import Optional, TYPE_CHECKING

import functions_framework
from cloudevents.http import CloudEvent
Expand All @@ -31,12 +31,15 @@

from database_gen.sqlacodegen_models import Gtfsdataset, t_feedsearch
from dataset_service.main import DatasetTraceService, DatasetTrace, Status
from helpers.database import Database
from helpers.database import Database, refresh_materialized_view, with_db_session
import logging

from helpers.logger import Logger
from helpers.utils import download_and_get_hash

if TYPE_CHECKING:
from sqlalchemy.orm import Session


@dataclass
class DatasetFile:
Expand Down Expand Up @@ -202,50 +205,47 @@ def generate_temp_filename(self):
)
return temporary_file_path

def create_dataset(self, dataset_file: DatasetFile):
@with_db_session
def create_dataset(self, dataset_file: DatasetFile, db_session: "Session"):
"""
Creates a new dataset in the database
"""
db = Database(database_url=os.getenv("FEEDS_DATABASE_URL"))
try:
with db.start_db_session() as session:
# # Check latest version of the dataset
latest_dataset = (
session.query(Gtfsdataset)
.filter_by(latest=True, feed_id=self.feed_id)
.one_or_none()
)
if not latest_dataset:
logging.info(
f"[{self.feed_stable_id}] No latest dataset found for feed."
)

# Check latest version of the dataset
latest_dataset = (
db_session.query(Gtfsdataset)
.filter_by(latest=True, feed_id=self.feed_id)
.one_or_none()
)
if not latest_dataset:
logging.info(
f"[{self.feed_stable_id}] Creating new dataset for feed with stable id {dataset_file.stable_id}."
)
new_dataset = Gtfsdataset(
id=str(uuid.uuid4()),
feed_id=self.feed_id,
stable_id=dataset_file.stable_id,
latest=True,
bounding_box=None,
note=None,
hash=dataset_file.file_sha256_hash,
downloaded_at=func.now(),
hosted_url=dataset_file.hosted_url,
f"[{self.feed_stable_id}] No latest dataset found for feed."
)
if latest_dataset:
latest_dataset.latest = False
session.add(latest_dataset)
session.add(new_dataset)

db.refresh_materialized_view(session, t_feedsearch.name)
session.commit()
logging.info(f"[{self.feed_stable_id}] Dataset created successfully.")

logging.info(
f"[{self.feed_stable_id}] Creating new dataset for feed with stable id {dataset_file.stable_id}."
)
new_dataset = Gtfsdataset(
id=str(uuid.uuid4()),
feed_id=self.feed_id,
stable_id=dataset_file.stable_id,
latest=True,
bounding_box=None,
note=None,
hash=dataset_file.file_sha256_hash,
downloaded_at=func.now(),
hosted_url=dataset_file.hosted_url,
)
if latest_dataset:
latest_dataset.latest = False
db_session.add(latest_dataset)
db_session.add(new_dataset)

refresh_materialized_view(db_session, t_feedsearch.name)
db_session.commit()
logging.info(f"[{self.feed_stable_id}] Dataset created successfully.")
except Exception as e:
raise Exception(f"Error creating dataset: {e}")
finally:
pass

def process(self) -> DatasetFile or None:
"""
Expand Down
2 changes: 0 additions & 2 deletions functions-python/extract_location/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ def extract_location_pubsub(cloud_event: CloudEvent):
error = f"Error updating location information in database: {e}"
logging.error(f"[{dataset_id}] Error while processing: {e}")
raise e
finally:
pass
logging.info(
f"[{stable_id} - {dataset_id}] Location information updated successfully."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,12 @@ def test_extract_location_exception_2(
"GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json",
},
)
@patch("extract_location.src.main.start_db_session")
@patch("extract_location.src.main.Database")
@patch("extract_location.src.main.pubsub_v1.PublisherClient")
@patch("extract_location.src.main.Logger")
@patch("uuid.uuid4")
def test_extract_location_batch(
self, uuid_mock, logger_mock, publisher_client_mock, start_db_session_mock
self, uuid_mock, logger_mock, publisher_client_mock, database_mock
):
mock_session = MagicMock()
mock_dataset1 = Gtfsdataset(
Expand All @@ -300,7 +300,9 @@ def test_extract_location_batch(
mock_dataset2,
]
uuid_mock.return_value = "batch-uuid"
start_db_session_mock.return_value = mock_session
database_mock.return_value.start_db_session.return_value.__enter__.return_value = (
mock_session
)

mock_publisher = MagicMock()
publisher_client_mock.return_value = mock_publisher
Expand Down Expand Up @@ -358,10 +360,12 @@ def test_extract_location_batch_no_topic_name(self, logger_mock):
"GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json",
},
)
@patch("extract_location.src.main.start_db_session")
@patch("extract_location.src.main.Database")
@patch("extract_location.src.main.Logger")
def test_extract_location_batch_exception(self, logger_mock, start_db_session_mock):
start_db_session_mock.side_effect = Exception("Database error")
def test_extract_location_batch_exception(self, logger_mock, database_mock):
database_mock.return_value.start_db_session.side_effect = Exception(
"Database error"
)

response = extract_location_batch(None)
self.assertEqual(response, ("Error while fetching datasets.", 500))
21 changes: 8 additions & 13 deletions functions-python/gbfs_validator/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import functions_framework
from cloudevents.http import CloudEvent
from google.cloud import pubsub_v1, storage
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import joinedload, Session
import traceback
from database_gen.sqlacodegen_models import Gbfsfeed
from dataset_service.main import (
Expand All @@ -18,7 +18,7 @@
PipelineStage,
MaxExecutionsReachedError,
)
from helpers.database import Database
from helpers.database import Database, with_db_session
from helpers.logger import Logger, StableIdFilter
from helpers.parser import jsonify_pubsub
from .gbfs_utils import (
Expand All @@ -33,19 +33,16 @@
BUCKET_NAME = os.getenv("BUCKET_NAME", "mobilitydata-gbfs-snapshots-dev")


def fetch_all_gbfs_feeds() -> List[Gbfsfeed]:
db = Database(database_url=os.getenv("FEEDS_DATABASE_URL"))
@with_db_session
def fetch_all_gbfs_feeds(db_session: "Session") -> List[Gbfsfeed]:
try:
with db.start_db_session() as session:
gbfs_feeds = (
session.query(Gbfsfeed).options(joinedload(Gbfsfeed.gbfsversions)).all()
)
return gbfs_feeds
gbfs_feeds = (
db_session.query(Gbfsfeed).options(joinedload(Gbfsfeed.gbfsversions)).all()
)
return gbfs_feeds
except Exception as e:
logging.error(f"Error fetching all GBFS feeds: {e}")
raise e
finally:
pass


@functions_framework.cloud_event
Expand Down Expand Up @@ -114,8 +111,6 @@ def gbfs_validator_pubsub(cloud_event: CloudEvent):
logging.error(f"{error_message}\nTraceback:\n{traceback.format_exc()}")
save_trace_with_error(trace, error_message, trace_service)
return error_message
finally:
pass

trace.status = Status.SUCCESS
trace_service.save(trace)
Expand Down
56 changes: 33 additions & 23 deletions functions-python/gbfs_validator/tests/test_gbfs_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
gbfs_validator_batch,
fetch_all_gbfs_feeds,
)
from test_utils.database_utils import default_db_url
from test_utils.database_utils import default_db_url, reset_database_class
from helpers.database import Database


class TestMainFunctions(unittest.TestCase):
def tearDown(self) -> None:
reset_database_class()
return super().tearDown()

@patch.dict(
os.environ,
{
Expand All @@ -28,7 +33,7 @@ class TestMainFunctions(unittest.TestCase):
"VALIDATOR_URL": "https://mock-validator-url.com",
},
)
@patch("gbfs_validator.src.main.start_db_session")
@patch("gbfs_validator.src.main.Database")
@patch("gbfs_validator.src.main.DatasetTraceService")
@patch("gbfs_validator.src.main.fetch_gbfs_files")
@patch("gbfs_validator.src.main.GBFSValidator.create_gbfs_json_with_bucket_paths")
Expand All @@ -47,11 +52,11 @@ def test_gbfs_validator_pubsub(
mock_create_gbfs_json,
mock_fetch_gbfs_files,
mock_dataset_trace_service,
mock_start_db_session,
mock_database,
):
# Prepare mocks
mock_session = MagicMock()
mock_start_db_session.return_value = mock_session
mock_database.return_value.start_db_session.return_value = mock_session

mock_trace_service = MagicMock()
mock_dataset_trace_service.return_value = mock_trace_service
Expand Down Expand Up @@ -95,16 +100,16 @@ def test_gbfs_validator_pubsub(
"PUBSUB_TOPIC_NAME": "mock-topic",
},
)
@patch("gbfs_validator.src.main.start_db_session")
@patch("helpers.database.Database")
@patch("gbfs_validator.src.main.pubsub_v1.PublisherClient")
@patch("gbfs_validator.src.main.fetch_all_gbfs_feeds")
@patch("gbfs_validator.src.main.Logger")
def test_gbfs_validator_batch(
self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_start_db_session
self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_database
):
# Prepare mocks
mock_session = MagicMock()
mock_start_db_session.return_value = mock_session
mock_database.return_value.start_db_session.return_value = mock_session

mock_publisher = MagicMock()
mock_publisher_client.return_value = mock_publisher
Expand All @@ -131,11 +136,15 @@ def test_gbfs_validator_batch_missing_topic(self, _): # mock_logger
result = gbfs_validator_batch(None)
self.assertEqual(result[1], 500)

@patch("gbfs_validator.src.main.start_db_session")
@patch("helpers.database.Database")
@patch("gbfs_validator.src.main.Logger")
def test_fetch_all_gbfs_feeds(self, _, mock_start_db_session):
def test_fetch_all_gbfs_feeds(self, _, mock_database):
mock_session = MagicMock()
mock_start_db_session.return_value = mock_session
db = Database()
db._get_session = MagicMock()
db._get_session.return_value.return_value = mock_session
mock_database.return_value = db

mock_feed = MagicMock()
mock_session.query.return_value.options.return_value.all.return_value = [
mock_feed
Expand All @@ -144,14 +153,17 @@ def test_fetch_all_gbfs_feeds(self, _, mock_start_db_session):
result = fetch_all_gbfs_feeds()
self.assertEqual(result, [mock_feed])

mock_start_db_session.assert_called_once()
db._get_session.return_value.assert_called_once()
mock_session.close.assert_called_once()

@patch("gbfs_validator.src.main.start_db_session")
@patch("helpers.database.Database")
@patch("gbfs_validator.src.main.Logger")
def test_fetch_all_gbfs_feeds_exception(self, _, mock_start_db_session):
def test_fetch_all_gbfs_feeds_exception(self, _, mock_database):
mock_session = MagicMock()
mock_start_db_session.return_value = mock_session
db = Database()
db._get_session = MagicMock()
db._get_session.return_value.return_value = mock_session
mock_database.return_value = db

# Simulate an exception when querying the database
mock_session.query.side_effect = Exception("Database error")
Expand All @@ -161,19 +173,19 @@ def test_fetch_all_gbfs_feeds_exception(self, _, mock_start_db_session):

self.assertTrue("Database error" in str(context.exception))

mock_start_db_session.assert_called_once()
db._get_session.return_value.assert_called_once()
mock_session.close.assert_called_once()

@patch("gbfs_validator.src.main.start_db_session")
def test_fetch_all_gbfs_feeds_none_session(self, mock_start_db_session):
mock_start_db_session.return_value = None
@patch("helpers.database.Database")
def test_fetch_all_gbfs_feeds_none_session(self, mock_database):
mock_database.return_value = None

with self.assertRaises(Exception) as context:
fetch_all_gbfs_feeds()

self.assertTrue("NoneType" in str(context.exception))

mock_start_db_session.assert_called_once()
mock_database.assert_called_once()

@patch.dict(
os.environ,
Expand All @@ -199,16 +211,14 @@ def test_gbfs_validator_batch_fetch_exception(self, _, mock_fetch_all_gbfs_feeds
"PUBSUB_TOPIC_NAME": "mock-topic",
},
)
@patch("gbfs_validator.src.main.start_db_session")
@patch("helpers.database.Database")
@patch("gbfs_validator.src.main.pubsub_v1.PublisherClient")
@patch("gbfs_validator.src.main.fetch_all_gbfs_feeds")
@patch("gbfs_validator.src.main.Logger")
def test_gbfs_validator_batch_publish_exception(
self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_start_db_session
self, _, mock_fetch_all_gbfs_feeds, mock_publisher_client, mock_database
):
# Prepare mocks
mock_session = MagicMock()
mock_start_db_session.return_value = mock_session

mock_publisher_client.side_effect = Exception("Pub/Sub error")

Expand Down
Loading

0 comments on commit 82aaef6

Please sign in to comment.