Skip to content

Commit

Permalink
fix blocking call & switch to aiohttp
Browse files Browse the repository at this point in the history
  • Loading branch information
FaserF committed Dec 24, 2024
1 parent f9b34a4 commit 2f3dfa6
Showing 1 changed file with 32 additions and 12 deletions.
44 changes: 32 additions & 12 deletions custom_components/deutschebahn/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
from typing import Optional
import async_timeout
import aiohttp
import asyncio
from deutsche_bahn_api.api_authentication import ApiAuthentication
from deutsche_bahn_api.station_helper import StationHelper
from deutsche_bahn_api.timetable_helper import TimetableHelper
Expand Down Expand Up @@ -66,8 +68,8 @@ def __init__(self, config, hass: core.HomeAssistant, scan_interval: timedelta):
config[CONF_CLIENT_ID],
config[CONF_CLIENT_SECRET],
)
if not self.api_auth.test_credentials():
raise ValueError("Invalid Deutsche Bahn API credentials.")

self.hass.async_create_task(self._test_credentials())

# Station helpers
self.station_helper = StationHelper()
Expand All @@ -76,6 +78,16 @@ def __init__(self, config, hass: core.HomeAssistant, scan_interval: timedelta):
# Connections
self.connections = []

async def _test_credentials(self):
"""Test API login credentials using asyncio.to_thread."""
try:
result = await asyncio.to_thread(self.api_auth.test_credentials)
if not result:
_LOGGER.error("Invalid Deutsche Bahn API credentials.")
raise ValueError("Invalid Deutsche Bahn API credentials.")
except Exception as e:
_LOGGER.error(f"Error during Deutsche Bahn API credentials test: {e}")

@property
def name(self):
"""Return the name of the sensor."""
Expand Down Expand Up @@ -130,26 +142,34 @@ async def async_update(self):
with async_timeout.timeout(30):
_LOGGER.debug(f"Updating data for {self.start} -> {self.goal}")

# Find stations
start_station = self.station_helper.find_stations_by_name(self.start)[0]
goal_station = self.station_helper.find_stations_by_name(self.goal)[0]
# Find stations (no await needed as find_stations_by_name() is not async)
start_station = await asyncio.to_thread(self.station_helper.find_stations_by_name, self.start)
goal_station = await asyncio.to_thread(self.station_helper.find_stations_by_name, self.goal)

# Initialize timetable helper
self.timetable_helper = TimetableHelper(start_station, self.api_auth)
self.timetable_helper = TimetableHelper(start_station[0], self.api_auth)

# Get timetable
raw_connections = self.timetable_helper.get_timetable()
self.connections = self.timetable_helper.get_timetable_changes(raw_connections)
# Use asyncio.to_thread to offload the blocking call
raw_connections = await asyncio.to_thread(self.timetable_helper.get_timetable)
self.connections = await asyncio.to_thread(self.timetable_helper.get_timetable_changes, raw_connections)

if not self.connections:
self.connections = []
self._available = True

if self.connections:
# Log the first connection to inspect its structure
_LOGGER.debug(f"First connection: {self.connections[0]}")

first_connection = self.connections[0]
departure_time = first_connection.get("departure")
delay = first_connection.get("delay", 0)
self._state = f"{departure_time} (+{delay})" if delay else departure_time

# If 'first_connection' is a 'Train' object, access its attributes directly:
if hasattr(first_connection, "departure"):
departure_time = first_connection.departure
delay = getattr(first_connection, "delay", 0)
self._state = f"{departure_time} (+{delay})" if delay else departure_time
else:
_LOGGER.error(f"First connection does not have the expected attributes: {first_connection}")

except Exception as e:
self._available = False
Expand Down

0 comments on commit 2f3dfa6

Please sign in to comment.