Skip to content

Commit

Permalink
feat: Update GCP workflow to download latest dataset the day a feed i…
Browse files Browse the repository at this point in the history
…s merged in the catalogs (#411)
  • Loading branch information
cka-y authored Apr 30, 2024
1 parent c5addfd commit acf7011
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 2 deletions.
1 change: 0 additions & 1 deletion .github/workflows/db-update.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ jobs:
./scripts/tunnel-create.sh -project_id ${{ inputs.PROJECT_ID }} -zone ${{ inputs.REGION }}-a -instance ${{ env.GCP_FEED_BASTION_NAME }}-${{ inputs.DB_ENVIRONMENT}} -target_account ${{ env.GCP_FEED_SSH_USER }} -db_instance ${{ secrets.POSTGRE_SQL_INSTANCE_NAME }}
sleep 10 # Wait for the tunnel to establish
# Uncomment the following block to test the database connection through the tunnel
- name: Test Database Connection Through Tunnel
run: |
sudo apt-get update && sudo apt-get install -y postgresql-client
Expand Down
1 change: 1 addition & 0 deletions api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ cloud-sql-python-connector[pg8000]
fastapi-filter[sqlalchemy]==0.6.1
PyJWT
shapely
google-cloud-pubsub
1 change: 1 addition & 0 deletions api/src/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def start_new_db_session(self):
logging.info("Starting new global database session.")
self.engine = create_engine(self.SQLALCHEMY_DATABASE_URL, echo=self.echo_sql)
global_session = sessionmaker(bind=self.engine)()
global_session.expire_on_commit = False
self.session = global_session
return global_session
except Exception as error:
Expand Down
67 changes: 67 additions & 0 deletions api/src/scripts/load_dataset_on_create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import json
import logging
import os
import uuid
from typing import List

from database_gen.sqlacodegen_models import Feed
from google.cloud import pubsub_v1
from google.cloud.pubsub_v1.futures import Future

env = os.getenv("ENV", "dev")
pubsub_topic_name = f"datasets-batch-topic-{env}"
project_id = f"mobility-feeds-{env}"
pubsub_client = pubsub_v1.PublisherClient()


def get_topic_path():
if pubsub_topic_name is None or project_id is None:
raise ValueError("PUBSUB_TOPIC_NAME and PROJECT_ID must be set in the environment")
return pubsub_client.topic_path(project_id, pubsub_topic_name)


def publish_callback(future: Future, stable_id: str, topic_path: str):
"""
Callback function for when the message is published to Pub/Sub.
This function logs the result of the publishing operation.
@param future: Future object representing the result of the publishing operation
@param stable_id: The stable_id of the feed that was published
@param topic_path: The path to the Pub/Sub topic
"""
if future.exception():
logging.info(f"Error publishing feed {stable_id} to Pub/Sub topic {topic_path}: {future.exception()}")
else:
logging.info(f"Published stable_id = {stable_id}.")


def publish(feed: Feed, topic_path: str):
"""
Publishes a feed to the Pub/Sub topic.
:param feed: The feed to publish
:param topic_path: The path to the Pub/Sub topic
"""
payload = {
"execution_id": f"batch-uuid-{uuid.uuid4()}",
"producer_url": feed.producer_url,
"feed_stable_id": feed.stable_id,
"feed_id": feed.id,
"dataset_id": None, # The feed is not associated with a dataset
"dataset_hash": None,
"authentication_type": feed.authentication_type,
"authentication_info_url": feed.authentication_info_url,
"api_key_parameter_name": feed.api_key_parameter_name,
}
data_bytes = json.dumps(payload).encode("utf-8")
future = pubsub_client.publish(topic_path, data=data_bytes)
future.add_done_callback(lambda _: publish_callback(future, feed.stable_id, topic_path))


def publish_all(feeds: List[Feed]):
"""
Publishes a list of feeds to the Pub/Sub topic.
:param feeds: The list of feeds to publish
"""
topic_path = get_topic_path()
for feed in feeds:
publish(feed, topic_path)
logging.info(f"Published {len(feeds)} feeds to Pub/Sub topic {topic_path}.")
19 changes: 18 additions & 1 deletion api/src/scripts/populate_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
t_feedsearch,
Feed,
)
from scripts.load_dataset_on_create import publish_all
from utils.data_utils import set_up_defaults
from utils.logger import Logger

Expand Down Expand Up @@ -47,14 +48,15 @@ class DatabasePopulateHelper:
"""

def __init__(self, filepath):
self.logger = Logger(self.__class__.__module__).get_logger()
self.logger = Logger(self.__class__.__name__).get_logger()
self.logger.setLevel(logging.INFO)
self.db = Database(echo_sql=False)
self.df = pandas.read_csv(filepath) # contains the data to populate the database

# Filter unsupported data types
self.df = self.df[(self.df.data_type == "gtfs") | (self.df.data_type == "gtfs-rt")]
self.df = set_up_defaults(self.df)
self.added_gtfs_feeds = [] # Keep track of the feeds that have been added to the database

@staticmethod
def get_model(data_type: str | None) -> Type[Gtfsrealtimefeed | Gtfsfeed | Feed]:
Expand Down Expand Up @@ -220,6 +222,8 @@ def populate_db(self):
feed = self.get_model(data_type)(id=generate_unique_id(), data_type=data_type, stable_id=stable_id)
self.logger.info(f"Creating {feed.__class__.__name__}: {stable_id}")
self.db.session.add(feed)
if data_type == "gtfs":
self.added_gtfs_feeds.append(feed)
feed.externalids = [
Externalid(
feed_id=feed.id,
Expand Down Expand Up @@ -249,6 +253,18 @@ def populate_db(self):
self.process_feed_references()
self.process_redirects()

def trigger_downstream_tasks(self):
"""
Trigger downstream tasks after populating the database
"""
self.logger.info("Triggering downstream tasks")
self.logger.debug(
f"New feeds added to the database: "
f"{','.join([feed.stable_id for feed in self.added_gtfs_feeds] if self.added_gtfs_feeds else [])}"
)
if os.getenv("ENV", "local") != "local":
publish_all(self.added_gtfs_feeds) # Publishes the new feeds to the Pub/Sub topic to download the datasets


if __name__ == "__main__":
db_helper = DatabasePopulateHelper(set_up_configs())
Expand All @@ -262,6 +278,7 @@ def populate_db(self):
db_helper.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Completed")
db_helper.db.session.commit()
db_helper.logger.info("\n----- Database populated with sources.csv data. -----")
db_helper.trigger_downstream_tasks()
except Exception as e:
db_helper.logger.error(f"\n------ Failed to populate the database with sources.csv: {e} -----\n")
db_helper.db.session.rollback()
Expand Down

0 comments on commit acf7011

Please sign in to comment.