Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve typing against Trio #239

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,16 @@ warn_unused_configs = true
warn_unused_ignores = true

[[tool.mypy.overrides]]
module =["aioquic.*", "cryptography.*", "h11.*", "h2.*", "priority.*", "pytest_asyncio.*", "trio.*", "uvloop.*"]
module =["aioquic.*", "cryptography.*", "h11.*", "h2.*", "priority.*", "pytest_asyncio.*", "uvloop.*"]
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = ["trio.*", "tests.trio.*"]
disallow_any_generics = true
disallow_untyped_calls = true
strict_optional = true
warn_return_any = true

[tool.pytest.ini_options]
addopts = "--no-cov-on-fail --showlocals --strict-markers"
asyncio_mode = "strict"
Expand Down
6 changes: 4 additions & 2 deletions src/hypercorn/middleware/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Callable, Dict

from ..asyncio.task_group import TaskGroup
from ..typing import ASGIFramework, Scope
from ..typing import ASGIFramework, ASGIReceiveEvent, Scope

MAX_QUEUE_SIZE = 10

Expand Down Expand Up @@ -74,7 +74,9 @@ class TrioDispatcherMiddleware(_DispatcherMiddleware):
async def _handle_lifespan(self, scope: Scope, receive: Callable, send: Callable) -> None:
import trio

self.app_queues = {path: trio.open_memory_channel(MAX_QUEUE_SIZE) for path in self.mounts}
self.app_queues = {
path: trio.open_memory_channel[ASGIReceiveEvent](MAX_QUEUE_SIZE) for path in self.mounts
}
self.startup_complete = {path: False for path in self.mounts}
self.shutdown_complete = {path: False for path in self.mounts}

