Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
zxdavb committed Nov 2, 2024
1 parent 1fb1594 commit bbf0c12
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 32 deletions.
52 changes: 50 additions & 2 deletions src/evohomeasync2/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from abc import ABC, abstractmethod
from datetime import datetime as dt, timedelta as td
from http import HTTPMethod, HTTPStatus
from types import TracebackType
from typing import TYPE_CHECKING, Any, Final, TypedDict

import aiohttp
Expand Down Expand Up @@ -59,7 +60,7 @@
HTTPStatus.UNAUTHORIZED: "Unauthorized (expired access token/unknown entity id?)",
}

SZ_USERNAME: Final = "Username" # TODO: is camelCase (and not PascalCase) OK?
SZ_USERNAME: Final = "Username"
SZ_PASSWORD: Final = "Password"


Expand Down Expand Up @@ -261,6 +262,51 @@ async def save_access_token(self) -> None: # HA: api
"""Save the access token to a cache."""


class _RequestContextManager:
"""A context manager for Auth's aiohttp request."""

_response: aiohttp.ClientResponse | None = None

def __init__(
self,
websession: aiohttp.ClientSession,
method: HTTPMethod,
url: StrOrURL,
/,
**kwargs: Any,
):
"""Initialize the request context manager."""

self.websession = websession
self.method = method
self.url = url
self.kwargs = kwargs

async def __aenter__(self) -> aiohttp.ClientResponse:
"""Async context manager entry."""
self._response = await self._await_impl()
return self._response

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> None:
"""Async context manager exit."""
if self._response:
self._response.release() # or: close()
await self._response.wait_for_close()

def __await__(self) -> aiohttp.ClientResponse:
"""Make this class awaitable."""
return self._await_impl().__await__()

async def _await_impl(self) -> aiohttp.ClientResponse:
"""Return the actual result."""
return await self.websession.request(self.method, self.url, **self.kwargs)


class AbstractAuth(ABC): # APIs esposed by/for HA
def __init__(self, websession: aiohttp.ClientSession, host: str) -> None:
"""Initialize the auth."""
Expand All @@ -282,7 +328,9 @@ async def request( # type: ignore[no-untyped-def]
}
headers["Authorization"] = "bearer " + await self.get_access_token()

return self.websession.request(method, url, **kwargs, headers=headers)
return await _RequestContextManager(
self.websession, method, url, **kwargs, headers=headers
)


class Auth(AbstractAuth):
Expand Down
40 changes: 14 additions & 26 deletions tests/tests_rf/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import functools
import json
import os
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from datetime import datetime as dt, timedelta as td
from pathlib import Path
from typing import TYPE_CHECKING, Any, Final
from typing import TYPE_CHECKING, Any, Final, TypeVar

import pytest

Expand All @@ -21,7 +21,7 @@

#
# normally, we want debug flags to be False
_DBG_USE_REAL_AIOHTTP = False
_DBG_USE_REAL_AIOHTTP = True
_DBG_DISABLE_STRICT_ASSERTS = False # of response content-type, schema

if TYPE_CHECKING:
Expand All @@ -32,23 +32,25 @@
TEST_PASSWORD: Final[str] = "P@ssw0rd!!" # noqa: S105


_F = TypeVar("_F", bound=Callable[..., Any])


# Global flag to indicate if AuthenticationFailedError has been encountered
global_auth_failed = False


def skipif_auth_failed(fnc):
def skipif_auth_failed(fnc: _F) -> _F:
"""Decorator to skip tests if AuthenticationFailedError is encountered."""

@functools.wraps(fnc)
async def wrapper(*args, **kwargs):
async def wrapper(*args: Any, **kwargs: Any) -> Any:
global global_auth_failed

if global_auth_failed:
pytest.skip("Unable to authenticate")

try:
result = await fnc(*args, **kwargs)
return result
return await fnc(*args, **kwargs)

except (
evo1.AuthenticationFailedError,
Expand All @@ -60,7 +62,7 @@ async def wrapper(*args, **kwargs):
global_auth_failed = True
pytest.fail(f"Unable to authenticate: {err}")

return wrapper
return wrapper # type: ignore[return-value]


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -166,19 +168,12 @@ async def evohome_v1(
) -> AsyncGenerator[evo1.EvohomeClient, None]:
"""Yield an instance of a v1 EvohomeClient."""

global skipif_auth_failed

evo = evo1.EvohomeClient(*credentials, websession=client_session)

try:
yield evo

except evo1.AuthenticationFailedError as err:
if not _DBG_USE_REAL_AIOHTTP:
raise

skipif_auth_failed = True
pytest.skip(f"Unable to authenticate: {err}")
finally:
pass


@pytest.fixture
Expand All @@ -187,16 +182,9 @@ async def evohome_v2(
) -> AsyncGenerator[evo2.EvohomeClientNew, None]:
"""Yield an instance of a v2 EvohomeClient."""

global skipif_auth_failed

evo = evo2.EvohomeClientNew(token_manager)

try:
yield evo

except evo2.AuthenticationFailedError as err:
if not _DBG_USE_REAL_AIOHTTP:
raise

skipif_auth_failed = True
pytest.skip(f"Unable to authenticate: {err}")
finally:
pass
12 changes: 8 additions & 4 deletions tests/tests_rf/test_v2_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest

import evohomeasync2 as evo2
from evohomeasync2 import ControlSystem, Gateway, Location
from evohomeasync2 import ControlSystem, Gateway, HotWater, Location
from evohomeasync2.const import API_STRFTIME, DhwState, ZoneMode
from evohomeasync2.schema.const import (
SZ_MODE,
Expand Down Expand Up @@ -38,16 +38,20 @@ async def _test_task_id(evo: evo2.EvohomeClientNew) -> None:

_ = await evo.update(dont_update_status=True)

dhw: HotWater | None = None

for loc in evo.locations:
for gwy in loc.gateways:
for tcs in gwy.control_systems:
if tcs.hotwater:
# if (dhw := tcs.hotwater) and dhw.temperatureStatus['isAvailable']:
dhw = tcs.hotwater
break
# else:
# pytest.skip("No available DHW found")
#

else: # No available DHW found
pytest.skip("No available DHW found")

assert dhw is not None # mypy hint

GET_URL = f"{dhw.TYPE}/{dhw.id}/status"
PUT_URL = f"{dhw.TYPE}/{dhw.id}/state"
Expand Down

0 comments on commit bbf0c12

Please sign in to comment.