From 9d4b8d5c6881f44dc85421bcb5e2c74a3bf6fa00 Mon Sep 17 00:00:00 2001 From: cka-y <60586858+cka-y@users.noreply.github.com> Date: Wed, 14 Aug 2024 14:56:11 -0400 Subject: [PATCH] feat: Add GBFS feeds to the database (#674) --- .github/workflows/db-update-dev.yml | 2 +- .github/workflows/db-update-qa.yml | 2 +- .github/workflows/db-update.yml | 22 +- README.md | 5 +- api/src/database/database.py | 6 +- api/src/feeds/impl/feeds_api_impl.py | 2 + api/src/feeds/impl/search_api_impl.py | 1 + api/src/scripts/gbfs_utils/__init__.py | 0 api/src/scripts/gbfs_utils/comparison.py | 82 ++++++ api/src/scripts/gbfs_utils/fetching.py | 81 ++++++ api/src/scripts/gbfs_utils/gbfs_versions.py | 19 ++ api/src/scripts/gbfs_utils/license.py | 27 ++ api/src/scripts/populate_db.py | 266 ++------------------ api/src/scripts/populate_db_gbfs.py | 128 ++++++++++ api/src/scripts/populate_db_gtfs.py | 254 +++++++++++++++++++ api/src/scripts/populate_db_test_data.py | 20 +- api/tests/test_utils/database.py | 4 +- api/tests/unittest/test_feeds.py | 4 +- liquibase/changelog.xml | 1 + liquibase/changes/feat_565.sql | 22 ++ scripts/populate-db.sh | 15 +- 21 files changed, 689 insertions(+), 274 deletions(-) create mode 100644 api/src/scripts/gbfs_utils/__init__.py create mode 100644 api/src/scripts/gbfs_utils/comparison.py create mode 100644 api/src/scripts/gbfs_utils/fetching.py create mode 100644 api/src/scripts/gbfs_utils/gbfs_versions.py create mode 100644 api/src/scripts/gbfs_utils/license.py create mode 100644 api/src/scripts/populate_db_gbfs.py create mode 100644 api/src/scripts/populate_db_gtfs.py create mode 100644 liquibase/changes/feat_565.sql diff --git a/.github/workflows/db-update-dev.yml b/.github/workflows/db-update-dev.yml index 49b28a692..c225b1279 100644 --- a/.github/workflows/db-update-dev.yml +++ b/.github/workflows/db-update-dev.yml @@ -6,7 +6,7 @@ on: - main paths: - 'liquibase/changelog.xml' - - 'api/src/scripts/populate_db.py' + - 'api/src/scripts/populate_db*' repository_dispatch: # Update on mobility-database-catalog repo dispatch types: [ catalog-sources-updated ] workflow_dispatch: diff --git a/.github/workflows/db-update-qa.yml b/.github/workflows/db-update-qa.yml index 04b265a7d..21b392e5a 100644 --- a/.github/workflows/db-update-qa.yml +++ b/.github/workflows/db-update-qa.yml @@ -6,7 +6,7 @@ on: - main paths: - 'liquibase/changelog.xml' - - 'api/src/scripts/populate_db.py' + - 'api/src/scripts/populate_db*' workflow_dispatch: jobs: update: diff --git a/.github/workflows/db-update.yml b/.github/workflows/db-update.yml index a0d1ae703..9fd6f15e1 100644 --- a/.github/workflows/db-update.yml +++ b/.github/workflows/db-update.yml @@ -189,15 +189,33 @@ jobs: id: getpath run: echo "PATH=$(realpath sources.csv)" >> $GITHUB_OUTPUT - - name: Update Database Content + - name: Download systems.csv + run: wget -O systems.csv https://raw.githubusercontent.com/MobilityData/gbfs/master/systems.csv + + - name: Get full path of systems.csv + id: getsyspath + run: echo "PATH=$(realpath systems.csv)" >> $GITHUB_OUTPUT + + - name: GTFS - Update Database Content run: scripts/populate-db.sh ${{ steps.getpath.outputs.PATH }} > populate.log - - name: Upload log file for verification + - name: GBFS - Update Database Content + run: scripts/populate-db.sh ${{ steps.getsyspath.outputs.PATH }} gbfs >> populate-gbfs.log + + - name: GTFS - Upload log file for verification + if: ${{ always() }} uses: actions/upload-artifact@v4 with: name: populate-${{ inputs.ENVIRONMENT }}.log path: populate.log + - name: GBFS - Upload log file for verification + if: ${{ always() }} + uses: actions/upload-artifact@v4 + with: + name: populate-gbfs-${{ inputs.ENVIRONMENT }}.log + path: populate-gbfs.log + update-gcp-secret: name: Update GCP Secrets if: ${{ github.event_name == 'repository_dispatch' || github.event_name == 'workflow_dispatch' }} diff --git a/README.md b/README.md index 1bd1dc6f1..02806934f 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Mobility Feed API ![Deploy Feeds API - QA](https://github.com/MobilityData/mobility-feed-api/workflows/Deploy%20Feeds%20API%20-%20QA/badge.svg?branch=main) -![Deploy Web App - QA](https://github.com/MobilityData/mobility-feed-api/actions/workflows/web-app.yml/badge.svg?branch=main) +![Deploy Web App - QA](https://github.com/MobilityData/mobility-feed-api/actions/workflows/web-qa.yml/badge.svg?branch=main) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) The Mobility Feed API service a list of open mobility data sources from across the world. This repository is the effort the initial effort to convert the current [The Mobility Database Catalogs](https://github.com/MobilityData/mobility-database-catalogs) in an API service. @@ -10,6 +10,9 @@ The Mobility Feed API service a list of open mobility data sources from across t Mobility Feed API is not released yet; any code or service hosted is considered as **Work in Progress**. For more information regarding the current Mobility Database Catalog, go to [The Mobility Database Catalogs](https://github.com/MobilityData/mobility-database-catalogs). +## GBFS Feeds +The repository also includes GBFS feeds extracted from [`systems.csv`](https://github.com/MobilityData/gbfs/blob/master/systems.csv) in the [GBFS repository](https://github.com/MobilityData/gbfs). However, these feeds are not being served yet. The supported versions of these feeds are specified in the file [api/src/scripts/gbfs_utils/gbfs_versions.py](https://github.com/MobilityData/mobility-feed-api/blob/main/api/src/scripts/gbfs_utils/gbfs_versions.py). + # Authentication To access the Mobility Feed API, users need to authenticate using an access token. Here is the step-by-step process to obtain and use an access token: diff --git a/api/src/database/database.py b/api/src/database/database.py index 086dc772a..653550164 100644 --- a/api/src/database/database.py +++ b/api/src/database/database.py @@ -7,7 +7,7 @@ from sqlalchemy import create_engine, inspect from sqlalchemy.orm import load_only, Query, class_mapper, Session -from database_gen.sqlacodegen_models import Base, Feed, Gtfsfeed, Gtfsrealtimefeed +from database_gen.sqlacodegen_models import Base, Feed, Gtfsfeed, Gtfsrealtimefeed, Gbfsfeed from sqlalchemy.orm import sessionmaker import logging from typing import Final @@ -43,6 +43,10 @@ def configure_polymorphic_mappers(): gtfsrealtimefeed_mapper.inherits = feed_mapper gtfsrealtimefeed_mapper.polymorphic_identity = Gtfsrealtimefeed.__tablename__.lower() + gbfsfeed_mapper = class_mapper(Gbfsfeed) + gbfsfeed_mapper.inherits = feed_mapper + gbfsfeed_mapper.polymorphic_identity = Gbfsfeed.__tablename__.lower() + class Database: """ diff --git a/api/src/feeds/impl/feeds_api_impl.py b/api/src/feeds/impl/feeds_api_impl.py index c7213e75f..6727516c2 100644 --- a/api/src/feeds/impl/feeds_api_impl.py +++ b/api/src/feeds/impl/feeds_api_impl.py @@ -59,6 +59,7 @@ def get_feed( feed = ( FeedFilter(stable_id=id, provider__ilike=None, producer_url__ilike=None, status=None) .filter(Database().get_query_model(Feed)) + .filter(Feed.data_type != "gbfs") # Filter out GBFS feeds .first() ) if feed: @@ -79,6 +80,7 @@ def get_feeds( status=status, provider__ilike=provider, producer_url__ilike=producer_url, stable_id=None ) feed_query = feed_filter.filter(Database().get_query_model(Feed)) + feed_query = feed_query.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds # Results are sorted by provider feed_query = feed_query.order_by(Feed.provider, Feed.stable_id) feed_query = feed_query.options(*BasicFeedImpl.get_joinedload_options()) diff --git a/api/src/feeds/impl/search_api_impl.py b/api/src/feeds/impl/search_api_impl.py index 18484f534..652a67de8 100644 --- a/api/src/feeds/impl/search_api_impl.py +++ b/api/src/feeds/impl/search_api_impl.py @@ -35,6 +35,7 @@ def add_search_query_filters(query, search_query, data_type, feed_id, status) -> Filter values are trimmed and converted to lowercase. The search query is also converted to its unaccented version. """ + query = query.filter(t_feedsearch.c.data_type != "gbfs") # Filter out GBFS feeds if feed_id: query = query.where(t_feedsearch.c.feed_stable_id == feed_id.strip().lower()) if data_type: diff --git a/api/src/scripts/gbfs_utils/__init__.py b/api/src/scripts/gbfs_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/src/scripts/gbfs_utils/comparison.py b/api/src/scripts/gbfs_utils/comparison.py new file mode 100644 index 000000000..5c2795c2a --- /dev/null +++ b/api/src/scripts/gbfs_utils/comparison.py @@ -0,0 +1,82 @@ +import pandas as pd +from sqlalchemy.orm import joinedload +from database_gen.sqlacodegen_models import Gbfsfeed + + +def generate_system_csv_from_db(df, db_session): + """Generate a DataFrame from the database with the same columns as the CSV file.""" + stable_ids = "gbfs-" + df["System ID"] + query = db_session.query(Gbfsfeed) + query = query.filter(Gbfsfeed.stable_id.in_(stable_ids.to_list())) + query = query.options( + joinedload(Gbfsfeed.locations), joinedload(Gbfsfeed.gbfsversions), joinedload(Gbfsfeed.externalids) + ) + feeds = query.all() + data = [] + for feed in feeds: + system_id = feed.externalids[0].associated_id + auto_discovery_url = feed.auto_discovery_url + feed.gbfsversions.sort(key=lambda x: x.version, reverse=False) + supported_versions = [version.version for version in feed.gbfsversions] + data.append( + { + "System ID": system_id, + "Name": feed.operator, + "URL": feed.operator_url, + "Country Code": feed.locations[0].country_code, + "Location": feed.locations[0].municipality, + "Auto-Discovery URL": auto_discovery_url, + "Supported Versions": " ; ".join(supported_versions), + } + ) + return pd.DataFrame(data) + + +def compare_db_to_csv(df_from_db, df_from_csv, logger): + """Compare the database to the CSV file and return the differences.""" + df_from_csv = df_from_csv[df_from_db.columns] + df_from_db = df_from_db.fillna("") + df_from_csv = df_from_csv.fillna("") + + if df_from_db.empty: + logger.info("No data found in the database.") + return None, None + + # Align both DataFrames by "System ID" + df_from_db.set_index("System ID", inplace=True) + df_from_csv.set_index("System ID", inplace=True) + + # Find rows that are in the CSV but not in the DB (new feeds) + missing_in_db = df_from_csv[~df_from_csv.index.isin(df_from_db.index)] + if not missing_in_db.empty: + logger.info("New feeds found in CSV:") + logger.info(missing_in_db) + + # Find rows that are in the DB but not in the CSV (deprecated feeds) + missing_in_csv = df_from_db[~df_from_db.index.isin(df_from_csv.index)] + if not missing_in_csv.empty: + logger.info("Deprecated feeds found in DB:") + logger.info(missing_in_csv) + + # Find rows that are in both, but with differences + common_ids = df_from_db.index.intersection(df_from_csv.index) + df_db_common = df_from_db.loc[common_ids] + df_csv_common = df_from_csv.loc[common_ids] + differences = df_db_common != df_csv_common + differing_rows = df_db_common[differences.any(axis=1)] + + if not differing_rows.empty: + logger.info("Rows with differences:") + for idx in differing_rows.index: + logger.info(f"Differences for System ID {idx}:") + db_row = df_db_common.loc[idx] + csv_row = df_csv_common.loc[idx] + diff = db_row != csv_row + logger.info(f"DB Row: {db_row[diff].to_dict()}") + logger.info(f"CSV Row: {csv_row[diff].to_dict()}") + logger.info(80 * "-") + + # Merge differing rows with missing_in_db to capture all new or updated feeds + all_differing_or_new_rows = pd.concat([differing_rows, missing_in_db]).reset_index() + + return all_differing_or_new_rows, missing_in_csv diff --git a/api/src/scripts/gbfs_utils/fetching.py b/api/src/scripts/gbfs_utils/fetching.py new file mode 100644 index 000000000..974b94d57 --- /dev/null +++ b/api/src/scripts/gbfs_utils/fetching.py @@ -0,0 +1,81 @@ +import requests + + +def fetch_data(auto_discovery_url, logger, urls=[], fields=[]): + """Fetch data from the auto-discovery URL and return the specified fields.""" + fetched_data = {} + if not auto_discovery_url: + return + try: + response = requests.get(auto_discovery_url) + response.raise_for_status() + data = response.json() + for field in fields: + fetched_data[field] = data.get(field) + feeds = None + for lang_code, lang_data in data.get("data", {}).items(): + if isinstance(lang_data, list): + lang_feeds = lang_data + else: + lang_feeds = lang_data.get("feeds", []) + if lang_code == "en": + feeds = lang_feeds + break + elif not feeds: + feeds = lang_feeds + logger.info(f"Feeds found from auto-discovery URL {auto_discovery_url}: {feeds}") + if feeds: + for url in urls: + fetched_data[url] = get_field_url(feeds, url) + return fetched_data + except requests.RequestException as e: + logger.error(f"Error fetching data for autodiscovery url {auto_discovery_url}: {e}") + return fetched_data + + +def get_data_content(url, logger): + """Utility function to fetch data content from a URL.""" + try: + if url: + response = requests.get(url) + response.raise_for_status() + system_info = response.json().get("data", {}) + return system_info + except requests.RequestException as e: + logger.error(f"Error fetching data content for url {url}: {e}") + return None + + +def get_field_url(fields, field_name): + """Utility function to get the URL of a specific feed by name.""" + for field in fields: + if field.get("name") == field_name: + return field.get("url") + return None + + +def get_gbfs_versions(gbfs_versions_url, auto_discovery_url, auto_discovery_version, logger): + """Get the GBFS versions from the gbfs_versions_url.""" + # Default version info extracted from auto-discovery url + version_info = { + "version": auto_discovery_version if auto_discovery_version else "1.0", + "url": auto_discovery_url, + } + try: + if not gbfs_versions_url: + return [version_info] + logger.info(f"Fetching GBFS versions from: {gbfs_versions_url}") + data = get_data_content(gbfs_versions_url, logger) + if not data: + logger.warning(f"No data found in the GBFS versions URL -> {gbfs_versions_url}.") + return [version_info] + gbfs_versions = data.get("versions", []) + + # Append the version info from auto-discovery if it doesn't exist + if not any(gv.get("version") == auto_discovery_version for gv in gbfs_versions): + gbfs_versions.append(version_info) + + return gbfs_versions + except Exception as e: + logger.error(f"Error fetching version data: {e}") + return [version_info] diff --git a/api/src/scripts/gbfs_utils/gbfs_versions.py b/api/src/scripts/gbfs_utils/gbfs_versions.py new file mode 100644 index 000000000..9ffce502c --- /dev/null +++ b/api/src/scripts/gbfs_utils/gbfs_versions.py @@ -0,0 +1,19 @@ +OFFICIAL_VERSIONS = [ + "1.0", + "1.1-RC", + "1.1", + "2.0-RC", + "2.0", + "2.1-RC", + "2.1-RC2", + "2.1", + "2.2-RC", + "2.2", + "2.3-RC", + "2.3-RC2", + "2.3", + "3.0-RC", + "3.0-RC2", + "3.0", + "3.1-RC", +] diff --git a/api/src/scripts/gbfs_utils/license.py b/api/src/scripts/gbfs_utils/license.py new file mode 100644 index 000000000..793ea4019 --- /dev/null +++ b/api/src/scripts/gbfs_utils/license.py @@ -0,0 +1,27 @@ +LICENSE_URL_MAP = { + "CC0-1.0": "https://creativecommons.org/publicdomain/zero/1.0/", + "CC-BY-4.0": "https://creativecommons.org/licenses/by/4.0/", + "CDLA-Permissive-1.0": "https://cdla.io/permissive-1-0/", + "ODC-By-1.0": "https://www.opendatacommons.org/licenses/by/1.0/", +} + +DEFAULT_LICENSE_URL = "https://creativecommons.org/licenses/by/4.0/" + + +def get_license_url(system_info, logger): + """Get the license URL from the system information.""" + try: + if system_info is None: + return None + + # Fetching license_url or license_id + license_url = system_info.get("license_url") + if not license_url: + license_id = system_info.get("license_id") + if license_id: + return LICENSE_URL_MAP.get(license_id, DEFAULT_LICENSE_URL) + return DEFAULT_LICENSE_URL + return license_url + except Exception as e: + logger.error(f"Error fetching license url data from system info {system_info}: \n{e}") + return None diff --git a/api/src/scripts/populate_db.py b/api/src/scripts/populate_db.py index 83c8f3127..533fcb2ea 100644 --- a/api/src/scripts/populate_db.py +++ b/api/src/scripts/populate_db.py @@ -1,34 +1,21 @@ import argparse +import logging import os from pathlib import Path from typing import Type import pandas from dotenv import load_dotenv -from sqlalchemy import text -from database.database import Database, generate_unique_id, configure_polymorphic_mappers -from database_gen.sqlacodegen_models import ( - Entitytype, - Externalid, - Gtfsfeed, - Gtfsrealtimefeed, - Location, - Redirectingid, - t_feedsearch, - Feed, -) -from scripts.load_dataset_on_create import publish_all -from utils.data_utils import set_up_defaults +from database.database import Database +from database_gen.sqlacodegen_models import Feed, Gtfsrealtimefeed, Gtfsfeed, Gbfsfeed from utils.logger import Logger -from datetime import datetime -import pytz - -import logging logging.basicConfig() logging.getLogger("sqlalchemy.engine").setLevel(logging.ERROR) +feed_mapping = {"gtfs_rt": Gtfsrealtimefeed, "gtfs": Gtfsfeed, "gbfs": Gbfsfeed} + def set_up_configs(): """ @@ -67,19 +54,21 @@ def __init__(self, filepaths): new_df = pandas.read_csv(filepath, low_memory=False) self.df = pandas.concat([self.df, new_df]) - # Filter unsupported data types - self.df = self.df[(self.df.data_type == "gtfs") | (self.df.data_type == "gtfs-rt")] - self.df = set_up_defaults(self.df) - self.added_gtfs_feeds = [] # Keep track of the feeds that have been added to the database + self.filter_data() + + def query_feed_by_stable_id(self, stable_id: str, data_type: str | None) -> Gtfsrealtimefeed | Gtfsfeed | None: + """ + Query the feed by stable id + """ + model = self.get_model(data_type) + return self.db.session.query(model).filter(model.stable_id == stable_id).first() @staticmethod - def get_model(data_type: str | None) -> Type[Gtfsrealtimefeed | Gtfsfeed | Feed]: + def get_model(data_type: str | None) -> Type[Feed]: """ Get the model based on the data type """ - if data_type is None: - return Feed - return Gtfsrealtimefeed if data_type == "gtfs_rt" else Gtfsfeed + return feed_mapping.get(data_type, Feed) @staticmethod def get_safe_value(row, column_name, default_value): @@ -90,228 +79,17 @@ def get_safe_value(row, column_name, default_value): return default_value if default_value is not None else None return f"{row[column_name]}".strip() - def get_data_type(self, row): - """ - Get the data type from the row - """ - data_type = self.get_safe_value(row, "data_type", "").lower() - if data_type not in ["gtfs", "gtfs-rt", "gtfs_rt"]: - self.logger.warning(f"Unsupported data type: {data_type}") - return None - return data_type.replace("-", "_") - - def query_feed_by_stable_id(self, stable_id: str, data_type: str | None) -> Gtfsrealtimefeed | Gtfsfeed | None: - """ - Query the feed by stable id - """ - model = self.get_model(data_type) - return self.db.session.query(model).filter(model.stable_id == stable_id).first() - - def get_stable_id(self, row): - """ - Get the stable id from the row - """ - return f'mdb-{self.get_safe_value(row, "mdb_source_id", "")}' - - def populate_location(self, feed, row, stable_id): + @staticmethod + def get_location_id(country_code, subdivision_name, municipality): """ - Populate the location for the feed + Get the location ID """ - # TODO: validate behaviour for gtfs-rt feeds - if feed.locations and feed.data_type == "gtfs": - self.logger.warning(f"Location already exists for feed {stable_id}") - return - - country_code = self.get_safe_value(row, "location.country_code", "") - subdivision_name = self.get_safe_value(row, "location.subdivision_name", "") - municipality = self.get_safe_value(row, "location.municipality", "") composite_id = f"{country_code}-{subdivision_name}-{municipality}".replace(" ", "_") location_id = composite_id if len(composite_id) > 2 else None - if not location_id: - self.logger.warning(f"Location ID is empty for feed {stable_id}") - feed.locations.clear() - else: - location = self.db.session.get(Location, location_id) - location = ( - location - if location - else Location( - id=location_id, - country_code=country_code, - subdivision_name=subdivision_name, - municipality=municipality, - ) - ) - feed.locations = [location] - - def process_entity_types(self, feed: Gtfsrealtimefeed, row, stable_id): - """ - Process the entity types for the feed - """ - entity_types = self.get_safe_value(row, "entity_type", "").replace("|", "-").split("-") - if len(entity_types) > 0: - for entity_type_name in entity_types: - entity_type = self.db.session.query(Entitytype).filter(Entitytype.name == entity_type_name).first() - if not entity_type: - entity_type = Entitytype(name=entity_type_name) - if all(entity_type.name != entity.name for entity in feed.entitytypes): - feed.entitytypes.append(entity_type) - self.db.session.flush() - else: - self.logger.warning(f"Entity types array is empty for feed {stable_id}") - feed.entitytypes.clear() - - def process_feed_references(self): - """ - Process the feed references - """ - self.logger.info("Processing feed references") - for index, row in self.df.iterrows(): - stable_id = self.get_stable_id(row) - data_type = self.get_data_type(row) - if data_type != "gtfs_rt": - continue - gtfs_rt_feed = self.query_feed_by_stable_id(stable_id, "gtfs_rt") - static_reference = self.get_safe_value(row, "static_reference", "") - if static_reference: - gtfs_stable_id = f"mdb-{int(float(static_reference))}" - gtfs_feed = self.query_feed_by_stable_id(gtfs_stable_id, "gtfs") - already_referenced_ids = {ref.id for ref in gtfs_feed.gtfs_rt_feeds} - if gtfs_feed and gtfs_rt_feed.id not in already_referenced_ids: - gtfs_feed.gtfs_rt_feeds.append(gtfs_rt_feed) - # Flush to avoid FK violation - self.db.session.flush() + return location_id - def process_redirects(self): - """ - Process the redirects + def filter_data(self): """ - self.logger.info("Processing redirects") - for index, row in self.df.iterrows(): - stable_id = self.get_stable_id(row) - raw_redirects = row.get("redirect.id", None) - redirects_ids = str(raw_redirects).split("|") if raw_redirects is not None else [] - if len(redirects_ids) == 0: - continue - feed = self.query_feed_by_stable_id(stable_id, None) - raw_comments = row.get("redirect.comment", None) - comments = raw_comments.split("|") if raw_comments is not None else [] - if len(redirects_ids) != len(comments) and len(comments) > 0: - self.logger.warning(f"Number of redirect ids and redirect comments differ for feed {stable_id}") - for mdb_source_id in redirects_ids: - if len(mdb_source_id) == 0: - # since there is a 1:1 correspondence between redirect ids and comments, skip also the comment - comments = comments[1:] - continue - if comments: - comment = comments.pop(0) - else: - comment = "" - - target_stable_id = f"mdb-{int(float(mdb_source_id.strip()))}" - target_feed = self.query_feed_by_stable_id(target_stable_id, None) - if not target_feed: - self.logger.warning(f"Could not find redirect target feed {target_stable_id} for feed {stable_id}") - continue - - if feed.id == target_feed.id: - self.logger.error(f"Feed has redirect pointing to itself {stable_id}") - else: - if all(redirect.target_id != target_feed.id for redirect in feed.redirectingids): - feed.redirectingids.append( - Redirectingid(source_id=feed.id, target_id=target_feed.id, redirect_comment=comment) - ) - # Flush to avoid FK violation - self.db.session.flush() - - def populate_db(self): + Filter the data to only include the necessary columns """ - Populate the database with the sources.csv data - """ - self.logger.info("Populating the database with sources.csv data") - for index, row in self.df.iterrows(): - self.logger.debug(f"Populating Database with Feed [stable_id = {row['mdb_source_id']}]") - # Create or update the GTFS feed - data_type = self.get_data_type(row) - stable_id = self.get_stable_id(row) - feed = self.query_feed_by_stable_id(stable_id, data_type) - if feed: - self.logger.debug(f"Updating {feed.__class__.__name__}: {stable_id}") - else: - feed = self.get_model(data_type)( - id=generate_unique_id(), - data_type=data_type, - stable_id=stable_id, - created_at=datetime.now(pytz.utc), # Current timestamp with UTC timezone - ) - self.logger.info(f"Creating {feed.__class__.__name__}: {stable_id}") - self.db.session.add(feed) - if data_type == "gtfs": - self.added_gtfs_feeds.append(feed) - feed.externalids = [ - Externalid( - feed_id=feed.id, - associated_id=str(int(float(row["mdb_source_id"]))), - source="mdb", - ) - ] - # Populate common fields from Feed - feed.feed_name = self.get_safe_value(row, "name", "") - feed.note = self.get_safe_value(row, "note", "") - feed.producer_url = self.get_safe_value(row, "urls.direct_download", "") - feed.authentication_type = str(int(float(self.get_safe_value(row, "urls.authentication_type", "0")))) - feed.authentication_info_url = self.get_safe_value(row, "urls.authentication_info", "") - feed.api_key_parameter_name = self.get_safe_value(row, "urls.api_key_parameter_name", "") - feed.license_url = self.get_safe_value(row, "urls.license", "") - feed.status = self.get_safe_value(row, "status", "active") - feed.feed_contact_email = self.get_safe_value(row, "feed_contact_email", "") - feed.provider = self.get_safe_value(row, "provider", "") - - self.populate_location(feed, row, stable_id) - if data_type == "gtfs_rt": - self.process_entity_types(feed, row, stable_id) - - self.db.session.add(feed) - self.db.session.flush() - # This need to be done after all feeds are added to the session to avoid FK violation - self.process_feed_references() - self.process_redirects() - - def trigger_downstream_tasks(self): - """ - Trigger downstream tasks after populating the database - """ - self.logger.info("Triggering downstream tasks") - self.logger.debug( - f"New feeds added to the database: " - f"{','.join([feed.stable_id for feed in self.added_gtfs_feeds] if self.added_gtfs_feeds else [])}" - ) - - env = os.getenv("ENV") - self.logger.info(f"ENV = {env}") - if os.getenv("ENV", "local") != "local": - publish_all(self.added_gtfs_feeds) # Publishes the new feeds to the Pub/Sub topic to download the datasets - - # Extracted the following code from main so it can be executed as a library function - def initialize(self, trigger_downstream_tasks: bool = True): - try: - configure_polymorphic_mappers() - self.populate_db() - self.db.session.commit() - - self.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Started") - self.db.session.execute(text(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {t_feedsearch.name}")) - self.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Completed") - self.db.session.commit() - self.logger.info("\n----- Database populated with sources.csv data. -----") - if trigger_downstream_tasks: - self.trigger_downstream_tasks() - except Exception as e: - self.logger.error(f"\n------ Failed to populate the database with sources.csv: {e} -----\n") - self.db.session.rollback() - exit(1) - - -if __name__ == "__main__": - db_helper = DatabasePopulateHelper(set_up_configs()) - db_helper.initialize() + pass # Should be implemented in the child class diff --git a/api/src/scripts/populate_db_gbfs.py b/api/src/scripts/populate_db_gbfs.py new file mode 100644 index 000000000..4bb7796eb --- /dev/null +++ b/api/src/scripts/populate_db_gbfs.py @@ -0,0 +1,128 @@ +from datetime import datetime + +import pandas as pd +import pytz + +from database.database import generate_unique_id, configure_polymorphic_mappers +from database_gen.sqlacodegen_models import Gbfsfeed, Location, Gbfsversion, Externalid +from scripts.gbfs_utils.comparison import generate_system_csv_from_db, compare_db_to_csv +from scripts.gbfs_utils.fetching import fetch_data, get_data_content, get_gbfs_versions +from scripts.gbfs_utils.license import get_license_url +from scripts.populate_db import DatabasePopulateHelper, set_up_configs +from scripts.gbfs_utils.gbfs_versions import OFFICIAL_VERSIONS + + +class GBFSDatabasePopulateHelper(DatabasePopulateHelper): + def __init__(self, file_path): + super().__init__(file_path) + + def filter_data(self): + """Filter out rows with Authentication Info and duplicate System IDs""" + self.df = self.df[pd.isna(self.df["Authentication Info"])] + self.df = self.df[~self.df.duplicated(subset="System ID", keep=False)] + self.logger.info(f"Data = {self.df}") + + @staticmethod + def get_stable_id(row): + return f"gbfs-{row['System ID']}" + + @staticmethod + def get_external_id(feed_id, system_id): + return Externalid(feed_id=feed_id, associated_id=str(system_id), source="gbfs") + + def deprecate_feeds(self, deprecated_feeds): + """Deprecate feeds that are no longer in systems.csv""" + if deprecated_feeds is None or deprecated_feeds.empty: + self.logger.info("No feeds to deprecate.") + return + self.logger.info(f"Deprecating {len(deprecated_feeds)} feed(s).") + for index, row in deprecated_feeds.iterrows(): + stable_id = self.get_stable_id(row) + gbfs_feed = self.query_feed_by_stable_id(stable_id, "gbfs") + if gbfs_feed: + self.logger.info(f"Deprecating feed with stable_id={stable_id}") + gbfs_feed.status = "deprecated" + self.db.session.flush() + + def populate_db(self): + """Populate the database with the GBFS feeds""" + start_time = datetime.now() + configure_polymorphic_mappers() + + # Compare the database to the CSV file + df_from_db = generate_system_csv_from_db(self.df, self.db.session) + added_or_updated_feeds, deprecated_feeds = compare_db_to_csv(df_from_db, self.df, self.logger) + + self.deprecate_feeds(deprecated_feeds) + if added_or_updated_feeds is None: + added_or_updated_feeds = self.df + for index, row in added_or_updated_feeds.iterrows(): + self.logger.info(f"Processing row {index + 1} of {len(added_or_updated_feeds)}") + stable_id = self.get_stable_id(row) + gbfs_feed = self.query_feed_by_stable_id(stable_id, "gbfs") + fetched_data = fetch_data( + row["Auto-Discovery URL"], self.logger, ["system_information", "gbfs_versions"], ["version"] + ) + # If the feed already exists, update it. Otherwise, create a new feed. + if gbfs_feed: + feed_id = gbfs_feed.id + self.logger.info(f"Updating feed {stable_id} - {row['Name']}") + else: + feed_id = generate_unique_id() + self.logger.info(f"Creating new feed for {stable_id} - {row['Name']}") + gbfs_feed = Gbfsfeed( + id=feed_id, + data_type="gbfs", + stable_id=stable_id, + created_at=datetime.now(pytz.utc), + ) + gbfs_feed.externalids = [self.get_external_id(feed_id, row["System ID"])] + self.db.session.add(gbfs_feed) + + system_information_content = get_data_content(fetched_data.get("system_information"), self.logger) + gbfs_feed.license_url = get_license_url(system_information_content, self.logger) + gbfs_feed.feed_contact_email = ( + system_information_content.get("feed_contact_email") if system_information_content else None + ) + gbfs_feed.operator = row["Name"] + gbfs_feed.operator_url = row["URL"] + gbfs_feed.auto_discovery_url = row["Auto-Discovery URL"] + gbfs_feed.updated_at = datetime.now(pytz.utc) + + country_code = self.get_safe_value(row, "Country Code", "") + municipality = self.get_safe_value(row, "Location", "") + location_id = self.get_location_id(country_code, None, municipality) + location = self.db.session.get(Location, location_id) or Location( + id=location_id, + country_code=country_code, + municipality=municipality, + ) + gbfs_feed.locations.clear() + gbfs_feed.locations = [location] + + # Add the GBFS versions + versions = get_gbfs_versions( + fetched_data.get("gbfs_versions"), row["Auto-Discovery URL"], fetched_data.get("version"), self.logger + ) + existing_versions = [version.version for version in gbfs_feed.gbfsversions] + for version in versions: + version_value = version.get("version") + if version_value.upper() in OFFICIAL_VERSIONS and version_value not in existing_versions: + gbfs_feed.gbfsversions.append( + Gbfsversion( + feed_id=feed_id, + url=version.get("url"), + version=version_value, + ) + ) + + self.db.session.flush() + self.logger.info(80 * "-") + + self.db.session.commit() + end_time = datetime.now() + self.logger.info(f"Time taken: {end_time - start_time} seconds") + + +if __name__ == "__main__": + GBFSDatabasePopulateHelper(set_up_configs()).populate_db() diff --git a/api/src/scripts/populate_db_gtfs.py b/api/src/scripts/populate_db_gtfs.py new file mode 100644 index 000000000..40eee6a4b --- /dev/null +++ b/api/src/scripts/populate_db_gtfs.py @@ -0,0 +1,254 @@ +import os +from datetime import datetime + +import pytz +from sqlalchemy import text + +from database.database import generate_unique_id, configure_polymorphic_mappers +from database_gen.sqlacodegen_models import ( + Entitytype, + Externalid, + Gtfsrealtimefeed, + Location, + Redirectingid, + t_feedsearch, +) +from scripts.populate_db import DatabasePopulateHelper, set_up_configs +from scripts.load_dataset_on_create import publish_all +from utils.data_utils import set_up_defaults + + +class GTFSDatabasePopulateHelper(DatabasePopulateHelper): + """ + GTFS - Helper class to populate the database + """ + + def __init__(self, filepaths): + """ + Specify a list of files to load the csv data from. + Can also be a single string with a file name. + """ + super().__init__(filepaths) + self.added_gtfs_feeds = [] # Keep track of the feeds that have been added to the database + + def filter_data(self): + self.df = self.df[(self.df.data_type == "gtfs") | (self.df.data_type == "gtfs-rt")] + self.df = set_up_defaults(self.df) + self.added_gtfs_feeds = [] # Keep track of the feeds that have been added to the database + + def get_data_type(self, row): + """ + Get the data type from the row + """ + data_type = self.get_safe_value(row, "data_type", "").lower() + if data_type not in ["gtfs", "gtfs-rt", "gtfs_rt"]: + self.logger.warning(f"Unsupported data type: {data_type}") + return None + return data_type.replace("-", "_") + + def get_stable_id(self, row): + """ + Get the stable id from the row + """ + return f'mdb-{self.get_safe_value(row, "mdb_source_id", "")}' + + def populate_location(self, feed, row, stable_id): + """ + Populate the location for the feed + """ + if feed.locations: + self.logger.warning(f"Location already exists for feed {stable_id}") + return + + country_code = self.get_safe_value(row, "location.country_code", "") + subdivision_name = self.get_safe_value(row, "location.subdivision_name", "") + municipality = self.get_safe_value(row, "location.municipality", "") + location_id = self.get_location_id(country_code, subdivision_name, municipality) + if not location_id: + self.logger.warning(f"Location ID is empty for feed {stable_id}") + feed.locations.clear() + else: + location = self.db.session.get(Location, location_id) + location = ( + location + if location + else Location( + id=location_id, + country_code=country_code, + subdivision_name=subdivision_name, + municipality=municipality, + ) + ) + feed.locations = [location] + + def process_entity_types(self, feed: Gtfsrealtimefeed, row, stable_id): + """ + Process the entity types for the feed + """ + entity_types = self.get_safe_value(row, "entity_type", "").replace("|", "-").split("-") + if len(entity_types) > 0: + for entity_type_name in entity_types: + entity_type = self.db.session.query(Entitytype).filter(Entitytype.name == entity_type_name).first() + if not entity_type: + entity_type = Entitytype(name=entity_type_name) + if all(entity_type.name != entity.name for entity in feed.entitytypes): + feed.entitytypes.append(entity_type) + self.db.session.flush() + else: + self.logger.warning(f"Entity types array is empty for feed {stable_id}") + feed.entitytypes.clear() + + def process_feed_references(self): + """ + Process the feed references + """ + self.logger.info("Processing feed references") + for index, row in self.df.iterrows(): + stable_id = self.get_stable_id(row) + data_type = self.get_data_type(row) + if data_type != "gtfs_rt": + continue + gtfs_rt_feed = self.query_feed_by_stable_id(stable_id, "gtfs_rt") + static_reference = self.get_safe_value(row, "static_reference", "") + if static_reference: + gtfs_stable_id = f"mdb-{int(float(static_reference))}" + gtfs_feed = self.query_feed_by_stable_id(gtfs_stable_id, "gtfs") + already_referenced_ids = {ref.id for ref in gtfs_feed.gtfs_rt_feeds} + if gtfs_feed and gtfs_rt_feed.id not in already_referenced_ids: + gtfs_feed.gtfs_rt_feeds.append(gtfs_rt_feed) + # Flush to avoid FK violation + self.db.session.flush() + + def process_redirects(self): + """ + Process the redirects + """ + self.logger.info("Processing redirects") + for index, row in self.df.iterrows(): + stable_id = self.get_stable_id(row) + raw_redirects = row.get("redirect.id", None) + redirects_ids = str(raw_redirects).split("|") if raw_redirects is not None else [] + if len(redirects_ids) == 0: + continue + feed = self.query_feed_by_stable_id(stable_id, None) + raw_comments = row.get("redirect.comment", None) + comments = raw_comments.split("|") if raw_comments is not None else [] + if len(redirects_ids) != len(comments) and len(comments) > 0: + self.logger.warning(f"Number of redirect ids and redirect comments differ for feed {stable_id}") + for mdb_source_id in redirects_ids: + if len(mdb_source_id) == 0: + # since there is a 1:1 correspondence between redirect ids and comments, skip also the comment + comments = comments[1:] + continue + if comments: + comment = comments.pop(0) + else: + comment = "" + + target_stable_id = f"mdb-{int(float(mdb_source_id.strip()))}" + target_feed = self.query_feed_by_stable_id(target_stable_id, None) + if not target_feed: + self.logger.warning(f"Could not find redirect target feed {target_stable_id} for feed {stable_id}") + continue + + if feed.id == target_feed.id: + self.logger.error(f"Feed has redirect pointing to itself {stable_id}") + else: + if all(redirect.target_id != target_feed.id for redirect in feed.redirectingids): + feed.redirectingids.append( + Redirectingid(source_id=feed.id, target_id=target_feed.id, redirect_comment=comment) + ) + # Flush to avoid FK violation + self.db.session.flush() + + def populate_db(self): + """ + Populate the database with the sources.csv data + """ + self.logger.info("Populating the database with sources.csv data") + for index, row in self.df.iterrows(): + self.logger.debug(f"Populating Database with Feed [stable_id = {row['mdb_source_id']}]") + # Create or update the GTFS feed + data_type = self.get_data_type(row) + stable_id = self.get_stable_id(row) + feed = self.query_feed_by_stable_id(stable_id, data_type) + if feed: + self.logger.debug(f"Updating {feed.__class__.__name__}: {stable_id}") + else: + feed = self.get_model(data_type)( + id=generate_unique_id(), + data_type=data_type, + stable_id=stable_id, + created_at=datetime.now(pytz.utc), # Current timestamp with UTC timezone + ) + self.logger.info(f"Creating {feed.__class__.__name__}: {stable_id}") + self.db.session.add(feed) + if data_type == "gtfs": + self.added_gtfs_feeds.append(feed) + feed.externalids = [ + Externalid( + feed_id=feed.id, + associated_id=str(int(float(row["mdb_source_id"]))), + source="mdb", + ) + ] + # Populate common fields from Feed + feed.feed_name = self.get_safe_value(row, "name", "") + feed.note = self.get_safe_value(row, "note", "") + feed.producer_url = self.get_safe_value(row, "urls.direct_download", "") + feed.authentication_type = str(int(float(self.get_safe_value(row, "urls.authentication_type", "0")))) + feed.authentication_info_url = self.get_safe_value(row, "urls.authentication_info", "") + feed.api_key_parameter_name = self.get_safe_value(row, "urls.api_key_parameter_name", "") + feed.license_url = self.get_safe_value(row, "urls.license", "") + feed.status = self.get_safe_value(row, "status", "active") + feed.feed_contact_email = self.get_safe_value(row, "feed_contact_email", "") + feed.provider = self.get_safe_value(row, "provider", "") + + self.populate_location(feed, row, stable_id) + if data_type == "gtfs_rt": + self.process_entity_types(feed, row, stable_id) + + self.db.session.add(feed) + self.db.session.flush() + # This need to be done after all feeds are added to the session to avoid FK violation + self.process_feed_references() + self.process_redirects() + + def trigger_downstream_tasks(self): + """ + Trigger downstream tasks after populating the database + """ + self.logger.info("Triggering downstream tasks") + self.logger.debug( + f"New feeds added to the database: " + f"{','.join([feed.stable_id for feed in self.added_gtfs_feeds] if self.added_gtfs_feeds else [])}" + ) + + env = os.getenv("ENV") + self.logger.info(f"ENV = {env}") + if os.getenv("ENV", "local") != "local": + publish_all(self.added_gtfs_feeds) # Publishes the new feeds to the Pub/Sub topic to download the datasets + + # Extracted the following code from main, so it can be executed as a library function + def initialize(self, trigger_downstream_tasks: bool = True): + try: + configure_polymorphic_mappers() + self.populate_db() + self.db.session.commit() + + self.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Started") + self.db.session.execute(text(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {t_feedsearch.name}")) + self.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Completed") + self.db.session.commit() + self.logger.info("\n----- Database populated with sources.csv data. -----") + if trigger_downstream_tasks: + self.trigger_downstream_tasks() + except Exception as e: + self.logger.error(f"\n------ Failed to populate the database with sources.csv: {e} -----\n") + self.db.session.rollback() + exit(1) + + +if __name__ == "__main__": + db_helper = GTFSDatabasePopulateHelper(set_up_configs()) + db_helper.initialize() diff --git a/api/src/scripts/populate_db_test_data.py b/api/src/scripts/populate_db_test_data.py index af637b845..fae1a1d2f 100644 --- a/api/src/scripts/populate_db_test_data.py +++ b/api/src/scripts/populate_db_test_data.py @@ -1,30 +1,14 @@ -import argparse import json -import os -from pathlib import Path -from dotenv import load_dotenv + from geoalchemy2 import WKTElement from sqlalchemy import text from database.database import Database from database_gen.sqlacodegen_models import Gtfsdataset, Validationreport, Gtfsfeed, Notice, Feature, t_feedsearch - +from scripts.populate_db import set_up_configs from utils.logger import Logger -def set_up_configs(): - """ - Set up function - """ - parser = argparse.ArgumentParser() - parser.add_argument("--filepath", help="Absolute path for the JSON file containing the test data", required=True) - args = parser.parse_args() - current_path = Path(__file__).resolve() - dotenv_path = os.path.join(current_path.parents[3], "config", ".env") - load_dotenv(dotenv_path=dotenv_path) - return args.filepath - - class DatabasePopulateTestDataHelper: """ Helper class to populate diff --git a/api/tests/test_utils/database.py b/api/tests/test_utils/database.py index d32d99ea4..af814d3c3 100644 --- a/api/tests/test_utils/database.py +++ b/api/tests/test_utils/database.py @@ -8,7 +8,7 @@ from tests.test_utils.db_utils import dump_database, is_test_db, dump_raw_database, empty_database from database.database import Database -from scripts.populate_db import DatabasePopulateHelper +from scripts.populate_db_gtfs import GTFSDatabasePopulateHelper from scripts.populate_db_test_data import DatabasePopulateTestDataHelper import os @@ -48,7 +48,7 @@ def populate_database(db: Database, data_dirs: str): if len(csv_filepaths) == 0: raise Exception("No sources_test.csv file found in test_data directories") - db_helper = DatabasePopulateHelper(csv_filepaths) + db_helper = GTFSDatabasePopulateHelper(csv_filepaths) db_helper.initialize(trigger_downstream_tasks=False) # Make a list of all the extra_test_data.json files in the test_data directories and load the data diff --git a/api/tests/unittest/test_feeds.py b/api/tests/unittest/test_feeds.py index ffd814c07..3ee91f336 100644 --- a/api/tests/unittest/test_feeds.py +++ b/api/tests/unittest/test_feeds.py @@ -84,7 +84,7 @@ def test_feeds_get(client: TestClient, mocker): mock_filter_offset = Mock() mock_filter_order_by = Mock() mock_options = Mock() - mock_filter.return_value.order_by.return_value = mock_filter_order_by + mock_filter.return_value.filter.return_value.order_by.return_value = mock_filter_order_by mock_filter_order_by.options.return_value = mock_options mock_options.offset.return_value = mock_filter_offset # Target is set to None as deep copy is failing for unknown reasons @@ -119,7 +119,7 @@ def test_feed_get(client: TestClient, mocker): Unit test for get_feeds """ mock_filter = mocker.patch.object(FeedFilter, "filter") - mock_filter.return_value.first.return_value = mock_feed + mock_filter.return_value.filter.return_value.first.return_value = mock_feed response = client.request( "GET", diff --git a/liquibase/changelog.xml b/liquibase/changelog.xml index e20df7c95..d7e4259bc 100644 --- a/liquibase/changelog.xml +++ b/liquibase/changelog.xml @@ -27,4 +27,5 @@ + \ No newline at end of file diff --git a/liquibase/changes/feat_565.sql b/liquibase/changes/feat_565.sql new file mode 100644 index 000000000..f3c98c238 --- /dev/null +++ b/liquibase/changes/feat_565.sql @@ -0,0 +1,22 @@ +ALTER TYPE datatype ADD VALUE IF NOT EXISTS 'gbfs'; + +-- Create the tables if they do not exist +CREATE TABLE IF NOT EXISTS GBFS_Feed( + id VARCHAR(255) PRIMARY KEY, + operator VARCHAR(255), + operator_url VARCHAR(255), + auto_discovery_url VARCHAR(255), + FOREIGN KEY (id) REFERENCES Feed(id) +); + +CREATE TABLE IF NOT EXISTS GBFS_Version( + feed_id VARCHAR(255) NOT NULL, + version VARCHAR(6), + url VARCHAR(255), + PRIMARY KEY (feed_id, version), + FOREIGN KEY (feed_id) REFERENCES GBFS_Feed(id) +); + +-- Rename tables to use convention like GBFSFeed and GBFSVersion +ALTER TABLE GBFS_Feed RENAME TO GBFSFeed; +ALTER TABLE GBFS_Version RENAME TO GBFSVersion; diff --git a/scripts/populate-db.sh b/scripts/populate-db.sh index e510ce688..6f1ef2d43 100755 --- a/scripts/populate-db.sh +++ b/scripts/populate-db.sh @@ -5,10 +5,21 @@ # As a requirement, you need to have the local instance of the database running on the port defined in config/.env.local # The csv file containing the data has to be in the same format as https://bit.ly/catalogs-csv # Usage: -# populate-db.sh +# populate-db.sh [data_type] # # relative path SCRIPT_PATH="$(dirname -- "${BASH_SOURCE[0]}")" -(cd $SCRIPT_PATH/../api/ && pip3 install -r requirements.txt && PYTHONPATH=src python src/scripts/populate_db.py --filepath "$1") \ No newline at end of file +# Set the data_type, defaulting to 'gtfs' +DATA_TYPE=${2:-gtfs} + +# Determine the script to run based on the data_type +if [ "$DATA_TYPE" = "gbfs" ]; then + SCRIPT_NAME="populate_db_gbfs.py" +else + SCRIPT_NAME="populate_db_gtfs.py" +fi + +# Run the appropriate script +(cd "$SCRIPT_PATH"/../api/ && pip3 install -r requirements.txt && PYTHONPATH=src python src/scripts/$SCRIPT_NAME --filepath "$1")