Skip to content

Commit

Permalink
feat: automate locations - modified API responses (#661)
Browse files Browse the repository at this point in the history
  • Loading branch information
cka-y authored Aug 13, 2024
1 parent 9f6000e commit daeef88
Show file tree
Hide file tree
Showing 20 changed files with 315 additions and 84 deletions.
3 changes: 2 additions & 1 deletion api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,5 @@ cloud-sql-python-connector[pg8000]
fastapi-filter[sqlalchemy]==1.0.0
PyJWT
shapely
google-cloud-pubsub
google-cloud-pubsub
pycountry
153 changes: 91 additions & 62 deletions api/src/feeds/impl/feeds_api_impl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from datetime import datetime
from typing import List, Union
from typing import List, Union, TypeVar

from sqlalchemy.orm import joinedload

from sqlalchemy.orm.query import Query
from database.database import Database
from database_gen.sqlacodegen_models import (
Feed,
Expand All @@ -12,6 +12,7 @@
Location,
Validationreport,
Entitytype,
t_location_with_translations_en,
)
from feeds.filters.feed_filter import FeedFilter
from feeds.filters.gtfs_dataset_filter import GtfsDatasetFilter
Expand All @@ -36,6 +37,9 @@
from feeds_gen.models.gtfs_feed import GtfsFeed
from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed
from utils.date_utils import valid_iso_date
from utils.location_translation import create_location_translation_object, LocationTranslation

T = TypeVar("T", bound="BasicFeed")


class FeedsApiImpl(BaseFeedsApi):
Expand Down Expand Up @@ -91,20 +95,35 @@ def get_gtfs_feed(
id: str,
) -> GtfsFeed:
"""Get the specified gtfs feed from the Mobility Database."""
feed = (
feed, translations = self._get_gtfs_feed(id)
if feed:
return GtfsFeedImpl.from_orm(feed, translations)
else:
raise_http_error(404, gtfs_feed_not_found.format(id))

@staticmethod
def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]:
results = (
FeedFilter(
stable_id=id,
stable_id=stable_id,
status=None,
provider__ilike=None,
producer_url__ilike=None,
)
.filter(Database().get_query_model(Gtfsfeed))
.first()
)
if feed:
return GtfsFeedImpl.from_orm(feed)
else:
raise_http_error(404, gtfs_feed_not_found.format(id))
.filter(Database().get_session().query(Gtfsfeed, t_location_with_translations_en))
.outerjoin(Location, Feed.locations)
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
.options(
joinedload(Gtfsfeed.gtfsdatasets)
.joinedload(Gtfsdataset.validation_reports)
.joinedload(Validationreport.notices),
*BasicFeedImpl.get_joinedload_options(),
)
).all()
if len(results) > 0 and results[0].Gtfsfeed:
translations = {result[1]: create_location_translation_object(result) for result in results}
return results[0].Gtfsfeed, translations
return None, {}

def get_gtfs_feed_datasets(
self,
Expand Down Expand Up @@ -176,43 +195,54 @@ def get_gtfs_feeds(
municipality__ilike=municipality,
),
)
gtfs_feed_query = gtfs_feed_filter.filter(Database().get_query_model(Gtfsfeed))

