Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use future.annotations for modern typing #149

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions hishel/_async/_mock.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import typing as tp
from types import TracebackType

Expand All @@ -15,19 +17,19 @@ class MockAsyncConnectionPool(AsyncRequestInterface):
async def handle_async_request(self, request: httpcore.Request) -> httpcore.Response:
return self.mocked_responses.pop(0)

def add_responses(self, responses: tp.List[httpcore.Response]) -> None:
def add_responses(self, responses: list[httpcore.Response]) -> None:
if not hasattr(self, "mocked_responses"):
self.mocked_responses = []
self.mocked_responses.extend(responses)

async def __aenter__(self) -> "Self":
async def __aenter__(self) -> Self:
return self

async def __aexit__(
self,
exc_type: tp.Optional[tp.Type[BaseException]] = None,
exc_value: tp.Optional[BaseException] = None,
traceback: tp.Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
...

Expand All @@ -36,7 +38,7 @@ class MockAsyncTransport(httpx.AsyncBaseTransport):
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
return self.mocked_responses.pop(0)

def add_responses(self, responses: tp.List[httpx.Response]) -> None:
def add_responses(self, responses: list[httpx.Response]) -> None:
if not hasattr(self, "mocked_responses"):
self.mocked_responses = []
self.mocked_responses.extend(responses)
12 changes: 7 additions & 5 deletions hishel/_async/_pool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import datetime
import types
import typing as tp
Expand Down Expand Up @@ -35,8 +37,8 @@ class AsyncCacheConnectionPool(AsyncRequestInterface):
def __init__(
self,
pool: AsyncRequestInterface,
storage: tp.Optional[AsyncBaseStorage] = None,
controller: tp.Optional[Controller] = None,
storage: AsyncBaseStorage | None = None,
controller: Controller | None = None,
) -> None:
self._pool = pool

Expand Down Expand Up @@ -143,8 +145,8 @@ async def __aenter__(self: T) -> T:

async def __aexit__(
self,
exc_type: tp.Optional[tp.Type[BaseException]] = None,
exc_value: tp.Optional[BaseException] = None,
traceback: tp.Optional[types.TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
await self.aclose()
42 changes: 22 additions & 20 deletions hishel/_async/_storages.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
import time
import typing as tp
Expand Down Expand Up @@ -35,16 +37,16 @@
class AsyncBaseStorage:
def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
serializer: BaseSerializer | None = None,
ttl: int | float | None = None,
) -> None:
self._serializer = serializer or JSONSerializer()
self._ttl = ttl

async def store(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
raise NotImplementedError()

async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
async def retrieve(self, key: str) -> StoredResponse | None:
raise NotImplementedError()

async def aclose(self) -> None:
Expand All @@ -65,9 +67,9 @@ class AsyncFileStorage(AsyncBaseStorage):

def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
base_path: tp.Optional[Path] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
serializer: BaseSerializer | None = None,
base_path: Path | None = None,
ttl: int | float | None = None,
) -> None:
super().__init__(serializer, ttl)

Expand Down Expand Up @@ -101,7 +103,7 @@ async def store(self, key: str, response: Response, request: Request, metadata:
)
await self._remove_expired_caches()

async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
async def retrieve(self, key: str) -> StoredResponse | None:
"""
Retreives the response from the cache using his key.

Expand Down Expand Up @@ -148,9 +150,9 @@ class AsyncSQLiteStorage(AsyncBaseStorage):

def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
connection: tp.Optional["anysqlite.Connection"] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
serializer: BaseSerializer | None = None,
connection: anysqlite.Connection | None = None,
ttl: int | float | None = None,
) -> None:
if anysqlite is None: # pragma: no cover
raise RuntimeError(
Expand All @@ -162,7 +164,7 @@ def __init__(
)
super().__init__(serializer, ttl)

self._connection: tp.Optional[anysqlite.Connection] = connection or None
self._connection: anysqlite.Connection | None = connection or None
self._setup_lock = AsyncLock()
self._setup_completed: bool = False
self._lock = AsyncLock()
Expand Down Expand Up @@ -204,7 +206,7 @@ async def store(self, key: str, response: Response, request: Request, metadata:
await self._connection.commit()
await self._remove_expired_caches()

async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
async def retrieve(self, key: str) -> StoredResponse | None:
"""
Retreives the response from the cache using his key.

Expand Down Expand Up @@ -255,9 +257,9 @@ class AsyncRedisStorage(AsyncBaseStorage):

def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
client: tp.Optional["redis.Redis"] = None, # type: ignore
ttl: tp.Optional[tp.Union[int, float]] = None,
serializer: BaseSerializer | None = None,
client: redis.Redis | None = None, # type: ignore
ttl: int | float | None = None,
) -> None:
if redis is None: # pragma: no cover
raise RuntimeError(
Expand Down Expand Up @@ -297,7 +299,7 @@ async def store(self, key: str, response: Response, request: Request, metadata:
key, self._serializer.dumps(response=response, request=request, metadata=metadata), px=px
)

async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
async def retrieve(self, key: str) -> StoredResponse | None:
"""
Retreives the response from the cache using his key.

Expand Down Expand Up @@ -331,8 +333,8 @@ class AsyncInMemoryStorage(AsyncBaseStorage):

def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
serializer: BaseSerializer | None = None,
ttl: int | float | None = None,
capacity: int = 128,
) -> None:
super().__init__(serializer, ttl)
Expand All @@ -342,7 +344,7 @@ def __init__(

from hishel import LFUCache

self._cache: LFUCache[str, tp.Tuple[StoredResponse, float]] = LFUCache(capacity=capacity)
self._cache: LFUCache[str, tuple[StoredResponse, float]] = LFUCache(capacity=capacity)
self._lock = AsyncLock()

async def store(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
Expand All @@ -366,7 +368,7 @@ async def store(self, key: str, response: Response, request: Request, metadata:
self._cache.put(key, (stored_response, time.monotonic()))
await self._remove_expired_caches()

async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
async def retrieve(self, key: str) -> StoredResponse | None:
"""
Retreives the response from the cache using his key.

Expand Down
14 changes: 8 additions & 6 deletions hishel/_async/_transports.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import datetime
import types
import typing as tp
Expand Down Expand Up @@ -56,8 +58,8 @@ class AsyncCacheTransport(httpx.AsyncBaseTransport):
def __init__(
self,
transport: httpx.AsyncBaseTransport,
storage: tp.Optional[AsyncBaseStorage] = None,
controller: tp.Optional[Controller] = None,
storage: AsyncBaseStorage | None = None,
controller: Controller | None = None,
) -> None:
self._transport = transport

Expand Down Expand Up @@ -237,13 +239,13 @@ async def aclose(self) -> None:
await self._storage.aclose()
await self._transport.aclose()

async def __aenter__(self) -> "Self":
async def __aenter__(self) -> Self:
return self

async def __aexit__(
self,
exc_type: tp.Optional[tp.Type[BaseException]] = None,
exc_value: tp.Optional[BaseException] = None,
traceback: tp.Optional[types.TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
) -> None:
await self.aclose()
26 changes: 14 additions & 12 deletions hishel/_controller.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import typing as tp

from httpcore import Request, Response
Expand All @@ -20,9 +22,9 @@


def get_updated_headers(
stored_response_headers: tp.List[tp.Tuple[bytes, bytes]],
new_response_headers: tp.List[tp.Tuple[bytes, bytes]],
) -> tp.List[tp.Tuple[bytes, bytes]]:
stored_response_headers: list[tuple[bytes, bytes]],
new_response_headers: list[tuple[bytes, bytes]],
) -> list[tuple[bytes, bytes]]:
updated_headers = []

checked = set()
Expand All @@ -46,7 +48,7 @@ def get_updated_headers(
return updated_headers


def get_freshness_lifetime(response: Response) -> tp.Optional[int]:
def get_freshness_lifetime(response: Response) -> int | None:
response_cache_control = parse_cache_control(extract_header_values_decoded(response.headers, b"Cache-Control"))

if response_cache_control.max_age is not None:
Expand All @@ -62,7 +64,7 @@ def get_freshness_lifetime(response: Response) -> tp.Optional[int]:
return None


def get_heuristic_freshness(response: Response, clock: "BaseClock") -> int:
def get_heuristic_freshness(response: Response, clock: BaseClock) -> int:
last_modified = extract_header_values_decoded(response.headers, b"last-modified", single=True)

if last_modified:
Expand All @@ -77,7 +79,7 @@ def get_heuristic_freshness(response: Response, clock: "BaseClock") -> int:
return ONE_DAY


def get_age(response: Response, clock: "BaseClock") -> int:
def get_age(response: Response, clock: BaseClock) -> int:
if not header_presents(response.headers, b"date"): # pragma: no cover
raise RuntimeError("The `Date` header is missing in the response.")

Expand All @@ -104,13 +106,13 @@ def allowed_stale(response: Response) -> bool:
class Controller:
def __init__(
self,
cacheable_methods: tp.Optional[tp.List[str]] = None,
cacheable_status_codes: tp.Optional[tp.List[int]] = None,
cacheable_methods: list[str] | None = None,
cacheable_status_codes: list[int] | None = None,
allow_heuristics: bool = False,
clock: tp.Optional[BaseClock] = None,
clock: BaseClock | None = None,
allow_stale: bool = False,
always_revalidate: bool = False,
key_generator: tp.Optional[tp.Callable[[Request], str]] = None,
key_generator: tp.Callable[[Request], str] | None = None,
):
self._cacheable_methods = []

Expand Down Expand Up @@ -216,7 +218,7 @@ def _make_request_conditional(self, request: Request, response: Response) -> Non
else:
etag = None

precondition_headers: tp.List[tp.Tuple[bytes, bytes]] = []
precondition_headers: list[tuple[bytes, bytes]] = []
if last_modified:
precondition_headers.append((b"If-Modified-Since", last_modified))
if etag:
Expand Down Expand Up @@ -246,7 +248,7 @@ def _validate_vary(self, request: Request, response: Response, original_request:

def construct_response_from_cache(
self, request: Request, response: Response, original_request: Request
) -> tp.Union[Response, Request, None]:
) -> Response | Request | None:
"""
Specifies whether the response should be used, skipped, or validated by the cache.

Expand Down
18 changes: 10 additions & 8 deletions hishel/_files.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import typing as tp

import anyio
Expand All @@ -7,21 +9,21 @@ class AsyncBaseFileManager:
def __init__(self, is_binary: bool) -> None:
self.is_binary = is_binary

async def write_to(self, path: str, data: tp.Union[bytes, str], is_binary: tp.Optional[bool] = None) -> None:
async def write_to(self, path: str, data: bytes | str, is_binary: bool | None = None) -> None:
raise NotImplementedError()

async def read_from(self, path: str, is_binary: tp.Optional[bool] = None) -> tp.Union[bytes, str]:
async def read_from(self, path: str, is_binary: bool | None = None) -> bytes | str:
raise NotImplementedError()


class AsyncFileManager(AsyncBaseFileManager):
async def write_to(self, path: str, data: tp.Union[bytes, str], is_binary: tp.Optional[bool] = None) -> None:
async def write_to(self, path: str, data: bytes | str, is_binary: bool | None = None) -> None:
is_binary = self.is_binary if is_binary is None else is_binary
mode = "wb" if is_binary else "wt"
async with await anyio.open_file(path, mode) as f: # type: ignore[call-overload]
await f.write(data)

async def read_from(self, path: str, is_binary: tp.Optional[bool] = None) -> tp.Union[bytes, str]:
async def read_from(self, path: str, is_binary: bool | None = None) -> bytes | str:
is_binary = self.is_binary if is_binary is None else is_binary
mode = "rb" if is_binary else "rt"

Expand All @@ -33,21 +35,21 @@ class BaseFileManager:
def __init__(self, is_binary: bool) -> None:
self.is_binary = is_binary

def write_to(self, path: str, data: tp.Union[bytes, str], is_binary: tp.Optional[bool] = None) -> None:
def write_to(self, path: str, data: bytes | str, is_binary: bool | None = None) -> None:
raise NotImplementedError()

def read_from(self, path: str, is_binary: tp.Optional[bool] = None) -> tp.Union[bytes, str]:
def read_from(self, path: str, is_binary: bool | None = None) -> bytes | str:
raise NotImplementedError()


class FileManager(BaseFileManager):
def write_to(self, path: str, data: tp.Union[bytes, str], is_binary: tp.Optional[bool] = None) -> None:
def write_to(self, path: str, data: bytes | str, is_binary: bool | None = None) -> None:
is_binary = self.is_binary if is_binary is None else is_binary
mode = "wb" if is_binary else "wt"
with open(path, mode) as f:
f.write(data)

def read_from(self, path: str, is_binary: tp.Optional[bool] = None) -> tp.Union[bytes, str]:
def read_from(self, path: str, is_binary: bool | None = None) -> bytes | str:
is_binary = self.is_binary if is_binary is None else is_binary
mode = "rb" if is_binary else "rt"
with open(path, mode) as f:
Expand Down
Loading