Skip to content

Commit

Permalink
Extend api capabilities (#16)
Browse files Browse the repository at this point in the history
* Extend the existing api capabilities

- connections: add all missing parameters
- stationboard: add all missing parameters
- urlencode: Add "True" flag for correct formatting

* Add locations endpoint and update example

* update version to 0.4.0 and update changelog

* fix black linting

* Fix AttributeError

* Update changelog entry

---------

Co-authored-by: Fabian Affolter <[email protected]>
  • Loading branch information
miaucl and fabaff authored Nov 20, 2023
1 parent b3a5599 commit 2077fe0
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 8 deletions.
14 changes: 14 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
Changes
=======

20231120 - 0.4.0
----------------

- Add support locations API (thanks @miaucl)
- Add missing connections parameters (thanks @miaucl)
- Add missing stationboard parameters (thanks @miaucl)
- Add "True" flag for correct formatting of lists (ex. via parameter `via[]=foo1&via[]=foo2`) (thanks @miaucl)

20211124 - 0.3.0
----------------

- Don't use async timeout (thanks @agners)
- Remove loop

20210317 - 0.2.2
----------------

Expand Down
23 changes: 23 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,34 @@

from opendata_transport import OpendataTransport
from opendata_transport import OpendataTransportStationboard
from opendata_transport import OpendataTransportLocation


async def main():
"""Example for getting the data."""
async with aiohttp.ClientSession() as session:
# Search a station by query
locations = OpendataTransportLocation(session, query="Stettb")
await locations.async_get_data()

# Print the locations data
print(locations.locations)

# Print as list
print(list(map(lambda x: x["name"], locations.locations)))

# Search a station by coordinates
locations = OpendataTransportLocation(session, x=47.2, y=8.7)
await locations.async_get_data()

# Print the locations data
print(locations.locations)

# Print as list
print(list(map(lambda x: x["name"], locations.locations)))

print()

# Get the connection for a defined route
connection = OpendataTransport(
"Zürich, Blumenfeldstrasse", "Zürich Oerlikon, Bahnhof", session, 4
Expand Down
173 changes: 165 additions & 8 deletions opendata_transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,110 @@ def __init__(self, session):
@staticmethod
def get_url(resource, params):
"""Generate the URL for the request."""
param = urllib.parse.urlencode(params)
param = urllib.parse.urlencode(params, True)
url = "{resource_url}{resource}?{param}".format(
resource_url=_RESOURCE_URL, resource=resource, param=param
)
print(url)
return url


class OpendataTransportLocation(OpendataTransportBase):
"""A class for handling locations from Opendata Transport."""

def __init__(self, session, query=None, x=None, y=None, type_="all", fields=None):
"""Initialize the location."""
super().__init__(session)

self.query = query
self.x = x
self.y = y
self.type = type_
self.fields = (
fields if fields is not None and isinstance(fields, list) else None
)

self.from_name = self.from_id = self.to_name = self.to_id = None

self.locations = []

@staticmethod
def get_station(station):
"""Get the station details."""
return {
"name": station["name"],
"score": station["score"],
"coordinate_type": station["coordinate"]["type"],
"x": station["coordinate"]["x"],
"y": station["coordinate"]["y"],
"distance": station["distance"],
}

async def async_get_data(self):
"""Retrieve the data for the location."""
params = {}
if self.query is not None:
params["query"] = self.query
else:
params["x"] = self.x
params["y"] = self.y

if self.fields:
params["fields"] = self.fields

url = self.get_url("locations", params)

try:
response = await self._session.get(url, raise_for_status=True)

_LOGGER.debug("Response from transport.opendata.ch: %s", response.status)
data = await response.json()
_LOGGER.debug(data)
except asyncio.TimeoutError:
_LOGGER.error("Can not load data from transport.opendata.ch")
raise exceptions.OpendataTransportConnectionError()
except aiohttp.ClientError as aiohttpClientError:
_LOGGER.error("Response from transport.opendata.ch: %s", aiohttpClientError)
raise exceptions.OpendataTransportConnectionError()

try:
for station in data["stations"]:
self.locations.append(self.get_station(station))
except (TypeError, IndexError):
raise exceptions.OpendataTransportError()


class OpendataTransportStationboard(OpendataTransportBase):
"""A class for handling stationsboards from Opendata Transport."""

def __init__(self, station, session, limit=5):
def __init__(
self,
station,
session,
limit=5,
transportations=None,
datetime=None,
type_="departure",
fields=None,
):
"""Initialize the journey."""
super().__init__(session)

self.station = station
self.limit = limit
self.datetime = datetime
self.transportations = (
transportations
if transportations is not None and isinstance(transportations, list)
else None
)
self.type = type_
self.fields = (
fields if fields is not None and isinstance(fields, list) else None
)

self.from_name = self.from_id = self.to_name = self.to_id = None

self.journeys = []

@staticmethod
Expand All @@ -53,11 +141,20 @@ def get_journey(journey):

async def __async_get_data(self, station):
"""Retrieve the data for the station."""
params = {"limit": self.limit}
params = {
"limit": self.limit,
"type": self.type,
}
if str.isdigit(station):
params["id"] = station
else:
params["station"] = station
if self.datetime:
params["datetime"] = self.date
if self.transportations:
params["transportations"] = self.transportations
if self.fields:
params["fields"] = self.fields

url = self.get_url("stationboard", params)

Expand Down Expand Up @@ -94,13 +191,52 @@ async def async_get_data(self):
class OpendataTransport(OpendataTransportBase):
"""A class for handling connections from Opendata Transport."""

def __init__(self, start, destination, session, limit=3):
def __init__(
self,
start,
destination,
session,
limit=3,
page=0,
date=None,
time=None,
isArrivalTime=False,
transportations=None,
direct=False,
sleeper=False,
couchette=False,
bike=False,
accessibility=None,
via=None,
fields=None,
):
"""Initialize the connection."""
super().__init__(session)

self.limit = limit
self.page = page
self.start = start
self.destination = destination
self.via = via[:5] if via is not None and isinstance(via, list) else None
self.date = date
self.time = time
self.isArrivalTime = 1 if isArrivalTime else 0
self.transportations = (
transportations
if transportations is not None and isinstance(transportations, list)
else None
)
self.direct = 1 if direct else 0
self.sleeper = 1 if sleeper else 0
self.couchette = 1 if couchette else 0
self.bike = 1 if bike else 0
self.accessibility = accessibility
self.fields = (
fields if fields is not None and isinstance(fields, list) else None
)

self.from_name = self.from_id = self.to_name = self.to_id = None

self.connections = dict()

@staticmethod
Expand All @@ -125,10 +261,31 @@ def get_connection(connection):

async def async_get_data(self):
"""Retrieve the data for the connection."""
url = self.get_url(
"connections",
{"from": self.start, "to": self.destination, "limit": self.limit},
)
params = {
"from": self.start,
"to": self.destination,
"limit": self.limit,
"page": self.page,
"isArrivalTime": self.isArrivalTime,
"direct": self.direct,
"sleeper": self.sleeper,
"couchette": self.couchette,
"bike": self.bike,
}
if self.via:
params["via"] = self.via
if self.time:
params["time"] = self.time
if self.date:
params["date"] = self.date
if self.transportations:
params["transportations"] = self.transportations
if self.accessibility:
params["accessibility"] = self.accessibility
if self.fields:
params["fields"] = self.fields

url = self.get_url("connections", params)

try:
response = await self._session.get(url, raise_for_status=True)
Expand Down

0 comments on commit 2077fe0

Please sign in to comment.