From bbf0c12afda1bcc87438858542d65462976a84f2 Mon Sep 17 00:00:00 2001 From: David Bonnes Date: Sat, 2 Nov 2024 21:57:38 +0000 Subject: [PATCH] fixes --- src/evohomeasync2/auth.py | 52 ++++++++++++++++++++++++++++++++-- tests/tests_rf/conftest.py | 40 +++++++++----------------- tests/tests_rf/test_v2_task.py | 12 +++++--- 3 files changed, 72 insertions(+), 32 deletions(-) diff --git a/src/evohomeasync2/auth.py b/src/evohomeasync2/auth.py index 26ecedac..9609b1d2 100644 --- a/src/evohomeasync2/auth.py +++ b/src/evohomeasync2/auth.py @@ -7,6 +7,7 @@ from abc import ABC, abstractmethod from datetime import datetime as dt, timedelta as td from http import HTTPMethod, HTTPStatus +from types import TracebackType from typing import TYPE_CHECKING, Any, Final, TypedDict import aiohttp @@ -59,7 +60,7 @@ HTTPStatus.UNAUTHORIZED: "Unauthorized (expired access token/unknown entity id?)", } -SZ_USERNAME: Final = "Username" # TODO: is camelCase (and not PascalCase) OK? +SZ_USERNAME: Final = "Username" SZ_PASSWORD: Final = "Password" @@ -261,6 +262,51 @@ async def save_access_token(self) -> None: # HA: api """Save the access token to a cache.""" +class _RequestContextManager: + """A context manager for Auth's aiohttp request.""" + + _response: aiohttp.ClientResponse | None = None + + def __init__( + self, + websession: aiohttp.ClientSession, + method: HTTPMethod, + url: StrOrURL, + /, + **kwargs: Any, + ): + """Initialize the request context manager.""" + + self.websession = websession + self.method = method + self.url = url + self.kwargs = kwargs + + async def __aenter__(self) -> aiohttp.ClientResponse: + """Async context manager entry.""" + self._response = await self._await_impl() + return self._response + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + """Async context manager exit.""" + if self._response: + self._response.release() # or: close() + await self._response.wait_for_close() + + def __await__(self) -> aiohttp.ClientResponse: + """Make this class awaitable.""" + return self._await_impl().__await__() + + async def _await_impl(self) -> aiohttp.ClientResponse: + """Return the actual result.""" + return await self.websession.request(self.method, self.url, **self.kwargs) + + class AbstractAuth(ABC): # APIs esposed by/for HA def __init__(self, websession: aiohttp.ClientSession, host: str) -> None: """Initialize the auth.""" @@ -282,7 +328,9 @@ async def request( # type: ignore[no-untyped-def] } headers["Authorization"] = "bearer " + await self.get_access_token() - return self.websession.request(method, url, **kwargs, headers=headers) + return await _RequestContextManager( + self.websession, method, url, **kwargs, headers=headers + ) class Auth(AbstractAuth): diff --git a/tests/tests_rf/conftest.py b/tests/tests_rf/conftest.py index 84944ee3..1a074df6 100644 --- a/tests/tests_rf/conftest.py +++ b/tests/tests_rf/conftest.py @@ -6,10 +6,10 @@ import functools import json import os -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from datetime import datetime as dt, timedelta as td from pathlib import Path -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, Final, TypeVar import pytest @@ -21,7 +21,7 @@ # # normally, we want debug flags to be False -_DBG_USE_REAL_AIOHTTP = False +_DBG_USE_REAL_AIOHTTP = True _DBG_DISABLE_STRICT_ASSERTS = False # of response content-type, schema if TYPE_CHECKING: @@ -32,23 +32,25 @@ TEST_PASSWORD: Final[str] = "P@ssw0rd!!" # noqa: S105 +_F = TypeVar("_F", bound=Callable[..., Any]) + + # Global flag to indicate if AuthenticationFailedError has been encountered global_auth_failed = False -def skipif_auth_failed(fnc): +def skipif_auth_failed(fnc: _F) -> _F: """Decorator to skip tests if AuthenticationFailedError is encountered.""" @functools.wraps(fnc) - async def wrapper(*args, **kwargs): + async def wrapper(*args: Any, **kwargs: Any) -> Any: global global_auth_failed if global_auth_failed: pytest.skip("Unable to authenticate") try: - result = await fnc(*args, **kwargs) - return result + return await fnc(*args, **kwargs) except ( evo1.AuthenticationFailedError, @@ -60,7 +62,7 @@ async def wrapper(*args, **kwargs): global_auth_failed = True pytest.fail(f"Unable to authenticate: {err}") - return wrapper + return wrapper # type: ignore[return-value] @pytest.fixture(autouse=True) @@ -166,19 +168,12 @@ async def evohome_v1( ) -> AsyncGenerator[evo1.EvohomeClient, None]: """Yield an instance of a v1 EvohomeClient.""" - global skipif_auth_failed - evo = evo1.EvohomeClient(*credentials, websession=client_session) try: yield evo - - except evo1.AuthenticationFailedError as err: - if not _DBG_USE_REAL_AIOHTTP: - raise - - skipif_auth_failed = True - pytest.skip(f"Unable to authenticate: {err}") + finally: + pass @pytest.fixture @@ -187,16 +182,9 @@ async def evohome_v2( ) -> AsyncGenerator[evo2.EvohomeClientNew, None]: """Yield an instance of a v2 EvohomeClient.""" - global skipif_auth_failed - evo = evo2.EvohomeClientNew(token_manager) try: yield evo - - except evo2.AuthenticationFailedError as err: - if not _DBG_USE_REAL_AIOHTTP: - raise - - skipif_auth_failed = True - pytest.skip(f"Unable to authenticate: {err}") + finally: + pass diff --git a/tests/tests_rf/test_v2_task.py b/tests/tests_rf/test_v2_task.py index 23a4d83e..de49728e 100644 --- a/tests/tests_rf/test_v2_task.py +++ b/tests/tests_rf/test_v2_task.py @@ -9,7 +9,7 @@ import pytest import evohomeasync2 as evo2 -from evohomeasync2 import ControlSystem, Gateway, Location +from evohomeasync2 import ControlSystem, Gateway, HotWater, Location from evohomeasync2.const import API_STRFTIME, DhwState, ZoneMode from evohomeasync2.schema.const import ( SZ_MODE, @@ -38,6 +38,8 @@ async def _test_task_id(evo: evo2.EvohomeClientNew) -> None: _ = await evo.update(dont_update_status=True) + dhw: HotWater | None = None + for loc in evo.locations: for gwy in loc.gateways: for tcs in gwy.control_systems: @@ -45,9 +47,11 @@ async def _test_task_id(evo: evo2.EvohomeClientNew) -> None: # if (dhw := tcs.hotwater) and dhw.temperatureStatus['isAvailable']: dhw = tcs.hotwater break - # else: - # pytest.skip("No available DHW found") - # + + else: # No available DHW found + pytest.skip("No available DHW found") + + assert dhw is not None # mypy hint GET_URL = f"{dhw.TYPE}/{dhw.id}/status" PUT_URL = f"{dhw.TYPE}/{dhw.id}/state"