From 2f3dfa6050526730c173673c6f7e5135396af404 Mon Sep 17 00:00:00 2001 From: Fabian Seitz Date: Tue, 24 Dec 2024 16:48:34 +0100 Subject: [PATCH] fix blocking call & switch to aiohttp --- custom_components/deutschebahn/sensor.py | 44 +++++++++++++++++------- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/custom_components/deutschebahn/sensor.py b/custom_components/deutschebahn/sensor.py index 7dd1501..d3c12c9 100644 --- a/custom_components/deutschebahn/sensor.py +++ b/custom_components/deutschebahn/sensor.py @@ -2,6 +2,8 @@ import logging from typing import Optional import async_timeout +import aiohttp +import asyncio from deutsche_bahn_api.api_authentication import ApiAuthentication from deutsche_bahn_api.station_helper import StationHelper from deutsche_bahn_api.timetable_helper import TimetableHelper @@ -66,8 +68,8 @@ def __init__(self, config, hass: core.HomeAssistant, scan_interval: timedelta): config[CONF_CLIENT_ID], config[CONF_CLIENT_SECRET], ) - if not self.api_auth.test_credentials(): - raise ValueError("Invalid Deutsche Bahn API credentials.") + + self.hass.async_create_task(self._test_credentials()) # Station helpers self.station_helper = StationHelper() @@ -76,6 +78,16 @@ def __init__(self, config, hass: core.HomeAssistant, scan_interval: timedelta): # Connections self.connections = [] + async def _test_credentials(self): + """Test API login credentials using asyncio.to_thread.""" + try: + result = await asyncio.to_thread(self.api_auth.test_credentials) + if not result: + _LOGGER.error("Invalid Deutsche Bahn API credentials.") + raise ValueError("Invalid Deutsche Bahn API credentials.") + except Exception as e: + _LOGGER.error(f"Error during Deutsche Bahn API credentials test: {e}") + @property def name(self): """Return the name of the sensor.""" @@ -130,26 +142,34 @@ async def async_update(self): with async_timeout.timeout(30): _LOGGER.debug(f"Updating data for {self.start} -> {self.goal}") - # Find stations - start_station = self.station_helper.find_stations_by_name(self.start)[0] - goal_station = self.station_helper.find_stations_by_name(self.goal)[0] + # Find stations (no await needed as find_stations_by_name() is not async) + start_station = await asyncio.to_thread(self.station_helper.find_stations_by_name, self.start) + goal_station = await asyncio.to_thread(self.station_helper.find_stations_by_name, self.goal) # Initialize timetable helper - self.timetable_helper = TimetableHelper(start_station, self.api_auth) + self.timetable_helper = TimetableHelper(start_station[0], self.api_auth) - # Get timetable - raw_connections = self.timetable_helper.get_timetable() - self.connections = self.timetable_helper.get_timetable_changes(raw_connections) + # Use asyncio.to_thread to offload the blocking call + raw_connections = await asyncio.to_thread(self.timetable_helper.get_timetable) + self.connections = await asyncio.to_thread(self.timetable_helper.get_timetable_changes, raw_connections) if not self.connections: self.connections = [] self._available = True if self.connections: + # Log the first connection to inspect its structure + _LOGGER.debug(f"First connection: {self.connections[0]}") + first_connection = self.connections[0] - departure_time = first_connection.get("departure") - delay = first_connection.get("delay", 0) - self._state = f"{departure_time} (+{delay})" if delay else departure_time + + # If 'first_connection' is a 'Train' object, access its attributes directly: + if hasattr(first_connection, "departure"): + departure_time = first_connection.departure + delay = getattr(first_connection, "delay", 0) + self._state = f"{departure_time} (+{delay})" if delay else departure_time + else: + _LOGGER.error(f"First connection does not have the expected attributes: {first_connection}") except Exception as e: self._available = False