From 9d4b8d5c6881f44dc85421bcb5e2c74a3bf6fa00 Mon Sep 17 00:00:00 2001
From: cka-y <60586858+cka-y@users.noreply.github.com>
Date: Wed, 14 Aug 2024 14:56:11 -0400
Subject: [PATCH] feat: Add GBFS feeds to the database (#674)
---
.github/workflows/db-update-dev.yml | 2 +-
.github/workflows/db-update-qa.yml | 2 +-
.github/workflows/db-update.yml | 22 +-
README.md | 5 +-
api/src/database/database.py | 6 +-
api/src/feeds/impl/feeds_api_impl.py | 2 +
api/src/feeds/impl/search_api_impl.py | 1 +
api/src/scripts/gbfs_utils/__init__.py | 0
api/src/scripts/gbfs_utils/comparison.py | 82 ++++++
api/src/scripts/gbfs_utils/fetching.py | 81 ++++++
api/src/scripts/gbfs_utils/gbfs_versions.py | 19 ++
api/src/scripts/gbfs_utils/license.py | 27 ++
api/src/scripts/populate_db.py | 266 ++------------------
api/src/scripts/populate_db_gbfs.py | 128 ++++++++++
api/src/scripts/populate_db_gtfs.py | 254 +++++++++++++++++++
api/src/scripts/populate_db_test_data.py | 20 +-
api/tests/test_utils/database.py | 4 +-
api/tests/unittest/test_feeds.py | 4 +-
liquibase/changelog.xml | 1 +
liquibase/changes/feat_565.sql | 22 ++
scripts/populate-db.sh | 15 +-
21 files changed, 689 insertions(+), 274 deletions(-)
create mode 100644 api/src/scripts/gbfs_utils/__init__.py
create mode 100644 api/src/scripts/gbfs_utils/comparison.py
create mode 100644 api/src/scripts/gbfs_utils/fetching.py
create mode 100644 api/src/scripts/gbfs_utils/gbfs_versions.py
create mode 100644 api/src/scripts/gbfs_utils/license.py
create mode 100644 api/src/scripts/populate_db_gbfs.py
create mode 100644 api/src/scripts/populate_db_gtfs.py
create mode 100644 liquibase/changes/feat_565.sql
diff --git a/.github/workflows/db-update-dev.yml b/.github/workflows/db-update-dev.yml
index 49b28a692..c225b1279 100644
--- a/.github/workflows/db-update-dev.yml
+++ b/.github/workflows/db-update-dev.yml
@@ -6,7 +6,7 @@ on:
- main
paths:
- 'liquibase/changelog.xml'
- - 'api/src/scripts/populate_db.py'
+ - 'api/src/scripts/populate_db*'
repository_dispatch: # Update on mobility-database-catalog repo dispatch
types: [ catalog-sources-updated ]
workflow_dispatch:
diff --git a/.github/workflows/db-update-qa.yml b/.github/workflows/db-update-qa.yml
index 04b265a7d..21b392e5a 100644
--- a/.github/workflows/db-update-qa.yml
+++ b/.github/workflows/db-update-qa.yml
@@ -6,7 +6,7 @@ on:
- main
paths:
- 'liquibase/changelog.xml'
- - 'api/src/scripts/populate_db.py'
+ - 'api/src/scripts/populate_db*'
workflow_dispatch:
jobs:
update:
diff --git a/.github/workflows/db-update.yml b/.github/workflows/db-update.yml
index a0d1ae703..9fd6f15e1 100644
--- a/.github/workflows/db-update.yml
+++ b/.github/workflows/db-update.yml
@@ -189,15 +189,33 @@ jobs:
id: getpath
run: echo "PATH=$(realpath sources.csv)" >> $GITHUB_OUTPUT
- - name: Update Database Content
+ - name: Download systems.csv
+ run: wget -O systems.csv https://raw.githubusercontent.com/MobilityData/gbfs/master/systems.csv
+
+ - name: Get full path of systems.csv
+ id: getsyspath
+ run: echo "PATH=$(realpath systems.csv)" >> $GITHUB_OUTPUT
+
+ - name: GTFS - Update Database Content
run: scripts/populate-db.sh ${{ steps.getpath.outputs.PATH }} > populate.log
- - name: Upload log file for verification
+ - name: GBFS - Update Database Content
+ run: scripts/populate-db.sh ${{ steps.getsyspath.outputs.PATH }} gbfs >> populate-gbfs.log
+
+ - name: GTFS - Upload log file for verification
+ if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: populate-${{ inputs.ENVIRONMENT }}.log
path: populate.log
+ - name: GBFS - Upload log file for verification
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: populate-gbfs-${{ inputs.ENVIRONMENT }}.log
+ path: populate-gbfs.log
+
update-gcp-secret:
name: Update GCP Secrets
if: ${{ github.event_name == 'repository_dispatch' || github.event_name == 'workflow_dispatch' }}
diff --git a/README.md b/README.md
index 1bd1dc6f1..02806934f 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
# Mobility Feed API
![Deploy Feeds API - QA](https://github.com/MobilityData/mobility-feed-api/workflows/Deploy%20Feeds%20API%20-%20QA/badge.svg?branch=main)
-![Deploy Web App - QA](https://github.com/MobilityData/mobility-feed-api/actions/workflows/web-app.yml/badge.svg?branch=main)
+![Deploy Web App - QA](https://github.com/MobilityData/mobility-feed-api/actions/workflows/web-qa.yml/badge.svg?branch=main)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
The Mobility Feed API service a list of open mobility data sources from across the world. This repository is the effort the initial effort to convert the current [The Mobility Database Catalogs](https://github.com/MobilityData/mobility-database-catalogs) in an API service.
@@ -10,6 +10,9 @@ The Mobility Feed API service a list of open mobility data sources from across t
Mobility Feed API is not released yet; any code or service hosted is considered as **Work in Progress**. For more information regarding the current Mobility Database Catalog, go to [The Mobility Database Catalogs](https://github.com/MobilityData/mobility-database-catalogs).
+## GBFS Feeds
+The repository also includes GBFS feeds extracted from [`systems.csv`](https://github.com/MobilityData/gbfs/blob/master/systems.csv) in the [GBFS repository](https://github.com/MobilityData/gbfs). However, these feeds are not being served yet. The supported versions of these feeds are specified in the file [api/src/scripts/gbfs_utils/gbfs_versions.py](https://github.com/MobilityData/mobility-feed-api/blob/main/api/src/scripts/gbfs_utils/gbfs_versions.py).
+
# Authentication
To access the Mobility Feed API, users need to authenticate using an access token. Here is the step-by-step process to obtain and use an access token:
diff --git a/api/src/database/database.py b/api/src/database/database.py
index 086dc772a..653550164 100644
--- a/api/src/database/database.py
+++ b/api/src/database/database.py
@@ -7,7 +7,7 @@
from sqlalchemy import create_engine, inspect
from sqlalchemy.orm import load_only, Query, class_mapper, Session
-from database_gen.sqlacodegen_models import Base, Feed, Gtfsfeed, Gtfsrealtimefeed
+from database_gen.sqlacodegen_models import Base, Feed, Gtfsfeed, Gtfsrealtimefeed, Gbfsfeed
from sqlalchemy.orm import sessionmaker
import logging
from typing import Final
@@ -43,6 +43,10 @@ def configure_polymorphic_mappers():
gtfsrealtimefeed_mapper.inherits = feed_mapper
gtfsrealtimefeed_mapper.polymorphic_identity = Gtfsrealtimefeed.__tablename__.lower()
+ gbfsfeed_mapper = class_mapper(Gbfsfeed)
+ gbfsfeed_mapper.inherits = feed_mapper
+ gbfsfeed_mapper.polymorphic_identity = Gbfsfeed.__tablename__.lower()
+
class Database:
"""
diff --git a/api/src/feeds/impl/feeds_api_impl.py b/api/src/feeds/impl/feeds_api_impl.py
index c7213e75f..6727516c2 100644
--- a/api/src/feeds/impl/feeds_api_impl.py
+++ b/api/src/feeds/impl/feeds_api_impl.py
@@ -59,6 +59,7 @@ def get_feed(
feed = (
FeedFilter(stable_id=id, provider__ilike=None, producer_url__ilike=None, status=None)
.filter(Database().get_query_model(Feed))
+ .filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
.first()
)
if feed:
@@ -79,6 +80,7 @@ def get_feeds(
status=status, provider__ilike=provider, producer_url__ilike=producer_url, stable_id=None
)
feed_query = feed_filter.filter(Database().get_query_model(Feed))
+ feed_query = feed_query.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
# Results are sorted by provider
feed_query = feed_query.order_by(Feed.provider, Feed.stable_id)
feed_query = feed_query.options(*BasicFeedImpl.get_joinedload_options())
diff --git a/api/src/feeds/impl/search_api_impl.py b/api/src/feeds/impl/search_api_impl.py
index 18484f534..652a67de8 100644
--- a/api/src/feeds/impl/search_api_impl.py
+++ b/api/src/feeds/impl/search_api_impl.py
@@ -35,6 +35,7 @@ def add_search_query_filters(query, search_query, data_type, feed_id, status) ->
Filter values are trimmed and converted to lowercase.
The search query is also converted to its unaccented version.
"""
+ query = query.filter(t_feedsearch.c.data_type != "gbfs") # Filter out GBFS feeds
if feed_id:
query = query.where(t_feedsearch.c.feed_stable_id == feed_id.strip().lower())
if data_type:
diff --git a/api/src/scripts/gbfs_utils/__init__.py b/api/src/scripts/gbfs_utils/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/api/src/scripts/gbfs_utils/comparison.py b/api/src/scripts/gbfs_utils/comparison.py
new file mode 100644
index 000000000..5c2795c2a
--- /dev/null
+++ b/api/src/scripts/gbfs_utils/comparison.py
@@ -0,0 +1,82 @@
+import pandas as pd
+from sqlalchemy.orm import joinedload
+from database_gen.sqlacodegen_models import Gbfsfeed
+
+
+def generate_system_csv_from_db(df, db_session):
+ """Generate a DataFrame from the database with the same columns as the CSV file."""
+ stable_ids = "gbfs-" + df["System ID"]
+ query = db_session.query(Gbfsfeed)
+ query = query.filter(Gbfsfeed.stable_id.in_(stable_ids.to_list()))
+ query = query.options(
+ joinedload(Gbfsfeed.locations), joinedload(Gbfsfeed.gbfsversions), joinedload(Gbfsfeed.externalids)
+ )
+ feeds = query.all()
+ data = []
+ for feed in feeds:
+ system_id = feed.externalids[0].associated_id
+ auto_discovery_url = feed.auto_discovery_url
+ feed.gbfsversions.sort(key=lambda x: x.version, reverse=False)
+ supported_versions = [version.version for version in feed.gbfsversions]
+ data.append(
+ {
+ "System ID": system_id,
+ "Name": feed.operator,
+ "URL": feed.operator_url,
+ "Country Code": feed.locations[0].country_code,
+ "Location": feed.locations[0].municipality,
+ "Auto-Discovery URL": auto_discovery_url,
+ "Supported Versions": " ; ".join(supported_versions),
+ }
+ )
+ return pd.DataFrame(data)
+
+
+def compare_db_to_csv(df_from_db, df_from_csv, logger):
+ """Compare the database to the CSV file and return the differences."""
+ df_from_csv = df_from_csv[df_from_db.columns]
+ df_from_db = df_from_db.fillna("")
+ df_from_csv = df_from_csv.fillna("")
+
+ if df_from_db.empty:
+ logger.info("No data found in the database.")
+ return None, None
+
+ # Align both DataFrames by "System ID"
+ df_from_db.set_index("System ID", inplace=True)
+ df_from_csv.set_index("System ID", inplace=True)
+
+ # Find rows that are in the CSV but not in the DB (new feeds)
+ missing_in_db = df_from_csv[~df_from_csv.index.isin(df_from_db.index)]
+ if not missing_in_db.empty:
+ logger.info("New feeds found in CSV:")
+ logger.info(missing_in_db)
+
+ # Find rows that are in the DB but not in the CSV (deprecated feeds)
+ missing_in_csv = df_from_db[~df_from_db.index.isin(df_from_csv.index)]
+ if not missing_in_csv.empty:
+ logger.info("Deprecated feeds found in DB:")
+ logger.info(missing_in_csv)
+
+ # Find rows that are in both, but with differences
+ common_ids = df_from_db.index.intersection(df_from_csv.index)
+ df_db_common = df_from_db.loc[common_ids]
+ df_csv_common = df_from_csv.loc[common_ids]
+ differences = df_db_common != df_csv_common
+ differing_rows = df_db_common[differences.any(axis=1)]
+
+ if not differing_rows.empty:
+ logger.info("Rows with differences:")
+ for idx in differing_rows.index:
+ logger.info(f"Differences for System ID {idx}:")
+ db_row = df_db_common.loc[idx]
+ csv_row = df_csv_common.loc[idx]
+ diff = db_row != csv_row
+ logger.info(f"DB Row: {db_row[diff].to_dict()}")
+ logger.info(f"CSV Row: {csv_row[diff].to_dict()}")
+ logger.info(80 * "-")
+
+ # Merge differing rows with missing_in_db to capture all new or updated feeds
+ all_differing_or_new_rows = pd.concat([differing_rows, missing_in_db]).reset_index()
+
+ return all_differing_or_new_rows, missing_in_csv
diff --git a/api/src/scripts/gbfs_utils/fetching.py b/api/src/scripts/gbfs_utils/fetching.py
new file mode 100644
index 000000000..974b94d57
--- /dev/null
+++ b/api/src/scripts/gbfs_utils/fetching.py
@@ -0,0 +1,81 @@
+import requests
+
+
+def fetch_data(auto_discovery_url, logger, urls=[], fields=[]):
+ """Fetch data from the auto-discovery URL and return the specified fields."""
+ fetched_data = {}
+ if not auto_discovery_url:
+ return
+ try:
+ response = requests.get(auto_discovery_url)
+ response.raise_for_status()
+ data = response.json()
+ for field in fields:
+ fetched_data[field] = data.get(field)
+ feeds = None
+ for lang_code, lang_data in data.get("data", {}).items():
+ if isinstance(lang_data, list):
+ lang_feeds = lang_data
+ else:
+ lang_feeds = lang_data.get("feeds", [])
+ if lang_code == "en":
+ feeds = lang_feeds
+ break
+ elif not feeds:
+ feeds = lang_feeds
+ logger.info(f"Feeds found from auto-discovery URL {auto_discovery_url}: {feeds}")
+ if feeds:
+ for url in urls:
+ fetched_data[url] = get_field_url(feeds, url)
+ return fetched_data
+ except requests.RequestException as e:
+ logger.error(f"Error fetching data for autodiscovery url {auto_discovery_url}: {e}")
+ return fetched_data
+
+
+def get_data_content(url, logger):
+ """Utility function to fetch data content from a URL."""
+ try:
+ if url:
+ response = requests.get(url)
+ response.raise_for_status()
+ system_info = response.json().get("data", {})
+ return system_info
+ except requests.RequestException as e:
+ logger.error(f"Error fetching data content for url {url}: {e}")
+ return None
+
+
+def get_field_url(fields, field_name):
+ """Utility function to get the URL of a specific feed by name."""
+ for field in fields:
+ if field.get("name") == field_name:
+ return field.get("url")
+ return None
+
+
+def get_gbfs_versions(gbfs_versions_url, auto_discovery_url, auto_discovery_version, logger):
+ """Get the GBFS versions from the gbfs_versions_url."""
+ # Default version info extracted from auto-discovery url
+ version_info = {
+ "version": auto_discovery_version if auto_discovery_version else "1.0",
+ "url": auto_discovery_url,
+ }
+ try:
+ if not gbfs_versions_url:
+ return [version_info]
+ logger.info(f"Fetching GBFS versions from: {gbfs_versions_url}")
+ data = get_data_content(gbfs_versions_url, logger)
+ if not data:
+ logger.warning(f"No data found in the GBFS versions URL -> {gbfs_versions_url}.")
+ return [version_info]
+ gbfs_versions = data.get("versions", [])
+
+ # Append the version info from auto-discovery if it doesn't exist
+ if not any(gv.get("version") == auto_discovery_version for gv in gbfs_versions):
+ gbfs_versions.append(version_info)
+
+ return gbfs_versions
+ except Exception as e:
+ logger.error(f"Error fetching version data: {e}")
+ return [version_info]
diff --git a/api/src/scripts/gbfs_utils/gbfs_versions.py b/api/src/scripts/gbfs_utils/gbfs_versions.py
new file mode 100644
index 000000000..9ffce502c
--- /dev/null
+++ b/api/src/scripts/gbfs_utils/gbfs_versions.py
@@ -0,0 +1,19 @@
+OFFICIAL_VERSIONS = [
+ "1.0",
+ "1.1-RC",
+ "1.1",
+ "2.0-RC",
+ "2.0",
+ "2.1-RC",
+ "2.1-RC2",
+ "2.1",
+ "2.2-RC",
+ "2.2",
+ "2.3-RC",
+ "2.3-RC2",
+ "2.3",
+ "3.0-RC",
+ "3.0-RC2",
+ "3.0",
+ "3.1-RC",
+]
diff --git a/api/src/scripts/gbfs_utils/license.py b/api/src/scripts/gbfs_utils/license.py
new file mode 100644
index 000000000..793ea4019
--- /dev/null
+++ b/api/src/scripts/gbfs_utils/license.py
@@ -0,0 +1,27 @@
+LICENSE_URL_MAP = {
+ "CC0-1.0": "https://creativecommons.org/publicdomain/zero/1.0/",
+ "CC-BY-4.0": "https://creativecommons.org/licenses/by/4.0/",
+ "CDLA-Permissive-1.0": "https://cdla.io/permissive-1-0/",
+ "ODC-By-1.0": "https://www.opendatacommons.org/licenses/by/1.0/",
+}
+
+DEFAULT_LICENSE_URL = "https://creativecommons.org/licenses/by/4.0/"
+
+
+def get_license_url(system_info, logger):
+ """Get the license URL from the system information."""
+ try:
+ if system_info is None:
+ return None
+
+ # Fetching license_url or license_id
+ license_url = system_info.get("license_url")
+ if not license_url:
+ license_id = system_info.get("license_id")
+ if license_id:
+ return LICENSE_URL_MAP.get(license_id, DEFAULT_LICENSE_URL)
+ return DEFAULT_LICENSE_URL
+ return license_url
+ except Exception as e:
+ logger.error(f"Error fetching license url data from system info {system_info}: \n{e}")
+ return None
diff --git a/api/src/scripts/populate_db.py b/api/src/scripts/populate_db.py
index 83c8f3127..533fcb2ea 100644
--- a/api/src/scripts/populate_db.py
+++ b/api/src/scripts/populate_db.py
@@ -1,34 +1,21 @@
import argparse
+import logging
import os
from pathlib import Path
from typing import Type
import pandas
from dotenv import load_dotenv
-from sqlalchemy import text
-from database.database import Database, generate_unique_id, configure_polymorphic_mappers
-from database_gen.sqlacodegen_models import (
- Entitytype,
- Externalid,
- Gtfsfeed,
- Gtfsrealtimefeed,
- Location,
- Redirectingid,
- t_feedsearch,
- Feed,
-)
-from scripts.load_dataset_on_create import publish_all
-from utils.data_utils import set_up_defaults
+from database.database import Database
+from database_gen.sqlacodegen_models import Feed, Gtfsrealtimefeed, Gtfsfeed, Gbfsfeed
from utils.logger import Logger
-from datetime import datetime
-import pytz
-
-import logging
logging.basicConfig()
logging.getLogger("sqlalchemy.engine").setLevel(logging.ERROR)
+feed_mapping = {"gtfs_rt": Gtfsrealtimefeed, "gtfs": Gtfsfeed, "gbfs": Gbfsfeed}
+
def set_up_configs():
"""
@@ -67,19 +54,21 @@ def __init__(self, filepaths):
new_df = pandas.read_csv(filepath, low_memory=False)
self.df = pandas.concat([self.df, new_df])
- # Filter unsupported data types
- self.df = self.df[(self.df.data_type == "gtfs") | (self.df.data_type == "gtfs-rt")]
- self.df = set_up_defaults(self.df)
- self.added_gtfs_feeds = [] # Keep track of the feeds that have been added to the database
+ self.filter_data()
+
+ def query_feed_by_stable_id(self, stable_id: str, data_type: str | None) -> Gtfsrealtimefeed | Gtfsfeed | None:
+ """
+ Query the feed by stable id
+ """
+ model = self.get_model(data_type)
+ return self.db.session.query(model).filter(model.stable_id == stable_id).first()
@staticmethod
- def get_model(data_type: str | None) -> Type[Gtfsrealtimefeed | Gtfsfeed | Feed]:
+ def get_model(data_type: str | None) -> Type[Feed]:
"""
Get the model based on the data type
"""
- if data_type is None:
- return Feed
- return Gtfsrealtimefeed if data_type == "gtfs_rt" else Gtfsfeed
+ return feed_mapping.get(data_type, Feed)
@staticmethod
def get_safe_value(row, column_name, default_value):
@@ -90,228 +79,17 @@ def get_safe_value(row, column_name, default_value):
return default_value if default_value is not None else None
return f"{row[column_name]}".strip()
- def get_data_type(self, row):
- """
- Get the data type from the row
- """
- data_type = self.get_safe_value(row, "data_type", "").lower()
- if data_type not in ["gtfs", "gtfs-rt", "gtfs_rt"]:
- self.logger.warning(f"Unsupported data type: {data_type}")
- return None
- return data_type.replace("-", "_")
-
- def query_feed_by_stable_id(self, stable_id: str, data_type: str | None) -> Gtfsrealtimefeed | Gtfsfeed | None:
- """
- Query the feed by stable id
- """
- model = self.get_model(data_type)
- return self.db.session.query(model).filter(model.stable_id == stable_id).first()
-
- def get_stable_id(self, row):
- """
- Get the stable id from the row
- """
- return f'mdb-{self.get_safe_value(row, "mdb_source_id", "")}'
-
- def populate_location(self, feed, row, stable_id):
+ @staticmethod
+ def get_location_id(country_code, subdivision_name, municipality):
"""
- Populate the location for the feed
+ Get the location ID
"""
- # TODO: validate behaviour for gtfs-rt feeds
- if feed.locations and feed.data_type == "gtfs":
- self.logger.warning(f"Location already exists for feed {stable_id}")
- return
-
- country_code = self.get_safe_value(row, "location.country_code", "")
- subdivision_name = self.get_safe_value(row, "location.subdivision_name", "")
- municipality = self.get_safe_value(row, "location.municipality", "")
composite_id = f"{country_code}-{subdivision_name}-{municipality}".replace(" ", "_")
location_id = composite_id if len(composite_id) > 2 else None
- if not location_id:
- self.logger.warning(f"Location ID is empty for feed {stable_id}")
- feed.locations.clear()
- else:
- location = self.db.session.get(Location, location_id)
- location = (
- location
- if location
- else Location(
- id=location_id,
- country_code=country_code,
- subdivision_name=subdivision_name,
- municipality=municipality,
- )
- )
- feed.locations = [location]
-
- def process_entity_types(self, feed: Gtfsrealtimefeed, row, stable_id):
- """
- Process the entity types for the feed
- """
- entity_types = self.get_safe_value(row, "entity_type", "").replace("|", "-").split("-")
- if len(entity_types) > 0:
- for entity_type_name in entity_types:
- entity_type = self.db.session.query(Entitytype).filter(Entitytype.name == entity_type_name).first()
- if not entity_type:
- entity_type = Entitytype(name=entity_type_name)
- if all(entity_type.name != entity.name for entity in feed.entitytypes):
- feed.entitytypes.append(entity_type)
- self.db.session.flush()
- else:
- self.logger.warning(f"Entity types array is empty for feed {stable_id}")
- feed.entitytypes.clear()
-
- def process_feed_references(self):
- """
- Process the feed references
- """
- self.logger.info("Processing feed references")
- for index, row in self.df.iterrows():
- stable_id = self.get_stable_id(row)
- data_type = self.get_data_type(row)
- if data_type != "gtfs_rt":
- continue
- gtfs_rt_feed = self.query_feed_by_stable_id(stable_id, "gtfs_rt")
- static_reference = self.get_safe_value(row, "static_reference", "")
- if static_reference:
- gtfs_stable_id = f"mdb-{int(float(static_reference))}"
- gtfs_feed = self.query_feed_by_stable_id(gtfs_stable_id, "gtfs")
- already_referenced_ids = {ref.id for ref in gtfs_feed.gtfs_rt_feeds}
- if gtfs_feed and gtfs_rt_feed.id not in already_referenced_ids:
- gtfs_feed.gtfs_rt_feeds.append(gtfs_rt_feed)
- # Flush to avoid FK violation
- self.db.session.flush()
+ return location_id
- def process_redirects(self):
- """
- Process the redirects
+ def filter_data(self):
"""
- self.logger.info("Processing redirects")
- for index, row in self.df.iterrows():
- stable_id = self.get_stable_id(row)
- raw_redirects = row.get("redirect.id", None)
- redirects_ids = str(raw_redirects).split("|") if raw_redirects is not None else []
- if len(redirects_ids) == 0:
- continue
- feed = self.query_feed_by_stable_id(stable_id, None)
- raw_comments = row.get("redirect.comment", None)
- comments = raw_comments.split("|") if raw_comments is not None else []
- if len(redirects_ids) != len(comments) and len(comments) > 0:
- self.logger.warning(f"Number of redirect ids and redirect comments differ for feed {stable_id}")
- for mdb_source_id in redirects_ids:
- if len(mdb_source_id) == 0:
- # since there is a 1:1 correspondence between redirect ids and comments, skip also the comment
- comments = comments[1:]
- continue
- if comments:
- comment = comments.pop(0)
- else:
- comment = ""
-
- target_stable_id = f"mdb-{int(float(mdb_source_id.strip()))}"
- target_feed = self.query_feed_by_stable_id(target_stable_id, None)
- if not target_feed:
- self.logger.warning(f"Could not find redirect target feed {target_stable_id} for feed {stable_id}")
- continue
-
- if feed.id == target_feed.id:
- self.logger.error(f"Feed has redirect pointing to itself {stable_id}")
- else:
- if all(redirect.target_id != target_feed.id for redirect in feed.redirectingids):
- feed.redirectingids.append(
- Redirectingid(source_id=feed.id, target_id=target_feed.id, redirect_comment=comment)
- )
- # Flush to avoid FK violation
- self.db.session.flush()
-
- def populate_db(self):
+ Filter the data to only include the necessary columns
"""
- Populate the database with the sources.csv data
- """
- self.logger.info("Populating the database with sources.csv data")
- for index, row in self.df.iterrows():
- self.logger.debug(f"Populating Database with Feed [stable_id = {row['mdb_source_id']}]")
- # Create or update the GTFS feed
- data_type = self.get_data_type(row)
- stable_id = self.get_stable_id(row)
- feed = self.query_feed_by_stable_id(stable_id, data_type)
- if feed:
- self.logger.debug(f"Updating {feed.__class__.__name__}: {stable_id}")
- else:
- feed = self.get_model(data_type)(
- id=generate_unique_id(),
- data_type=data_type,
- stable_id=stable_id,
- created_at=datetime.now(pytz.utc), # Current timestamp with UTC timezone
- )
- self.logger.info(f"Creating {feed.__class__.__name__}: {stable_id}")
- self.db.session.add(feed)
- if data_type == "gtfs":
- self.added_gtfs_feeds.append(feed)
- feed.externalids = [
- Externalid(
- feed_id=feed.id,
- associated_id=str(int(float(row["mdb_source_id"]))),
- source="mdb",
- )
- ]
- # Populate common fields from Feed
- feed.feed_name = self.get_safe_value(row, "name", "")
- feed.note = self.get_safe_value(row, "note", "")
- feed.producer_url = self.get_safe_value(row, "urls.direct_download", "")
- feed.authentication_type = str(int(float(self.get_safe_value(row, "urls.authentication_type", "0"))))
- feed.authentication_info_url = self.get_safe_value(row, "urls.authentication_info", "")
- feed.api_key_parameter_name = self.get_safe_value(row, "urls.api_key_parameter_name", "")
- feed.license_url = self.get_safe_value(row, "urls.license", "")
- feed.status = self.get_safe_value(row, "status", "active")
- feed.feed_contact_email = self.get_safe_value(row, "feed_contact_email", "")
- feed.provider = self.get_safe_value(row, "provider", "")
-
- self.populate_location(feed, row, stable_id)
- if data_type == "gtfs_rt":
- self.process_entity_types(feed, row, stable_id)
-
- self.db.session.add(feed)
- self.db.session.flush()
- # This need to be done after all feeds are added to the session to avoid FK violation
- self.process_feed_references()
- self.process_redirects()
-
- def trigger_downstream_tasks(self):
- """
- Trigger downstream tasks after populating the database
- """
- self.logger.info("Triggering downstream tasks")
- self.logger.debug(
- f"New feeds added to the database: "
- f"{','.join([feed.stable_id for feed in self.added_gtfs_feeds] if self.added_gtfs_feeds else [])}"
- )
-
- env = os.getenv("ENV")
- self.logger.info(f"ENV = {env}")
- if os.getenv("ENV", "local") != "local":
- publish_all(self.added_gtfs_feeds) # Publishes the new feeds to the Pub/Sub topic to download the datasets
-
- # Extracted the following code from main so it can be executed as a library function
- def initialize(self, trigger_downstream_tasks: bool = True):
- try:
- configure_polymorphic_mappers()
- self.populate_db()
- self.db.session.commit()
-
- self.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Started")
- self.db.session.execute(text(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {t_feedsearch.name}"))
- self.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Completed")
- self.db.session.commit()
- self.logger.info("\n----- Database populated with sources.csv data. -----")
- if trigger_downstream_tasks:
- self.trigger_downstream_tasks()
- except Exception as e:
- self.logger.error(f"\n------ Failed to populate the database with sources.csv: {e} -----\n")
- self.db.session.rollback()
- exit(1)
-
-
-if __name__ == "__main__":
- db_helper = DatabasePopulateHelper(set_up_configs())
- db_helper.initialize()
+ pass # Should be implemented in the child class
diff --git a/api/src/scripts/populate_db_gbfs.py b/api/src/scripts/populate_db_gbfs.py
new file mode 100644
index 000000000..4bb7796eb
--- /dev/null
+++ b/api/src/scripts/populate_db_gbfs.py
@@ -0,0 +1,128 @@
+from datetime import datetime
+
+import pandas as pd
+import pytz
+
+from database.database import generate_unique_id, configure_polymorphic_mappers
+from database_gen.sqlacodegen_models import Gbfsfeed, Location, Gbfsversion, Externalid
+from scripts.gbfs_utils.comparison import generate_system_csv_from_db, compare_db_to_csv
+from scripts.gbfs_utils.fetching import fetch_data, get_data_content, get_gbfs_versions
+from scripts.gbfs_utils.license import get_license_url
+from scripts.populate_db import DatabasePopulateHelper, set_up_configs
+from scripts.gbfs_utils.gbfs_versions import OFFICIAL_VERSIONS
+
+
+class GBFSDatabasePopulateHelper(DatabasePopulateHelper):
+ def __init__(self, file_path):
+ super().__init__(file_path)
+
+ def filter_data(self):
+ """Filter out rows with Authentication Info and duplicate System IDs"""
+ self.df = self.df[pd.isna(self.df["Authentication Info"])]
+ self.df = self.df[~self.df.duplicated(subset="System ID", keep=False)]
+ self.logger.info(f"Data = {self.df}")
+
+ @staticmethod
+ def get_stable_id(row):
+ return f"gbfs-{row['System ID']}"
+
+ @staticmethod
+ def get_external_id(feed_id, system_id):
+ return Externalid(feed_id=feed_id, associated_id=str(system_id), source="gbfs")
+
+ def deprecate_feeds(self, deprecated_feeds):
+ """Deprecate feeds that are no longer in systems.csv"""
+ if deprecated_feeds is None or deprecated_feeds.empty:
+ self.logger.info("No feeds to deprecate.")
+ return
+ self.logger.info(f"Deprecating {len(deprecated_feeds)} feed(s).")
+ for index, row in deprecated_feeds.iterrows():
+ stable_id = self.get_stable_id(row)
+ gbfs_feed = self.query_feed_by_stable_id(stable_id, "gbfs")
+ if gbfs_feed:
+ self.logger.info(f"Deprecating feed with stable_id={stable_id}")
+ gbfs_feed.status = "deprecated"
+ self.db.session.flush()
+
+ def populate_db(self):
+ """Populate the database with the GBFS feeds"""
+ start_time = datetime.now()
+ configure_polymorphic_mappers()
+
+ # Compare the database to the CSV file
+ df_from_db = generate_system_csv_from_db(self.df, self.db.session)
+ added_or_updated_feeds, deprecated_feeds = compare_db_to_csv(df_from_db, self.df, self.logger)
+
+ self.deprecate_feeds(deprecated_feeds)
+ if added_or_updated_feeds is None:
+ added_or_updated_feeds = self.df
+ for index, row in added_or_updated_feeds.iterrows():
+ self.logger.info(f"Processing row {index + 1} of {len(added_or_updated_feeds)}")
+ stable_id = self.get_stable_id(row)
+ gbfs_feed = self.query_feed_by_stable_id(stable_id, "gbfs")
+ fetched_data = fetch_data(
+ row["Auto-Discovery URL"], self.logger, ["system_information", "gbfs_versions"], ["version"]
+ )
+ # If the feed already exists, update it. Otherwise, create a new feed.
+ if gbfs_feed:
+ feed_id = gbfs_feed.id
+ self.logger.info(f"Updating feed {stable_id} - {row['Name']}")
+ else:
+ feed_id = generate_unique_id()
+ self.logger.info(f"Creating new feed for {stable_id} - {row['Name']}")
+ gbfs_feed = Gbfsfeed(
+ id=feed_id,
+ data_type="gbfs",
+ stable_id=stable_id,
+ created_at=datetime.now(pytz.utc),
+ )
+ gbfs_feed.externalids = [self.get_external_id(feed_id, row["System ID"])]
+ self.db.session.add(gbfs_feed)
+
+ system_information_content = get_data_content(fetched_data.get("system_information"), self.logger)
+ gbfs_feed.license_url = get_license_url(system_information_content, self.logger)
+ gbfs_feed.feed_contact_email = (
+ system_information_content.get("feed_contact_email") if system_information_content else None
+ )
+ gbfs_feed.operator = row["Name"]
+ gbfs_feed.operator_url = row["URL"]
+ gbfs_feed.auto_discovery_url = row["Auto-Discovery URL"]
+ gbfs_feed.updated_at = datetime.now(pytz.utc)
+
+ country_code = self.get_safe_value(row, "Country Code", "")
+ municipality = self.get_safe_value(row, "Location", "")
+ location_id = self.get_location_id(country_code, None, municipality)
+ location = self.db.session.get(Location, location_id) or Location(
+ id=location_id,
+ country_code=country_code,
+ municipality=municipality,
+ )
+ gbfs_feed.locations.clear()
+ gbfs_feed.locations = [location]
+
+ # Add the GBFS versions
+ versions = get_gbfs_versions(
+ fetched_data.get("gbfs_versions"), row["Auto-Discovery URL"], fetched_data.get("version"), self.logger
+ )
+ existing_versions = [version.version for version in gbfs_feed.gbfsversions]
+ for version in versions:
+ version_value = version.get("version")
+ if version_value.upper() in OFFICIAL_VERSIONS and version_value not in existing_versions:
+ gbfs_feed.gbfsversions.append(
+ Gbfsversion(
+ feed_id=feed_id,
+ url=version.get("url"),
+ version=version_value,
+ )
+ )
+
+ self.db.session.flush()
+ self.logger.info(80 * "-")
+
+ self.db.session.commit()
+ end_time = datetime.now()
+ self.logger.info(f"Time taken: {end_time - start_time} seconds")
+
+
+if __name__ == "__main__":
+ GBFSDatabasePopulateHelper(set_up_configs()).populate_db()
diff --git a/api/src/scripts/populate_db_gtfs.py b/api/src/scripts/populate_db_gtfs.py
new file mode 100644
index 000000000..40eee6a4b
--- /dev/null
+++ b/api/src/scripts/populate_db_gtfs.py
@@ -0,0 +1,254 @@
+import os
+from datetime import datetime
+
+import pytz
+from sqlalchemy import text
+
+from database.database import generate_unique_id, configure_polymorphic_mappers
+from database_gen.sqlacodegen_models import (
+ Entitytype,
+ Externalid,
+ Gtfsrealtimefeed,
+ Location,
+ Redirectingid,
+ t_feedsearch,
+)
+from scripts.populate_db import DatabasePopulateHelper, set_up_configs
+from scripts.load_dataset_on_create import publish_all
+from utils.data_utils import set_up_defaults
+
+
+class GTFSDatabasePopulateHelper(DatabasePopulateHelper):
+ """
+ GTFS - Helper class to populate the database
+ """
+
+ def __init__(self, filepaths):
+ """
+ Specify a list of files to load the csv data from.
+ Can also be a single string with a file name.
+ """
+ super().__init__(filepaths)
+ self.added_gtfs_feeds = [] # Keep track of the feeds that have been added to the database
+
+ def filter_data(self):
+ self.df = self.df[(self.df.data_type == "gtfs") | (self.df.data_type == "gtfs-rt")]
+ self.df = set_up_defaults(self.df)
+ self.added_gtfs_feeds = [] # Keep track of the feeds that have been added to the database
+
+ def get_data_type(self, row):
+ """
+ Get the data type from the row
+ """
+ data_type = self.get_safe_value(row, "data_type", "").lower()
+ if data_type not in ["gtfs", "gtfs-rt", "gtfs_rt"]:
+ self.logger.warning(f"Unsupported data type: {data_type}")
+ return None
+ return data_type.replace("-", "_")
+
+ def get_stable_id(self, row):
+ """
+ Get the stable id from the row
+ """
+ return f'mdb-{self.get_safe_value(row, "mdb_source_id", "")}'
+
+ def populate_location(self, feed, row, stable_id):
+ """
+ Populate the location for the feed
+ """
+ if feed.locations:
+ self.logger.warning(f"Location already exists for feed {stable_id}")
+ return
+
+ country_code = self.get_safe_value(row, "location.country_code", "")
+ subdivision_name = self.get_safe_value(row, "location.subdivision_name", "")
+ municipality = self.get_safe_value(row, "location.municipality", "")
+ location_id = self.get_location_id(country_code, subdivision_name, municipality)
+ if not location_id:
+ self.logger.warning(f"Location ID is empty for feed {stable_id}")
+ feed.locations.clear()
+ else:
+ location = self.db.session.get(Location, location_id)
+ location = (
+ location
+ if location
+ else Location(
+ id=location_id,
+ country_code=country_code,
+ subdivision_name=subdivision_name,
+ municipality=municipality,
+ )
+ )
+ feed.locations = [location]
+
+ def process_entity_types(self, feed: Gtfsrealtimefeed, row, stable_id):
+ """
+ Process the entity types for the feed
+ """
+ entity_types = self.get_safe_value(row, "entity_type", "").replace("|", "-").split("-")
+ if len(entity_types) > 0:
+ for entity_type_name in entity_types:
+ entity_type = self.db.session.query(Entitytype).filter(Entitytype.name == entity_type_name).first()
+ if not entity_type:
+ entity_type = Entitytype(name=entity_type_name)
+ if all(entity_type.name != entity.name for entity in feed.entitytypes):
+ feed.entitytypes.append(entity_type)
+ self.db.session.flush()
+ else:
+ self.logger.warning(f"Entity types array is empty for feed {stable_id}")
+ feed.entitytypes.clear()
+
+ def process_feed_references(self):
+ """
+ Process the feed references
+ """
+ self.logger.info("Processing feed references")
+ for index, row in self.df.iterrows():
+ stable_id = self.get_stable_id(row)
+ data_type = self.get_data_type(row)
+ if data_type != "gtfs_rt":
+ continue
+ gtfs_rt_feed = self.query_feed_by_stable_id(stable_id, "gtfs_rt")
+ static_reference = self.get_safe_value(row, "static_reference", "")
+ if static_reference:
+ gtfs_stable_id = f"mdb-{int(float(static_reference))}"
+ gtfs_feed = self.query_feed_by_stable_id(gtfs_stable_id, "gtfs")
+ already_referenced_ids = {ref.id for ref in gtfs_feed.gtfs_rt_feeds}
+ if gtfs_feed and gtfs_rt_feed.id not in already_referenced_ids:
+ gtfs_feed.gtfs_rt_feeds.append(gtfs_rt_feed)
+ # Flush to avoid FK violation
+ self.db.session.flush()
+
+ def process_redirects(self):
+ """
+ Process the redirects
+ """
+ self.logger.info("Processing redirects")
+ for index, row in self.df.iterrows():
+ stable_id = self.get_stable_id(row)
+ raw_redirects = row.get("redirect.id", None)
+ redirects_ids = str(raw_redirects).split("|") if raw_redirects is not None else []
+ if len(redirects_ids) == 0:
+ continue
+ feed = self.query_feed_by_stable_id(stable_id, None)
+ raw_comments = row.get("redirect.comment", None)
+ comments = raw_comments.split("|") if raw_comments is not None else []
+ if len(redirects_ids) != len(comments) and len(comments) > 0:
+ self.logger.warning(f"Number of redirect ids and redirect comments differ for feed {stable_id}")
+ for mdb_source_id in redirects_ids:
+ if len(mdb_source_id) == 0:
+ # since there is a 1:1 correspondence between redirect ids and comments, skip also the comment
+ comments = comments[1:]
+ continue
+ if comments:
+ comment = comments.pop(0)
+ else:
+ comment = ""
+
+ target_stable_id = f"mdb-{int(float(mdb_source_id.strip()))}"
+ target_feed = self.query_feed_by_stable_id(target_stable_id, None)
+ if not target_feed:
+ self.logger.warning(f"Could not find redirect target feed {target_stable_id} for feed {stable_id}")
+ continue
+
+ if feed.id == target_feed.id:
+ self.logger.error(f"Feed has redirect pointing to itself {stable_id}")
+ else:
+ if all(redirect.target_id != target_feed.id for redirect in feed.redirectingids):
+ feed.redirectingids.append(
+ Redirectingid(source_id=feed.id, target_id=target_feed.id, redirect_comment=comment)
+ )
+ # Flush to avoid FK violation
+ self.db.session.flush()
+
+ def populate_db(self):
+ """
+ Populate the database with the sources.csv data
+ """
+ self.logger.info("Populating the database with sources.csv data")
+ for index, row in self.df.iterrows():
+ self.logger.debug(f"Populating Database with Feed [stable_id = {row['mdb_source_id']}]")
+ # Create or update the GTFS feed
+ data_type = self.get_data_type(row)
+ stable_id = self.get_stable_id(row)
+ feed = self.query_feed_by_stable_id(stable_id, data_type)
+ if feed:
+ self.logger.debug(f"Updating {feed.__class__.__name__}: {stable_id}")
+ else:
+ feed = self.get_model(data_type)(
+ id=generate_unique_id(),
+ data_type=data_type,
+ stable_id=stable_id,
+ created_at=datetime.now(pytz.utc), # Current timestamp with UTC timezone
+ )
+ self.logger.info(f"Creating {feed.__class__.__name__}: {stable_id}")
+ self.db.session.add(feed)
+ if data_type == "gtfs":
+ self.added_gtfs_feeds.append(feed)
+ feed.externalids = [
+ Externalid(
+ feed_id=feed.id,
+ associated_id=str(int(float(row["mdb_source_id"]))),
+ source="mdb",
+ )
+ ]
+ # Populate common fields from Feed
+ feed.feed_name = self.get_safe_value(row, "name", "")
+ feed.note = self.get_safe_value(row, "note", "")
+ feed.producer_url = self.get_safe_value(row, "urls.direct_download", "")
+ feed.authentication_type = str(int(float(self.get_safe_value(row, "urls.authentication_type", "0"))))
+ feed.authentication_info_url = self.get_safe_value(row, "urls.authentication_info", "")
+ feed.api_key_parameter_name = self.get_safe_value(row, "urls.api_key_parameter_name", "")
+ feed.license_url = self.get_safe_value(row, "urls.license", "")
+ feed.status = self.get_safe_value(row, "status", "active")
+ feed.feed_contact_email = self.get_safe_value(row, "feed_contact_email", "")
+ feed.provider = self.get_safe_value(row, "provider", "")
+
+ self.populate_location(feed, row, stable_id)
+ if data_type == "gtfs_rt":
+ self.process_entity_types(feed, row, stable_id)
+
+ self.db.session.add(feed)
+ self.db.session.flush()
+ # This need to be done after all feeds are added to the session to avoid FK violation
+ self.process_feed_references()
+ self.process_redirects()
+
+ def trigger_downstream_tasks(self):
+ """
+ Trigger downstream tasks after populating the database
+ """
+ self.logger.info("Triggering downstream tasks")
+ self.logger.debug(
+ f"New feeds added to the database: "
+ f"{','.join([feed.stable_id for feed in self.added_gtfs_feeds] if self.added_gtfs_feeds else [])}"
+ )
+
+ env = os.getenv("ENV")
+ self.logger.info(f"ENV = {env}")
+ if os.getenv("ENV", "local") != "local":
+ publish_all(self.added_gtfs_feeds) # Publishes the new feeds to the Pub/Sub topic to download the datasets
+
+ # Extracted the following code from main, so it can be executed as a library function
+ def initialize(self, trigger_downstream_tasks: bool = True):
+ try:
+ configure_polymorphic_mappers()
+ self.populate_db()
+ self.db.session.commit()
+
+ self.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Started")
+ self.db.session.execute(text(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {t_feedsearch.name}"))
+ self.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Completed")
+ self.db.session.commit()
+ self.logger.info("\n----- Database populated with sources.csv data. -----")
+ if trigger_downstream_tasks:
+ self.trigger_downstream_tasks()
+ except Exception as e:
+ self.logger.error(f"\n------ Failed to populate the database with sources.csv: {e} -----\n")
+ self.db.session.rollback()
+ exit(1)
+
+
+if __name__ == "__main__":
+ db_helper = GTFSDatabasePopulateHelper(set_up_configs())
+ db_helper.initialize()
diff --git a/api/src/scripts/populate_db_test_data.py b/api/src/scripts/populate_db_test_data.py
index af637b845..fae1a1d2f 100644
--- a/api/src/scripts/populate_db_test_data.py
+++ b/api/src/scripts/populate_db_test_data.py
@@ -1,30 +1,14 @@
-import argparse
import json
-import os
-from pathlib import Path
-from dotenv import load_dotenv
+
from geoalchemy2 import WKTElement
from sqlalchemy import text
from database.database import Database
from database_gen.sqlacodegen_models import Gtfsdataset, Validationreport, Gtfsfeed, Notice, Feature, t_feedsearch
-
+from scripts.populate_db import set_up_configs
from utils.logger import Logger
-def set_up_configs():
- """
- Set up function
- """
- parser = argparse.ArgumentParser()
- parser.add_argument("--filepath", help="Absolute path for the JSON file containing the test data", required=True)
- args = parser.parse_args()
- current_path = Path(__file__).resolve()
- dotenv_path = os.path.join(current_path.parents[3], "config", ".env")
- load_dotenv(dotenv_path=dotenv_path)
- return args.filepath
-
-
class DatabasePopulateTestDataHelper:
"""
Helper class to populate
diff --git a/api/tests/test_utils/database.py b/api/tests/test_utils/database.py
index d32d99ea4..af814d3c3 100644
--- a/api/tests/test_utils/database.py
+++ b/api/tests/test_utils/database.py
@@ -8,7 +8,7 @@
from tests.test_utils.db_utils import dump_database, is_test_db, dump_raw_database, empty_database
from database.database import Database
-from scripts.populate_db import DatabasePopulateHelper
+from scripts.populate_db_gtfs import GTFSDatabasePopulateHelper
from scripts.populate_db_test_data import DatabasePopulateTestDataHelper
import os
@@ -48,7 +48,7 @@ def populate_database(db: Database, data_dirs: str):
if len(csv_filepaths) == 0:
raise Exception("No sources_test.csv file found in test_data directories")
- db_helper = DatabasePopulateHelper(csv_filepaths)
+ db_helper = GTFSDatabasePopulateHelper(csv_filepaths)
db_helper.initialize(trigger_downstream_tasks=False)
# Make a list of all the extra_test_data.json files in the test_data directories and load the data
diff --git a/api/tests/unittest/test_feeds.py b/api/tests/unittest/test_feeds.py
index ffd814c07..3ee91f336 100644
--- a/api/tests/unittest/test_feeds.py
+++ b/api/tests/unittest/test_feeds.py
@@ -84,7 +84,7 @@ def test_feeds_get(client: TestClient, mocker):
mock_filter_offset = Mock()
mock_filter_order_by = Mock()
mock_options = Mock()
- mock_filter.return_value.order_by.return_value = mock_filter_order_by
+ mock_filter.return_value.filter.return_value.order_by.return_value = mock_filter_order_by
mock_filter_order_by.options.return_value = mock_options
mock_options.offset.return_value = mock_filter_offset
# Target is set to None as deep copy is failing for unknown reasons
@@ -119,7 +119,7 @@ def test_feed_get(client: TestClient, mocker):
Unit test for get_feeds
"""
mock_filter = mocker.patch.object(FeedFilter, "filter")
- mock_filter.return_value.first.return_value = mock_feed
+ mock_filter.return_value.filter.return_value.first.return_value = mock_feed
response = client.request(
"GET",
diff --git a/liquibase/changelog.xml b/liquibase/changelog.xml
index e20df7c95..d7e4259bc 100644
--- a/liquibase/changelog.xml
+++ b/liquibase/changelog.xml
@@ -27,4 +27,5 @@
+
\ No newline at end of file
diff --git a/liquibase/changes/feat_565.sql b/liquibase/changes/feat_565.sql
new file mode 100644
index 000000000..f3c98c238
--- /dev/null
+++ b/liquibase/changes/feat_565.sql
@@ -0,0 +1,22 @@
+ALTER TYPE datatype ADD VALUE IF NOT EXISTS 'gbfs';
+
+-- Create the tables if they do not exist
+CREATE TABLE IF NOT EXISTS GBFS_Feed(
+ id VARCHAR(255) PRIMARY KEY,
+ operator VARCHAR(255),
+ operator_url VARCHAR(255),
+ auto_discovery_url VARCHAR(255),
+ FOREIGN KEY (id) REFERENCES Feed(id)
+);
+
+CREATE TABLE IF NOT EXISTS GBFS_Version(
+ feed_id VARCHAR(255) NOT NULL,
+ version VARCHAR(6),
+ url VARCHAR(255),
+ PRIMARY KEY (feed_id, version),
+ FOREIGN KEY (feed_id) REFERENCES GBFS_Feed(id)
+);
+
+-- Rename tables to use convention like GBFSFeed and GBFSVersion
+ALTER TABLE GBFS_Feed RENAME TO GBFSFeed;
+ALTER TABLE GBFS_Version RENAME TO GBFSVersion;
diff --git a/scripts/populate-db.sh b/scripts/populate-db.sh
index e510ce688..6f1ef2d43 100755
--- a/scripts/populate-db.sh
+++ b/scripts/populate-db.sh
@@ -5,10 +5,21 @@
# As a requirement, you need to have the local instance of the database running on the port defined in config/.env.local
# The csv file containing the data has to be in the same format as https://bit.ly/catalogs-csv
# Usage:
-# populate-db.sh
+# populate-db.sh [data_type]
#
# relative path
SCRIPT_PATH="$(dirname -- "${BASH_SOURCE[0]}")"
-(cd $SCRIPT_PATH/../api/ && pip3 install -r requirements.txt && PYTHONPATH=src python src/scripts/populate_db.py --filepath "$1")
\ No newline at end of file
+# Set the data_type, defaulting to 'gtfs'
+DATA_TYPE=${2:-gtfs}
+
+# Determine the script to run based on the data_type
+if [ "$DATA_TYPE" = "gbfs" ]; then
+ SCRIPT_NAME="populate_db_gbfs.py"
+else
+ SCRIPT_NAME="populate_db_gtfs.py"
+fi
+
+# Run the appropriate script
+(cd "$SCRIPT_PATH"/../api/ && pip3 install -r requirements.txt && PYTHONPATH=src python src/scripts/$SCRIPT_NAME --filepath "$1")