gtfs_feed_query = gtfs_feed_query.outerjoin(Location, Feed.locations).options(
joinedload(Gtfsfeed.gtfsdatasets)
.joinedload(Gtfsdataset.validation_reports)
.joinedload(Validationreport.notices),
*BasicFeedImpl.get_joinedload_options(),
gtfs_feed_query = gtfs_feed_filter.filter(
Database().get_session().query(Gtfsfeed, t_location_with_translations_en)
)
gtfs_feed_query = (
gtfs_feed_query.outerjoin(Location, Feed.locations)
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
.options(
joinedload(Gtfsfeed.gtfsdatasets)
.joinedload(Gtfsdataset.validation_reports)
.joinedload(Validationreport.notices),
*BasicFeedImpl.get_joinedload_options(),
)
.order_by(Gtfsfeed.provider, Gtfsfeed.stable_id)
)
gtfs_feed_query = gtfs_feed_query.order_by(Gtfsfeed.provider, Gtfsfeed.stable_id)
gtfs_feed_query = DatasetsApiImpl.apply_bounding_filtering(
gtfs_feed_query, dataset_latitudes, dataset_longitudes, bounding_filter_method
)
if limit is not None:
gtfs_feed_query = gtfs_feed_query.limit(limit)
if offset is not None:
gtfs_feed_query = gtfs_feed_query.offset(offset)
results = gtfs_feed_query.all()
return [GtfsFeedImpl.from_orm(gtfs_feed) for gtfs_feed in results]
return self._get_response(gtfs_feed_query, limit, offset, GtfsFeedImpl)

def get_gtfs_rt_feed(
self,
id: str,
) -> GtfsRTFeed:
"""Get the specified GTFS Realtime feed from the Mobility Database."""
feed = (
GtfsRtFeedFilter(
stable_id=id,
provider__ilike=None,
producer_url__ilike=None,
entity_types=None,
location=None,
)
.filter(Database().get_query_model(Gtfsrealtimefeed))
.first()
gtfs_rt_feed_filter = GtfsRtFeedFilter(
stable_id=id,
provider__ilike=None,
producer_url__ilike=None,
entity_types=None,
location=None,
)
if feed:
return GtfsRTFeedImpl.from_orm(feed)
results = gtfs_rt_feed_filter.filter(
Database()
.get_session()
.query(Gtfsrealtimefeed, t_location_with_translations_en)
.outerjoin(Location, Gtfsrealtimefeed.locations)
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
*BasicFeedImpl.get_joinedload_options(),
)
).all()

if len(results) > 0 and results[0].Gtfsrealtimefeed:
translations = {result[1]: create_location_translation_object(result) for result in results}
return GtfsRTFeedImpl.from_orm(results[0].Gtfsrealtimefeed, translations)
else:
raise_http_error(404, gtfs_rt_feed_not_found.format(id))

Expand Down Expand Up @@ -251,42 +281,41 @@ def get_gtfs_rt_feeds(
municipality__ilike=municipality,
),
)
gtfs_rt_feed_query = gtfs_rt_feed_filter.filter(Database().get_query_model(Gtfsrealtimefeed)).options(
*BasicFeedImpl.get_joinedload_options()
gtfs_rt_feed_query = gtfs_rt_feed_filter.filter(
Database().get_session().query(Gtfsrealtimefeed, t_location_with_translations_en)
)
gtfs_rt_feed_query = gtfs_rt_feed_query.outerjoin(Entitytype, Gtfsrealtimefeed.entitytypes).options(
joinedload(Gtfsrealtimefeed.entitytypes)
)
gtfs_rt_feed_query = gtfs_rt_feed_query.outerjoin(Location, Feed.locations).options(
joinedload(Gtfsrealtimefeed.locations)
)
gtfs_rt_feed_query = gtfs_rt_feed_query.outerjoin(Gtfsfeed, Gtfsrealtimefeed.gtfs_feeds).options(
joinedload(Gtfsrealtimefeed.gtfs_feeds)
gtfs_rt_feed_query = (
gtfs_rt_feed_query.outerjoin(Location, Gtfsrealtimefeed.locations)
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
.outerjoin(Entitytype, Gtfsrealtimefeed.entitytypes)
.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
*BasicFeedImpl.get_joinedload_options(),
)
.order_by(Gtfsrealtimefeed.provider, Gtfsrealtimefeed.stable_id)
)
gtfs_rt_feed_query = gtfs_rt_feed_query.order_by(Gtfsrealtimefeed.provider, Gtfsrealtimefeed.stable_id)
return self._get_response(gtfs_rt_feed_query, limit, offset, GtfsRTFeedImpl)

@staticmethod
def _get_response(feed_query: Query, limit: int, offset: int, impl_cls: type[T]) -> List[T]:
"""Get the response for the feed query."""
if limit is not None:
gtfs_rt_feed_query = gtfs_rt_feed_query.limit(limit)
feed_query = feed_query.limit(limit)
if offset is not None:
gtfs_rt_feed_query = gtfs_rt_feed_query.offset(offset)
results = gtfs_rt_feed_query.all()
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed) for gtfs_rt_feed in results]
feed_query = feed_query.offset(offset)
results = feed_query.all()
location_translations = {row[1]: create_location_translation_object(row) for row in results}
response = [impl_cls.from_orm(feed[0], location_translations) for feed in results]
return list({feed.id: feed for feed in response}.values())

def get_gtfs_feed_gtfs_rt_feeds(
self,
id: str,
) -> List[GtfsRTFeed]:
"""Get a list of GTFS Realtime related to a GTFS feed."""
feed = (
FeedFilter(
stable_id=id,
status=None,
provider__ilike=None,
producer_url__ilike=None,
)
.filter(Database().get_query_model(Gtfsfeed))
.first()
)
feed, translations = self._get_gtfs_feed(id)
if feed:
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed) for gtfs_rt_feed in feed.gtfs_rt_feeds]
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed, translations) for gtfs_rt_feed in feed.gtfs_rt_feeds]
else:
raise_http_error(404, gtfs_feed_not_found.format(id))
2 changes: 1 addition & 1 deletion api/src/feeds/impl/models/basic_feed_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Config:
from_attributes = True

