Skip to content

Commit

Permalink
Set the country if missing from DB in the populate script (#701)
Browse files Browse the repository at this point in the history
  • Loading branch information
jcpitre authored Aug 20, 2024
1 parent 1d42f36 commit 48f865d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
24 changes: 24 additions & 0 deletions api/src/scripts/populate_db_gtfs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from datetime import datetime

import pycountry
import pytz
from sqlalchemy import text

Expand Down Expand Up @@ -52,6 +53,11 @@ def get_stable_id(self, row):
"""
return f'mdb-{self.get_safe_value(row, "mdb_source_id", "")}'

def get_country(self, country_code):
if country_code:
return pycountry.countries.get(alpha_2=country_code).name
return None

def populate_location(self, feed, row, stable_id):
"""
Populate the location for the feed
Expand All @@ -63,6 +69,7 @@ def populate_location(self, feed, row, stable_id):
country_code = self.get_safe_value(row, "location.country_code", "")
subdivision_name = self.get_safe_value(row, "location.subdivision_name", "")
municipality = self.get_safe_value(row, "location.municipality", "")
country = self.get_country(country_code)
location_id = self.get_location_id(country_code, subdivision_name, municipality)
if not location_id:
self.logger.warning(f"Location ID is empty for feed {stable_id}")
Expand All @@ -77,6 +84,7 @@ def populate_location(self, feed, row, stable_id):
country_code=country_code,
subdivision_name=subdivision_name,
municipality=municipality,
country=country,
)
)
feed.locations = [location]
Expand Down Expand Up @@ -213,6 +221,7 @@ def populate_db(self):
# This need to be done after all feeds are added to the session to avoid FK violation
self.process_feed_references()
self.process_redirects()
self.post_process_locations()

def trigger_downstream_tasks(self):
"""
Expand All @@ -229,6 +238,21 @@ def trigger_downstream_tasks(self):
if os.getenv("ENV", "local") != "local":
publish_all(self.added_gtfs_feeds) # Publishes the new feeds to the Pub/Sub topic to download the datasets

def post_process_locations(self):
"""
Set the country for any location entry that does not have one.
"""
query = self.db.session.query(Location).filter(Location.country.is_(None))
result = query.all()
set_country_count = 0
for location in result:
country = self.get_country(location.country_code)
if country:
location.country = country # Set the country field to the desired value
set_country_count += 1
self.db.session.commit()
self.logger.info(f"Had to set the country for {set_country_count} locations")

# Extracted the following code from main, so it can be executed as a library function
def initialize(self, trigger_downstream_tasks: bool = True):
try:
Expand Down
3 changes: 1 addition & 2 deletions api/tests/test_utils/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
def populate_database(db: Database, data_dirs: str):
try:

# Check if connected to localhost
# Check if connected to test DB.
url = make_url(db.engine.url)
if not is_test_db(url):
raise Exception("Not connected to MobilityDatabaseTest, aborting operation")
Expand All @@ -52,7 +52,6 @@ def populate_database(db: Database, data_dirs: str):
db_helper.initialize(trigger_downstream_tasks=False)

# Make a list of all the extra_test_data.json files in the test_data directories and load the data
# Make a list of all the sources_test.csv in test_data and keep only if the file exists
json_filepaths = []
for dir in data_dirs:

Expand Down

0 comments on commit 48f865d

Please sign in to comment.