Expand Down
2 changes: 1 addition & 1 deletion src/hypercorn/trio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def serve(
config: Config,
*,
shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None,
task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED,
task_status: trio.TaskStatus = trio.TASK_STATUS_IGNORED,
mode: Optional[Literal["asgi", "wsgi"]] = None,
) -> None:
"""Serve an ASGI framework app given the config.
Expand Down
8 changes: 4 additions & 4 deletions src/hypercorn/trio/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def __init__(self, app: AppWrapper, config: Config, state: LifespanState) -> Non
self.config = config
self.startup = trio.Event()
self.shutdown = trio.Event()
self.app_send_channel, self.app_receive_channel = trio.open_memory_channel(
config.max_app_queue_size
)
self.app_send_channel, self.app_receive_channel = trio.open_memory_channel[
ASGIReceiveEvent
](config.max_app_queue_size)
self.state = state
self.supported = True

async def handle_lifespan(
self, *, task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED
self, *, task_status: trio.TaskStatus = trio.TASK_STATUS_IGNORED
) -> None:
task_status.started()
scope: LifespanScope = {
Expand Down
4 changes: 2 additions & 2 deletions src/hypercorn/trio/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def worker_serve(
*,
sockets: Optional[Sockets] = None,
shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None,
task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED,
task_status: trio.TaskStatus = trio.TASK_STATUS_IGNORED,
) -> None:
config.set_statsd_logger_class(StatsdLogger)

Expand All @@ -57,7 +57,7 @@ async def worker_serve(
sock.listen(config.backlog)

ssl_context = config.create_ssl_context()
listeners = []
listeners: list[trio.SSLListener[trio.SocketStream] | trio.SocketListener] = []
binds = []
for sock in sockets.secure_sockets:
listeners.append(
Expand Down
9 changes: 6 additions & 3 deletions src/hypercorn/trio/task_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import sys
from contextlib import AbstractAsyncContextManager
from types import TracebackType
from typing import Any, Awaitable, Callable, Optional

Expand Down Expand Up @@ -41,8 +42,8 @@ async def _handle(

class TaskGroup:
def __init__(self) -> None:
self._nursery: Optional[trio._core._run.Nursery] = None
self._nursery_manager: Optional[trio._core._run.NurseryManager] = None
self._nursery: trio.Nursery | None = None
self._nursery_manager: AbstractAsyncContextManager[trio.Nursery] | None = None

async def spawn_app(
self,
Expand All @@ -51,7 +52,9 @@ async def spawn_app(
scope: Scope,
send: Callable[[Optional[ASGISendEvent]], Awaitable[None]],
) -> Callable[[ASGIReceiveEvent], Awaitable[None]]:
app_send_channel, app_receive_channel = trio.open_memory_channel(config.max_app_queue_size)
app_send_channel, app_receive_channel = trio.open_memory_channel[ASGIReceiveEvent](
config.max_app_queue_size
)
self._nursery.start_soon(
_handle,
app,
Expand Down
2 changes: 1 addition & 1 deletion src/hypercorn/trio/tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
config: Config,
context: WorkerContext,
state: LifespanState,
stream: trio.abc.Stream,
stream: trio.SSLStream[trio.SocketStream],
) -> None:
self.app = app
self.config = config
Expand Down
8 changes: 4 additions & 4 deletions src/hypercorn/trio/udp_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import socket

import trio

from .task_group import TaskGroup
Expand All @@ -19,17 +21,15 @@ def __init__(
config: Config,
context: WorkerContext,
state: LifespanState,
socket: trio.socket.socket,
socket: socket.socket,
) -> None:
self.app = app
self.config = config
self.context = context
self.socket = trio.socket.from_stdlib_socket(socket)
self.state = state

async def run(
self, task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED
) -> None:
async def run(self, task_status: trio.TaskStatus = trio.TASK_STATUS_IGNORED) -> None:
from ..protocol.quic import QuicProtocol # h3/Quic is an optional part of Hypercorn

task_status.started()
Expand Down
2 changes: 1 addition & 1 deletion src/hypercorn/trio/worker_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def _cancel_wrapper(func: Callable[[], Awaitable[None]]) -> Callable[[], Awaitable[None]]:
@wraps(func)
async def wrapper(
task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED,
task_status: trio.TaskStatus = trio.TASK_STATUS_IGNORED,
) -> None:
cancel_scope = trio.CancelScope()
task_status.started(cancel_scope)
Expand Down
6 changes: 4 additions & 2 deletions src/hypercorn/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,12 @@ class LifespanScope(TypedDict):
Scope = Union[HTTPScope, WebsocketScope, LifespanScope]


# A lot of fields should probably be marked with `NotRequired`, but only
# added these for now. See https://github.com/django/asgiref/issues/460
class HTTPRequestEvent(TypedDict):
type: Literal["http.request"]
body: bytes
more_body: bool
body: NotRequired[bytes]
more_body: NotRequired[bool]


class HTTPResponseStartEvent(TypedDict):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_app_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import trio

from hypercorn.app_wrappers import _build_environ, InvalidPathError, WSGIWrapper
from hypercorn.typing import ASGISendEvent, ConnectionState, HTTPScope
from hypercorn.typing import ASGIReceiveEvent, ASGISendEvent, ConnectionState, HTTPScope


def echo_body(environ: dict, start_response: Callable) -> List[bytes]:
Expand Down Expand Up @@ -41,7 +41,7 @@ async def test_wsgi_trio() -> None:
"extensions": {},
"state": ConnectionState({}),
}
send_channel, receive_channel = trio.open_memory_channel(1)
send_channel, receive_channel = trio.open_memory_channel[ASGIReceiveEvent](1)
await send_channel.send({"type": "http.request"})

messages = []
Expand Down
51 changes: 32 additions & 19 deletions tests/trio/test_keep_alive.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable, Generator
from typing import Awaitable, Callable, cast, Generator, TYPE_CHECKING

import h11
import pytest
Expand All @@ -10,22 +10,33 @@
from hypercorn.config import Config
from hypercorn.trio.tcp_server import TCPServer
from hypercorn.trio.worker_context import WorkerContext
from hypercorn.typing import Scope
from hypercorn.typing import ASGIReceiveEvent, ASGISendEvent, Scope
from ..helpers import MockSocket

if TYPE_CHECKING:
from typing_extensions import TypeAlias

KEEP_ALIVE_TIMEOUT = 0.01
REQUEST = h11.Request(method="GET", target="/", headers=[(b"host", b"hypercorn")])

ClientStream: TypeAlias = trio.StapledStream[
trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream
]


async def slow_framework(scope: Scope, receive: Callable, send: Callable) -> None:
async def slow_framework(
scope: Scope,
receive: Callable[[], Awaitable[ASGIReceiveEvent]],
send: Callable[[ASGISendEvent], Awaitable[None]],
) -> None:
while True:
event = await receive()
if event["type"] == "http.disconnect":
break
elif event["type"] == "lifespan.startup":
await send({"type": "lifspan.startup.complete"})
await send({"type": "lifespan.startup.complete"})
elif event["type"] == "lifespan.shutdown":
await send({"type": "lifspan.shutdown.complete"})
await send({"type": "lifespan.shutdown.complete"})
elif event["type"] == "http.request" and not event.get("more_body", False):
await trio.sleep(2 * KEEP_ALIVE_TIMEOUT)
await send(
Expand All @@ -41,21 +52,20 @@ async def slow_framework(scope: Scope, receive: Callable, send: Callable) -> Non

@pytest.fixture(name="client_stream", scope="function")
def _client_stream(
nursery: trio._core._run.Nursery,
) -> Generator[trio.testing._memory_streams.MemorySendStream, None, None]:
nursery: trio.Nursery,
) -> Generator[ClientStream, None, None]:
config = Config()
config.keep_alive_timeout = KEEP_ALIVE_TIMEOUT
client_stream, server_stream = trio.testing.memory_stream_pair()
server_stream = cast("trio.SSLStream[trio.SocketStream]", server_stream)
server_stream.socket = MockSocket()
server = TCPServer(ASGIWrapper(slow_framework), config, WorkerContext(None), {}, server_stream)
nursery.start_soon(server.run)
yield client_stream


@pytest.mark.trio
async def test_http1_keep_alive_pre_request(
client_stream: trio.testing._memory_streams.MemorySendStream,
) -> None:
async def test_http1_keep_alive_pre_request(client_stream: ClientStream) -> None:
await client_stream.send_all(b"GET")
await trio.sleep(2 * KEEP_ALIVE_TIMEOUT)
# Only way to confirm closure is to invoke an error
Expand All @@ -65,23 +75,26 @@ async def test_http1_keep_alive_pre_request(

@pytest.mark.trio
async def test_http1_keep_alive_during(
client_stream: trio.testing._memory_streams.MemorySendStream,
client_stream: ClientStream,
) -> None:
client = h11.Connection(h11.CLIENT)
await client_stream.send_all(client.send(REQUEST))
# client.send(h11.Request) and client.send(h11.EndOfMessage) only returns bytes.
# Fixed on master/ in the h11 repo, once released the ignore's can be removed.
# See https://github.com/python-hyper/h11/issues/175
await client_stream.send_all(client.send(REQUEST)) # type: ignore[arg-type]
await trio.sleep(2 * KEEP_ALIVE_TIMEOUT)
# Key is that this doesn't error
await client_stream.send_all(client.send(h11.EndOfMessage()))
await client_stream.send_all(client.send(h11.EndOfMessage())) # type: ignore[arg-type]


@pytest.mark.trio
async def test_http1_keep_alive(
client_stream: trio.testing._memory_streams.MemorySendStream,
client_stream: ClientStream,
) -> None:
client = h11.Connection(h11.CLIENT)
await client_stream.send_all(client.send(REQUEST))
await client_stream.send_all(client.send(REQUEST)) # type: ignore[arg-type]
await trio.sleep(2 * KEEP_ALIVE_TIMEOUT)
await client_stream.send_all(client.send(h11.EndOfMessage()))
await client_stream.send_all(client.send(h11.EndOfMessage())) # type: ignore[arg-type]
while True:
event = client.next_event()
if event == h11.NEED_DATA:
Expand All @@ -90,15 +103,15 @@ async def test_http1_keep_alive(
elif isinstance(event, h11.EndOfMessage):
break
client.start_next_cycle()
await client_stream.send_all(client.send(REQUEST))
await client_stream.send_all(client.send(REQUEST)) # type: ignore[arg-type]
await trio.sleep(2 * KEEP_ALIVE_TIMEOUT)
# Key is that this doesn't error
await client_stream.send_all(client.send(h11.EndOfMessage()))
await client_stream.send_all(client.send(h11.EndOfMessage())) # type: ignore[arg-type]


@pytest.mark.trio
async def test_http1_keep_alive_pipelining(
client_stream: trio.testing._memory_streams.MemorySendStream,
client_stream: ClientStream,
) -> None:
await client_stream.send_all(
b"GET / HTTP/1.1\r\nHost: hypercorn\r\n\r\nGET / HTTP/1.1\r\nHost: hypercorn\r\n\r\n"
Expand Down
16 changes: 11 additions & 5 deletions tests/trio/test_sanity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from typing import cast
from unittest.mock import Mock, PropertyMock

import h2
Expand All @@ -24,14 +25,16 @@
@pytest.mark.trio
async def test_http1_request(nursery: trio._core._run.Nursery) -> None:
client_stream, server_stream = trio.testing.memory_stream_pair()
server_stream = cast("trio.SSLStream[trio.SocketStream]", server_stream)
server_stream.socket = MockSocket()
server = TCPServer(
ASGIWrapper(sanity_framework), Config(), WorkerContext(None), {}, server_stream
)
nursery.start_soon(server.run)
client = h11.Connection(h11.CLIENT)
await client_stream.send_all(
client.send(
# h11 types are incorrect, awaiting release.
client.send( # type: ignore[arg-type]
h11.Request(
method="POST",
target="/",
Expand All @@ -43,8 +46,8 @@ async def test_http1_request(nursery: trio._core._run.Nursery) -> None:
)
)
)
await client_stream.send_all(client.send(h11.Data(data=SANITY_BODY)))
await client_stream.send_all(client.send(h11.EndOfMessage()))
await client_stream.send_all(client.send(h11.Data(data=SANITY_BODY))) # type: ignore[arg-type]
await client_stream.send_all(client.send(h11.EndOfMessage())) # type: ignore[arg-type]
events = []
while True:
event = client.next_event()
Expand Down Expand Up @@ -77,6 +80,7 @@ async def test_http1_request(nursery: trio._core._run.Nursery) -> None:
@pytest.mark.trio
async def test_http1_websocket(nursery: trio._core._run.Nursery) -> None:
client_stream, server_stream = trio.testing.memory_stream_pair()
server_stream = cast("trio.SSLStream[trio.SocketStream]", server_stream)
server_stream.socket = MockSocket()
server = TCPServer(
ASGIWrapper(sanity_framework), Config(), WorkerContext(None), {}, server_stream
Expand Down Expand Up @@ -104,8 +108,9 @@ async def test_http1_websocket(nursery: trio._core._run.Nursery) -> None:
@pytest.mark.trio
async def test_http2_request(nursery: trio._core._run.Nursery) -> None:
client_stream, server_stream = trio.testing.memory_stream_pair()
server_stream = cast("trio.SSLStream[trio.SocketStream]", server_stream)
server_stream.transport_stream = Mock(return_value=PropertyMock(return_value=MockSocket()))
server_stream.do_handshake = AsyncMock()
server_stream.do_handshake = AsyncMock() # type: ignore[method-assign]
server_stream.selected_alpn_protocol = Mock(return_value="h2")
server = TCPServer(
ASGIWrapper(sanity_framework), Config(), WorkerContext(None), {}, server_stream
Expand Down Expand Up @@ -161,8 +166,9 @@ async def test_http2_request(nursery: trio._core._run.Nursery) -> None:
@pytest.mark.trio
async def test_http2_websocket(nursery: trio._core._run.Nursery) -> None:
client_stream, server_stream = trio.testing.memory_stream_pair()
server_stream = cast("trio.SSLStream[trio.SocketStream]", server_stream)
server_stream.transport_stream = Mock(return_value=PropertyMock(return_value=MockSocket()))
server_stream.do_handshake = AsyncMock()
server_stream.do_handshake = AsyncMock() # type: ignore[method-assign]
server_stream.selected_alpn_protocol = Mock(return_value="h2")
server = TCPServer(
ASGIWrapper(sanity_framework), Config(), WorkerContext(None), {}, server_stream
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ basepython = python3.12
deps =
mypy
pytest
trio
commands =
mypy src/hypercorn/ tests/

Expand Down