Skip to content

Commit

Permalink
add a token mager, use async-click
Browse files Browse the repository at this point in the history
  • Loading branch information
zxdavb committed Aug 20, 2024
1 parent a09666a commit a5e7c61
Show file tree
Hide file tree
Showing 2 changed files with 274 additions and 59 deletions.
173 changes: 171 additions & 2 deletions src/evohomeasync2/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from datetime import datetime as dt, timedelta as td
from http import HTTPMethod, HTTPStatus
from typing import TYPE_CHECKING, Any, Final
from typing import TYPE_CHECKING, Any, Final, NotRequired, TypedDict

import aiohttp
import voluptuous as vol
Expand All @@ -21,9 +22,16 @@
CREDS_USER_PASSWORD,
URL_BASE,
)
from .schema import SCH_OAUTH_TOKEN, SZ_ACCESS_TOKEN, SZ_EXPIRES_IN, SZ_REFRESH_TOKEN
from .schema import (
SCH_OAUTH_TOKEN,
SZ_ACCESS_TOKEN,
SZ_ACCESS_TOKEN_EXPIRES,
SZ_EXPIRES_IN,
SZ_REFRESH_TOKEN,
)

if TYPE_CHECKING:
from . import EvohomeClient
from .schema import _EvoDictT, _EvoListT, _EvoSchemaT


Expand All @@ -50,6 +58,19 @@
}


class OAuthTokenData(TypedDict):
access_token: str
expires_in: int # number of seconds
refresh_token: str


class _EvoTokenData(TypedDict):
access_token: str
access_token_expires: str # dt.isoformat()
refresh_token: str
username: NotRequired[str]


class Broker:
"""A class for interacting with the Evohome API."""

Expand Down Expand Up @@ -295,3 +316,151 @@ async def put(
raise exc.RequestFailed(str(err)) from err

return content


class AbstractTokenManager(ABC):
"""Abstract class to manage an OAuth access token and its refresh token."""

access_token: str
access_token_expires: dt
refresh_token: str

def __init__(
self,
username: str,
password: str,
websession: aiohttp.ClientSession,
) -> None:
"""Initialize the token manager."""

self._user_credentials = {
"Username": username,
"Password": password,
} # TODO: are only ever PascalCase?

self.websession = websession

self._token_data_reset()

@property
def username(self) -> str:
"""Return the username."""
return self._user_credentials["Username"]

def _token_data_reset(self) -> None:
self.access_token = ""
self.access_token_expires = dt.now()
self.refresh_token = ""

def _token_data_from_dict(self, tokens: _EvoTokenData) -> None:
self.access_token = tokens[SZ_ACCESS_TOKEN]
self.access_token_expires = dt.fromisoformat(tokens[SZ_ACCESS_TOKEN_EXPIRES])
self.refresh_token = tokens[SZ_REFRESH_TOKEN]

# HACK: sometimes using evo, not self
def _token_data_as_dict(self, evo: EvohomeClient | None) -> _EvoTokenData:
if evo is not None:
return {
SZ_ACCESS_TOKEN: evo.access_token,
SZ_ACCESS_TOKEN_EXPIRES: evo.access_token_expires.isoformat(),
SZ_REFRESH_TOKEN: evo.refresh_token,
}
return {
SZ_ACCESS_TOKEN: self.access_token,
SZ_ACCESS_TOKEN_EXPIRES: self.access_token_expires.isoformat(),
SZ_REFRESH_TOKEN: self.refresh_token,
}

async def is_token_data_valid(self) -> bool:
"""Return True if we have a valid access token."""
return bool(self.access_token) and dt.now() < self.access_token_expires

async def fetch_access_token(self) -> None: # HA api
"""Obtain a new access token from the vendor (as ours is expired/invalid).
First, try using the refresh token, if one is available, otherwise authenticate
using the user credentials.
"""
"""Fetch an access token.
If there is a valid cached token use that, otherwise fetch via the web API.
"""

_LOGGER.debug("No/Expired/Invalid access_token, re-authenticating.")
if self.refresh_token:
_LOGGER.debug("Authenticating with the refresh_token...")
try:
response = await self._request_access_token(
CREDS_REFRESH_TOKEN | {SZ_REFRESH_TOKEN: self.refresh_token}
)
except exc.AuthenticationFailed as err:
if err.status != HTTPStatus.BAD_REQUEST: # e.g. != invalid tokens
raise

_LOGGER.warning(
"Likely Invalid refresh_token (will try username/password)"
)
self.refresh_token = None

if self.refresh_token is None:
_LOGGER.debug("Authenticating with username/password...")
response = await self._request_access_token(
CREDS_USER_PASSWORD | self._user_credentials
)

try:
token_data = SCH_OAUTH_TOKEN(await response.json())
except vol.Invalid as err:
raise exc.AuthenticationFailed(f"Server response invalid: {err}") from err

try:
self.access_token = token_data[SZ_ACCESS_TOKEN]
self.access_token_expires = dt.now() + td(
seconds=token_data[SZ_EXPIRES_IN] - 15
)
self.refresh_token = token_data[SZ_REFRESH_TOKEN]
except (KeyError, TypeError) as err:
raise exc.AuthenticationFailed(f"Server response invalid: {err}") from err

_LOGGER.debug(f"refresh_token = {self.refresh_token}")
_LOGGER.debug(f"access_token = {self.access_token}")
_LOGGER.debug(f"access_token_expires = {self.access_token_expires}")

async def _request_access_token(self, **kwargs: Any) -> aiohttp.ClientResponse:
"""Fetch an access token via the vendor's web API."""

try:
response = await self._request(HTTPMethod.POST, AUTH_URL, **kwargs)
response.raise_for_status()

except aiohttp.ClientResponseError as err:
if hint := _ERR_MSG_LOOKUP_AUTH.get(err.status):
raise exc.AuthenticationFailed(hint, status=err.status) from err
raise exc.AuthenticationFailed(str(err), status=err.status) from err

except aiohttp.ClientError as err: # e.g. ClientConnectionError
raise exc.AuthenticationFailed(str(err)) from err

if response.content_type != "application/json": # or likely "text/html"
# <title>Authorize error <h1>Authorization failed
# <p>The authorization server have encoutered an error while processing...
content = await response.text()
raise exc.AuthenticationFailed(
f"Server response is not JSON: {HTTPMethod.POST} {AUTH_URL}: {content}"
)

return response

async def _request(
self, method: HTTPMethod, url: str, **kwargs: Any
) -> aiohttp.ClientResponse:
# The credentials can be either a refresh token or username

async with self.websession.request(
method, url, headers=AUTH_HEADER, json=kwargs["credentials"]
) as response:
return response

@abstractmethod
async def save_access_token(self) -> None: # HA: api
"""Save the access token to a cache."""
Loading

0 comments on commit a5e7c61

Please sign in to comment.