diff --git a/src/evohomeasync2/broker.py b/src/evohomeasync2/broker.py index 1f603e21..de787cee 100644 --- a/src/evohomeasync2/broker.py +++ b/src/evohomeasync2/broker.py @@ -4,9 +4,10 @@ from __future__ import annotations import logging +from abc import ABC, abstractmethod from datetime import datetime as dt, timedelta as td from http import HTTPMethod, HTTPStatus -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, Final, NotRequired, TypedDict import aiohttp import voluptuous as vol @@ -21,9 +22,16 @@ CREDS_USER_PASSWORD, URL_BASE, ) -from .schema import SCH_OAUTH_TOKEN, SZ_ACCESS_TOKEN, SZ_EXPIRES_IN, SZ_REFRESH_TOKEN +from .schema import ( + SCH_OAUTH_TOKEN, + SZ_ACCESS_TOKEN, + SZ_ACCESS_TOKEN_EXPIRES, + SZ_EXPIRES_IN, + SZ_REFRESH_TOKEN, +) if TYPE_CHECKING: + from . import EvohomeClient from .schema import _EvoDictT, _EvoListT, _EvoSchemaT @@ -50,6 +58,19 @@ } +class OAuthTokenData(TypedDict): + access_token: str + expires_in: int # number of seconds + refresh_token: str + + +class _EvoTokenData(TypedDict): + access_token: str + access_token_expires: str # dt.isoformat() + refresh_token: str + username: NotRequired[str] + + class Broker: """A class for interacting with the Evohome API.""" @@ -295,3 +316,151 @@ async def put( raise exc.RequestFailed(str(err)) from err return content + + +class AbstractTokenManager(ABC): + """Abstract class to manage an OAuth access token and its refresh token.""" + + access_token: str + access_token_expires: dt + refresh_token: str + + def __init__( + self, + username: str, + password: str, + websession: aiohttp.ClientSession, + ) -> None: + """Initialize the token manager.""" + + self._user_credentials = { + "Username": username, + "Password": password, + } # TODO: are only ever PascalCase? + + self.websession = websession + + self._token_data_reset() + + @property + def username(self) -> str: + """Return the username.""" + return self._user_credentials["Username"] + + def _token_data_reset(self) -> None: + self.access_token = "" + self.access_token_expires = dt.now() + self.refresh_token = "" + + def _token_data_from_dict(self, tokens: _EvoTokenData) -> None: + self.access_token = tokens[SZ_ACCESS_TOKEN] + self.access_token_expires = dt.fromisoformat(tokens[SZ_ACCESS_TOKEN_EXPIRES]) + self.refresh_token = tokens[SZ_REFRESH_TOKEN] + + # HACK: sometimes using evo, not self + def _token_data_as_dict(self, evo: EvohomeClient | None) -> _EvoTokenData: + if evo is not None: + return { + SZ_ACCESS_TOKEN: evo.access_token, + SZ_ACCESS_TOKEN_EXPIRES: evo.access_token_expires.isoformat(), + SZ_REFRESH_TOKEN: evo.refresh_token, + } + return { + SZ_ACCESS_TOKEN: self.access_token, + SZ_ACCESS_TOKEN_EXPIRES: self.access_token_expires.isoformat(), + SZ_REFRESH_TOKEN: self.refresh_token, + } + + async def is_token_data_valid(self) -> bool: + """Return True if we have a valid access token.""" + return bool(self.access_token) and dt.now() < self.access_token_expires + + async def fetch_access_token(self) -> None: # HA api + """Obtain a new access token from the vendor (as ours is expired/invalid). + + First, try using the refresh token, if one is available, otherwise authenticate + using the user credentials. + """ + """Fetch an access token. + + If there is a valid cached token use that, otherwise fetch via the web API. + """ + + _LOGGER.debug("No/Expired/Invalid access_token, re-authenticating.") + if self.refresh_token: + _LOGGER.debug("Authenticating with the refresh_token...") + try: + response = await self._request_access_token( + CREDS_REFRESH_TOKEN | {SZ_REFRESH_TOKEN: self.refresh_token} + ) + except exc.AuthenticationFailed as err: + if err.status != HTTPStatus.BAD_REQUEST: # e.g. != invalid tokens + raise + + _LOGGER.warning( + "Likely Invalid refresh_token (will try username/password)" + ) + self.refresh_token = None + + if self.refresh_token is None: + _LOGGER.debug("Authenticating with username/password...") + response = await self._request_access_token( + CREDS_USER_PASSWORD | self._user_credentials + ) + + try: + token_data = SCH_OAUTH_TOKEN(await response.json()) + except vol.Invalid as err: + raise exc.AuthenticationFailed(f"Server response invalid: {err}") from err + + try: + self.access_token = token_data[SZ_ACCESS_TOKEN] + self.access_token_expires = dt.now() + td( + seconds=token_data[SZ_EXPIRES_IN] - 15 + ) + self.refresh_token = token_data[SZ_REFRESH_TOKEN] + except (KeyError, TypeError) as err: + raise exc.AuthenticationFailed(f"Server response invalid: {err}") from err + + _LOGGER.debug(f"refresh_token = {self.refresh_token}") + _LOGGER.debug(f"access_token = {self.access_token}") + _LOGGER.debug(f"access_token_expires = {self.access_token_expires}") + + async def _request_access_token(self, **kwargs: Any) -> aiohttp.ClientResponse: + """Fetch an access token via the vendor's web API.""" + + try: + response = await self._request(HTTPMethod.POST, AUTH_URL, **kwargs) + response.raise_for_status() + + except aiohttp.ClientResponseError as err: + if hint := _ERR_MSG_LOOKUP_AUTH.get(err.status): + raise exc.AuthenticationFailed(hint, status=err.status) from err + raise exc.AuthenticationFailed(str(err), status=err.status) from err + + except aiohttp.ClientError as err: # e.g. ClientConnectionError + raise exc.AuthenticationFailed(str(err)) from err + + if response.content_type != "application/json": # or likely "text/html" + # Authorize error <h1>Authorization failed + # <p>The authorization server have encoutered an error while processing... + content = await response.text() + raise exc.AuthenticationFailed( + f"Server response is not JSON: {HTTPMethod.POST} {AUTH_URL}: {content}" + ) + + return response + + async def _request( + self, method: HTTPMethod, url: str, **kwargs: Any + ) -> aiohttp.ClientResponse: + # The credentials can be either a refresh token or username + + async with self.websession.request( + method, url, headers=AUTH_HEADER, json=kwargs["credentials"] + ) as response: + return response + + @abstractmethod + async def save_access_token(self) -> None: # HA: api + """Save the access token to a cache.""" diff --git a/src/evohomeasync2/client.py b/src/evohomeasync2/client.py index 57bf4826..806604b2 100644 --- a/src/evohomeasync2/client.py +++ b/src/evohomeasync2/client.py @@ -4,16 +4,20 @@ import asyncio import json import logging -import os import sys -from datetime import datetime as dt +import tempfile from io import TextIOWrapper +from pathlib import Path from typing import Final -import click +import aiofiles +import aiofiles.os +import aiohttp +import asyncclick as click from . import HotWater, Zone from .base import EvohomeClient +from .broker import AbstractTokenManager, _EvoTokenData from .const import SZ_NAME, SZ_SCHEDULE from .controlsystem import ControlSystem from .schema import SZ_ACCESS_TOKEN, SZ_ACCESS_TOKEN_EXPIRES, SZ_REFRESH_TOKEN @@ -27,10 +31,11 @@ SZ_CACHE_TOKENS: Final = "cache_tokens" SZ_EVO: Final = "evo" +SZ_TOKEN_MANAGER: Final = "token_manager" SZ_USERNAME: Final = "username" +SZ_WEBSESSION: Final = "websession" -TOKEN_FILE: Final = ".evo-cache.tmp" - +TOKEN_CACHE: Final = Path(tempfile.gettempdir() + "/.evo-cache.tmp") _LOGGER: Final = logging.getLogger(__name__) @@ -83,41 +88,70 @@ def _get_tcs(evo: EvohomeClient, loc_idx: int | None) -> ControlSystem: return evo.locations[int(loc_idx)]._gateways[0]._control_systems[0] -def _dump_tokens(evo: EvohomeClient) -> None: - """Dump the tokens to a cache (temporary file).""" +class TokenManager(AbstractTokenManager): + """A token manager that uses a cache file to store the tokens.""" - expires = evo.access_token_expires.isoformat() if evo.access_token_expires else None + def __init__( + self, + username: str, + password: str, + websession: aiohttp.ClientSession, + /, + *, + token_cache: Path | None = None, + ) -> None: + super().__init__(username, password, websession) - with open(TOKEN_FILE, "w") as fp: - json.dump( - { - # SZ_USERNAME: evo.username, - SZ_REFRESH_TOKEN: evo.refresh_token, - SZ_ACCESS_TOKEN: evo.access_token, - SZ_ACCESS_TOKEN_EXPIRES: expires, - }, - fp, - ) + self._token_cache = token_cache + + @property + def token_cache(self) -> str: + """Return the token cache path.""" + return str(self._token_cache) + + async def fetch_access_token(self) -> None: # HA api + """If required, fetch an (updated) access token (somehow). + + If there is a valid cached token use that, otherwise fetch via the web API. + """ + + if self.is_token_data_valid(): + return + + self._load_access_token() + + if not self.is_token_data_valid(): + await super().fetch_access_token() + self.save_access_token() - _LOGGER.warning("Access tokens cached to: %s", TOKEN_FILE) + async def _load_access_token(self) -> None: + """Load the tokens from a cache (temporary file).""" + self._token_data_reset() -def _load_tokens() -> dict[str, dt | str]: - """Load the tokens from a cache (temporary file).""" + try: + async with aiofiles.open(self._token_cache) as fp: + content = await fp.read() + except FileNotFoundError: + return - if not os.path.exists(TOKEN_FILE): - return {} + try: + tokens: _EvoTokenData = json.loads(content) + except json.JSONDecodeError: + return - with open(TOKEN_FILE) as f: - tokens = json.load(f) + if tokens.pop(SZ_USERNAME) == self.username: + self._token_data_from_dict(tokens) - if SZ_ACCESS_TOKEN_EXPIRES not in tokens: - return tokens # type: ignore[no-any-return] + async def save_access_token(self, evo: EvohomeClient) -> None: # HA api + """Dump the tokens to a cache (temporary file).""" - if expires := tokens[SZ_ACCESS_TOKEN_EXPIRES]: - tokens[SZ_ACCESS_TOKEN_EXPIRES] = dt.fromisoformat(expires) + content = json.dumps( + {SZ_USERNAME: self.username} | self._token_data_as_dict(evo) + ) - return tokens # type: ignore[no-any-return] + async with aiofiles.open(self._token_cache, "w") as fp: + await fp.write(content) @click.group() @@ -126,11 +160,10 @@ def _load_tokens() -> dict[str, dt | str]: @click.option("--cache-tokens", "-c", is_flag=True, help="Use a token cache.") @click.option("--debug", "-d", is_flag=True, help="Enable debug logging.") @click.pass_context -def cli( +async def cli( ctx: click.Context, username: str, password: str, - location: int | None = None, cache_tokens: bool | None = None, debug: bool | None = None, ) -> None: @@ -144,16 +177,30 @@ def cli( stream=sys.stdout, ) - ctx.obj = ctx.obj or {} # may be None - ctx.obj[SZ_CACHE_TOKENS] = cache_tokens + ctx.obj[SZ_WEBSESSION] = websession = ( + aiohttp.ClientSession() + ) # timeout=aiohttp.ClientTimeout(total=30)) - tokens = _load_tokens() if cache_tokens else {} + ctx.obj[SZ_TOKEN_MANAGER] = token_manager = TokenManager( + username, password, websession, token_cache=TOKEN_CACHE + ) + + if not cache_tokens: + tokens = {} + else: + await token_manager._load_access_token() # not: fetch_access_token() + tokens = { + SZ_ACCESS_TOKEN: token_manager.access_token, + SZ_ACCESS_TOKEN_EXPIRES: token_manager.access_token_expires, + SZ_REFRESH_TOKEN: token_manager.refresh_token, + } ctx.obj[SZ_EVO] = EvohomeClient( username, password, + **tokens, + session=websession, debug=bool(debug), - **tokens, # type: ignore[arg-type] ) @@ -170,7 +217,7 @@ def cli( "--filename", "-f", type=click.File("w"), default="-", help="The output file." ) @click.pass_context -def dump(ctx: click.Context, loc_idx: int, filename: TextIOWrapper) -> None: +async def dump(ctx: click.Context, loc_idx: int, filename: TextIOWrapper) -> None: """Download all the global config and the location status.""" async def get_state(evo: EvohomeClient, loc_idx: int | None) -> _ScheduleT: @@ -191,13 +238,12 @@ async def get_state(evo: EvohomeClient, loc_idx: int | None) -> _ScheduleT: print("\r\nclient.py: Starting dump of config and status...") evo: EvohomeClient = ctx.obj[SZ_EVO] - coro = get_state(evo, loc_idx) - result = asyncio.get_event_loop().run_until_complete(coro) + result = await get_state(evo, loc_idx) filename.write(json.dumps(result, indent=4) + "\r\n\r\n") - if ctx.obj[SZ_CACHE_TOKENS]: - _dump_tokens(evo) + await ctx.obj[SZ_TOKEN_MANAGER].save_access_token(ctx.obj[SZ_EVO]) + result = await ctx.obj[SZ_WEBSESSION].close() print(" - finished.\r\n") @@ -215,7 +261,7 @@ async def get_state(evo: EvohomeClient, loc_idx: int | None) -> _ScheduleT: "--filename", "-f", type=click.File("w"), default="-", help="The output file." ) @click.pass_context -def get_schedule( +async def get_schedule( ctx: click.Context, zone_id: str, loc_idx: int, filename: TextIOWrapper ) -> None: """Download the schedule of a zone of a TCS (WIP).""" @@ -245,13 +291,11 @@ async def get_schedule( print("\r\nclient.py: Starting backup of zone schedule (WIP)...") evo = ctx.obj[SZ_EVO] - coro = get_schedule(evo, zone_id, loc_idx) - schedule = asyncio.get_event_loop().run_until_complete(coro) + schedule = await get_schedule(evo, zone_id, loc_idx) filename.write(json.dumps(schedule, indent=4) + "\r\n\r\n") - if ctx.obj[SZ_CACHE_TOKENS]: - _dump_tokens(evo) + await ctx.obj[SZ_TOKEN_MANAGER].save_access_token(evo) print(" - finished.\r\n") @@ -268,7 +312,9 @@ async def get_schedule( "--filename", "-f", type=click.File("w"), default="-", help="The output file." ) @click.pass_context -def get_schedules(ctx: click.Context, loc_idx: int, filename: TextIOWrapper) -> None: +async def get_schedules( + ctx: click.Context, loc_idx: int, filename: TextIOWrapper +) -> None: """Download all the schedules from a TCS.""" async def get_schedules(evo: EvohomeClient, loc_idx: int | None) -> _ScheduleT: @@ -287,13 +333,11 @@ async def get_schedules(evo: EvohomeClient, loc_idx: int | None) -> _ScheduleT: print("\r\nclient.py: Starting backup of schedules...") evo: EvohomeClient = ctx.obj[SZ_EVO] - coro = get_schedules(evo, loc_idx) - schedules = asyncio.get_event_loop().run_until_complete(coro) + schedules = await get_schedules(evo, loc_idx) filename.write(json.dumps(schedules, indent=4) + "\r\n\r\n") - if ctx.obj[SZ_CACHE_TOKENS]: - _dump_tokens(evo) + await ctx.obj[SZ_TOKEN_MANAGER].save_access_token(evo) print(" - finished.\r\n") @@ -308,7 +352,9 @@ async def get_schedules(evo: EvohomeClient, loc_idx: int | None) -> _ScheduleT: ) @click.option("--filename", "-f", type=click.File(), help="The input file.") @click.pass_context -def set_schedules(ctx: click.Context, loc_idx: int, filename: TextIOWrapper) -> None: +async def set_schedules( + ctx: click.Context, loc_idx: int, filename: TextIOWrapper +) -> None: """Upload schedules to a TCS.""" async def set_schedules( @@ -331,17 +377,17 @@ async def set_schedules( schedules = json.loads(filename.read()) - coro = set_schedules(evo, schedules, loc_idx) - success = asyncio.get_event_loop().run_until_complete(coro) + success = await set_schedules(evo, schedules, loc_idx) - if ctx.obj[SZ_CACHE_TOKENS]: - _dump_tokens(evo) + await ctx.obj[SZ_TOKEN_MANAGER].save_access_token(evo) print(f" - finished{'' if success else ' (with errors)'}.\r\n") def main() -> None: + """Run the CLI.""" + try: - cli(obj={}) # default for ctx.obj is None + asyncio.run(cli(obj={})) # default for ctx.obj is None except click.ClickException as err: print(f"Error: {err}")