@classmethod
def from_orm(cls, feed: Feed | None) -> BasicFeed | None:
def from_orm(cls, feed: Feed | None, _=None) -> BasicFeed | None:
if not feed:
return None
return cls(
Expand Down
12 changes: 9 additions & 3 deletions api/src/feeds/impl/models/gtfs_feed_impl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Dict

from database_gen.sqlacodegen_models import Gtfsfeed as GtfsfeedOrm
from feeds.impl.models.basic_feed_impl import BaseFeedImpl
from feeds.impl.models.latest_dataset_impl import LatestDatasetImpl
from feeds.impl.models.location_impl import LocationImpl
from feeds_gen.models.gtfs_feed import GtfsFeed
from utils.location_translation import LocationTranslation, translate_feed_locations


class GtfsFeedImpl(BaseFeedImpl, GtfsFeed):
Expand All @@ -17,12 +20,15 @@ class Config:
from_attributes = True

@classmethod
def from_orm(cls, feed: GtfsfeedOrm | None) -> GtfsFeed | None:
gtfs_feed = super().from_orm(feed)
def from_orm(
cls, feed: GtfsfeedOrm | None, location_translations: Dict[str, LocationTranslation] = None
) -> GtfsFeed | None:
if location_translations is not None:
translate_feed_locations(feed, location_translations)
gtfs_feed: GtfsFeed = super().from_orm(feed)
if not gtfs_feed:
return None
gtfs_feed.locations = [LocationImpl.from_orm(item) for item in feed.locations]

latest_dataset = next(
(dataset for dataset in feed.gtfsdatasets if dataset is not None and dataset.latest), None
)
Expand Down
11 changes: 9 additions & 2 deletions api/src/feeds/impl/models/gtfs_rt_feed_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Dict

from database_gen.sqlacodegen_models import Gtfsrealtimefeed as GtfsRTFeedOrm
from feeds.impl.models.basic_feed_impl import BaseFeedImpl
from feeds.impl.models.location_impl import LocationImpl
from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed
from utils.location_translation import LocationTranslation, translate_feed_locations


class GtfsRTFeedImpl(BaseFeedImpl, GtfsRTFeed):
Expand All @@ -14,8 +17,12 @@ class Config:
from_attributes = True

@classmethod
def from_orm(cls, feed: GtfsRTFeedOrm | None) -> GtfsRTFeed | None:
gtfs_rt_feed = super().from_orm(feed)
def from_orm(
cls, feed: GtfsRTFeedOrm | None, location_translations: Dict[str, LocationTranslation] = None
) -> GtfsRTFeed | None:
if location_translations is not None:
translate_feed_locations(feed, location_translations)
gtfs_rt_feed: GtfsRTFeed = super().from_orm(feed)
if not gtfs_rt_feed:
return None
gtfs_rt_feed.locations = [LocationImpl.from_orm(item) for item in feed.locations]
Expand Down
1 change: 1 addition & 0 deletions api/src/feeds/impl/models/location_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def from_orm(cls, location: LocationOrm | None) -> Location | None:
return None
return cls(
country_code=location.country_code,
country=location.country,
subdivision_name=location.subdivision_name,
municipality=location.municipality,
)
30 changes: 30 additions & 0 deletions api/src/feeds/impl/models/search_feed_item_result_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from feeds_gen.models.latest_dataset import LatestDataset
from feeds_gen.models.search_feed_item_result import SearchFeedItemResult
from feeds_gen.models.source_info import SourceInfo
import pycountry


class SearchFeedItemResultImpl(SearchFeedItemResult):
Expand Down Expand Up @@ -74,11 +75,40 @@ def from_orm_gtfs_rt(cls, feed_search_row):
feed_references=feed_search_row.feed_reference_ids,
)

@classmethod
def _translate_locations(cls, feed_search_row):
"""Translate location information in the feed search row."""
country_translations = cls._create_translation_dict(feed_search_row.country_translations)
subdivision_translations = cls._create_translation_dict(feed_search_row.subdivision_name_translations)
municipality_translations = cls._create_translation_dict(feed_search_row.municipality_translations)

for location in feed_search_row.locations:
location["country"] = country_translations.get(location["country"], location["country"])
if location["country"] is None:
location["country"] = pycountry.countries.get(alpha_2=location["country_code"]).name
location["subdivision_name"] = subdivision_translations.get(
location["subdivision_name"], location["subdivision_name"]
)
location["municipality"] = municipality_translations.get(location["municipality"], location["municipality"])

@staticmethod
def _create_translation_dict(translations):
"""Helper method to create a translation dictionary."""
if translations:
return {
elem.get("key"): elem.get("value") for elem in translations if elem.get("key") and elem.get("value")
}
return {}

@classmethod
def from_orm(cls, feed_search_row):
"""Create a model instance from a SQLAlchemy row object."""
if feed_search_row is None:
return None

# Translate location data
cls._translate_locations(feed_search_row)

match feed_search_row.data_type:
case "gtfs":
return cls.from_orm_gtfs(feed_search_row)
Expand Down
Loading

0 comments on commit daeef88

Please sign in to comment.