diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index a1d3a83389..9790cd61f9 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -1221,7 +1221,7 @@ many cases, you just want to pass objects between different tasks inside a single process, and for that you can use :func:`trio.open_memory_channel`: -.. autofunction:: open_memory_channel +.. autofunction:: open_memory_channel(max_buffer_size) .. note:: If you've used the :mod:`threading` or :mod:`asyncio` modules, you may be familiar with :class:`queue.Queue` or diff --git a/newsfragments/908.feature.rst b/newsfragments/908.feature.rst new file mode 100644 index 0000000000..45379e753c --- /dev/null +++ b/newsfragments/908.feature.rst @@ -0,0 +1,7 @@ +:class:`~trio.abc.SendChannel`, :class:`~trio.abc.ReceiveChannel`, :class:`~trio.abc.Listener`, +and :func:`~trio.open_memory_channel` can now be referenced using a generic type parameter +(the type of object sent over the channel or produced by the listener) using PEP 484 syntax: +``trio.abc.SendChannel[bytes]``, ``trio.abc.Listener[trio.SocketStream]``, +``trio.open_memory_channel[MyMessage](5)``, etc. The added type information does not change +the runtime semantics, but permits better integration with external static type checkers. + diff --git a/trio/_abc.py b/trio/_abc.py index c9957f1c5c..3ca395b7d5 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -1,4 +1,5 @@ from abc import ABCMeta, abstractmethod +from typing import Generic, TypeVar from ._util import aiter_compat from . import _core @@ -483,7 +484,22 @@ async def send_eof(self): """ -class Listener(AsyncResource): +# The type of object produced by a ReceiveChannel (covariant because +# ReceiveChannel[Derived] can be passed to someone expecting +# ReceiveChannel[Base]) +T_co = TypeVar("T_co", covariant=True) + +# The type of object accepted by a SendChannel (contravariant because +# SendChannel[Base] can be passed to someone expecting +# SendChannel[Derived]) +T_contra = TypeVar("T_contra", contravariant=True) + +# The type of object produced by a Listener (covariant plus must be +# an AsyncResource) +T_resource = TypeVar("T_resource", bound=AsyncResource, covariant=True) + + +class Listener(AsyncResource, Generic[T_resource]): """A standard interface for listening for incoming connections. :class:`Listener` objects also implement the :class:`AsyncResource` @@ -521,7 +537,7 @@ async def accept(self): """ -class SendChannel(AsyncResource): +class SendChannel(AsyncResource, Generic[T_contra]): """A standard interface for sending Python objects to some receiver. :class:`SendChannel` objects also implement the :class:`AsyncResource` @@ -595,7 +611,7 @@ def clone(self): """ -class ReceiveChannel(AsyncResource): +class ReceiveChannel(AsyncResource, Generic[T_co]): """A standard interface for receiving Python objects from some sender. You can iterate over a :class:`ReceiveChannel` using an ``async for`` diff --git a/trio/_channel.py b/trio/_channel.py index 87c2714a78..8b0a7d7426 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -6,8 +6,10 @@ from . import _core from .abc import SendChannel, ReceiveChannel +from ._util import generic_function +@generic_function def open_memory_channel(max_buffer_size): """Open a channel for passing objects between tasks within a process. diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index b6219c0bf9..aadd766880 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -318,7 +318,7 @@ def getsockopt(self, level, option, buffersize=0): pass -class SocketListener(Listener): +class SocketListener(Listener[SocketStream]): """A :class:`~trio.abc.Listener` that uses a listening socket to accept incoming connections as :class:`SocketStream` objects. diff --git a/trio/_ssl.py b/trio/_ssl.py index 2a3204bd3a..6f62121ccc 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -827,7 +827,7 @@ async def wait_send_all_might_not_block(self): await self.transport_stream.wait_send_all_might_not_block() -class SSLListener(Listener): +class SSLListener(Listener[SSLStream]): """A :class:`~trio.abc.Listener` for SSL/TLS-encrypted servers. :class:`SSLListener` wraps around another Listener, and converts diff --git a/trio/_util.py b/trio/_util.py index bfe7138191..d8dffa5628 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -4,7 +4,7 @@ import signal import sys import pathlib -from functools import wraps +from functools import wraps, update_wrapper import typing as t import async_generator @@ -22,6 +22,7 @@ "ConflictDetector", "fixup_module_metadata", "fspath", + "generic_function", ] # Equivalent to the C function raise(), which Python doesn't wrap @@ -171,7 +172,15 @@ def decorator(func): def fixup_module_metadata(module_name, namespace): + seen_ids = set() + def fix_one(obj): + # avoid infinite recursion (relevant when using + # typing.Generic, for example) + if id(obj) in seen_ids: + return + seen_ids.add(id(obj)) + mod = getattr(obj, "__module__", None) if mod is not None and mod.startswith("trio."): obj.__module__ = module_name @@ -242,3 +251,31 @@ def fspath(path) -> t.Union[str, bytes]: if hasattr(os, "fspath"): fspath = os.fspath # noqa + + +class generic_function: + """Decorator that makes a function indexable, to communicate + non-inferrable generic type parameters to a static type checker. + + If you write:: + + @generic_function + def open_memory_channel(max_buffer_size: int) -> Tuple[ + SendChannel[T], ReceiveChannel[T] + ]: ... + + it is valid at runtime to say ``open_memory_channel[bytes](5)``. + This behaves identically to ``open_memory_channel(5)`` at runtime, + and currently won't type-check without a mypy plugin or clever stubs, + but at least it becomes possible to write those. + """ + + def __init__(self, fn): + update_wrapper(self, fn) + self._fn = fn + + def __call__(self, *args, **kwargs): + return self._fn(*args, **kwargs) + + def __getitem__(self, _): + return self diff --git a/trio/tests/test_abc.py b/trio/tests/test_abc.py index c8af0927d6..b55267b24f 100644 --- a/trio/tests/test_abc.py +++ b/trio/tests/test_abc.py @@ -19,3 +19,31 @@ async def aclose(self): assert myar.record == [] assert myar.record == ["ac"] + + +def test_abc_generics(): + # Pythons below 3.5.2 had a typing.Generic that would throw + # errors when instantiating or subclassing a parameterized + # version of a class with any __slots__. This is why RunVar + # (which has slots) is not generic. This tests that + # the generic ABCs are fine, because while they are slotted + # they don't actually define any slots. + + class SlottedChannel(tabc.SendChannel[tabc.Stream]): + __slots__ = ("x",) + + def send_nowait(self, value): + raise RuntimeError + + async def send(self, value): + raise RuntimeError # pragma: no cover + + def clone(self): + raise RuntimeError # pragma: no cover + + async def aclose(self): + pass # pragma: no cover + + channel = SlottedChannel() + with pytest.raises(RuntimeError): + channel.send_nowait(None) diff --git a/trio/tests/test_util.py b/trio/tests/test_util.py index cbfe6fd546..ea7b74deeb 100644 --- a/trio/tests/test_util.py +++ b/trio/tests/test_util.py @@ -7,7 +7,9 @@ from .. import _core from .._threads import run_sync_in_worker_thread -from .._util import signal_raise, ConflictDetector, fspath, is_main_thread +from .._util import ( + signal_raise, ConflictDetector, fspath, is_main_thread, generic_function +) from ..testing import wait_all_tasks_blocked, assert_checkpoints @@ -168,3 +170,17 @@ def not_main_thread(): assert not is_main_thread() await run_sync_in_worker_thread(not_main_thread) + + +def test_generic_function(): + @generic_function + def test_func(arg): + """Look, a docstring!""" + return arg + + assert test_func is test_func[int] is test_func[int, str] + assert test_func(42) == test_func[int](42) == 42 + assert test_func.__doc__ == "Look, a docstring!" + assert test_func.__qualname__ == "test_generic_function..test_func" + assert test_func.__name__ == "test_func" + assert test_func.__module__ == __name__