From 7848c1e7c45d1a723c137ae9bcc117e050b5b660 Mon Sep 17 00:00:00 2001 From: Jonathan Slenders Date: Fri, 26 Jan 2024 16:31:01 +0000 Subject: [PATCH 1/3] Add stricter types for TaskGroup.start() --- src/anyio/abc/_tasks.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/anyio/abc/_tasks.py b/src/anyio/abc/_tasks.py index 7ad4938c..65664944 100644 --- a/src/anyio/abc/_tasks.py +++ b/src/anyio/abc/_tasks.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from .._core._tasks import CancelScope -T_Retval = TypeVar("T_Retval") +T_Retval = TypeVar("T_Retval", covariant=True) T_contra = TypeVar("T_contra", contravariant=True) PosArgsT = TypeVarTuple("PosArgsT") @@ -36,6 +36,13 @@ def started(self, value: T_contra | None = None) -> None: """ +class _StartFunc(Protocol[Unpack[PosArgsT], T_Retval]): + async def __call__( + self, *args: Unpack[PosArgsT], task_status: TaskStatus[T_Retval] + ) -> None: + ... + + class TaskGroup(metaclass=ABCMeta): """ Groups several asynchronous tasks together. @@ -66,10 +73,10 @@ def start_soon( @abstractmethod async def start( self, - func: Callable[..., Awaitable[Any]], - *args: object, + func: _StartFunc[Unpack[PosArgsT], T_Retval], + *args: Unpack[PosArgsT], name: object = None, - ) -> Any: + ) -> T_Retval: """ Start a new task and wait until it signals for readiness. From e2fc1ac5963bcea1044b388871e2e87274f08358 Mon Sep 17 00:00:00 2001 From: Jonathan Slenders Date: Fri, 26 Jan 2024 16:42:49 +0000 Subject: [PATCH 2/3] Fixed typing in unit tests. --- src/anyio/abc/_tasks.py | 29 +++++++++++++++++++++++++++-- src/anyio/streams/stapled.py | 4 +++- tests/test_debugging.py | 2 +- tests/test_taskgroups.py | 14 +++++++------- 4 files changed, 38 insertions(+), 11 deletions(-) diff --git a/src/anyio/abc/_tasks.py b/src/anyio/abc/_tasks.py index 65664944..ae729edf 100644 --- a/src/anyio/abc/_tasks.py +++ b/src/anyio/abc/_tasks.py @@ -7,9 +7,9 @@ from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload if sys.version_info >= (3, 11): - from typing import TypeVarTuple, Unpack + from typing import Never, TypeVarTuple, Unpack else: - from typing_extensions import TypeVarTuple, Unpack + from typing_extensions import Never, TypeVarTuple, Unpack if TYPE_CHECKING: from .._core._tasks import CancelScope @@ -70,6 +70,31 @@ def start_soon( .. versionadded:: 3.0 """ + @overload + async def start( + self, + func: _StartFunc[Unpack[PosArgsT], object], + *args: Unpack[PosArgsT], + name: object = None, + ) -> Never: + # Overload added for when the returned value is never captured. + # In that case, it doesn't matter which kind of function is given. + # Without this overload, mypy would fail: + # Argument 1 to "start" of "TaskGroup" has incompatible type + # "Callable[[TaskStatus[None]], Coroutine[Any, Any, None]]"; + # expected "_StartFunc[Never]" + ... + + @overload + async def start( + self, + func: _StartFunc[Unpack[PosArgsT], T_Retval], + *args: Unpack[PosArgsT], + name: object = None, + ) -> T_Retval: + ... + + @abstractmethod async def start( self, diff --git a/src/anyio/streams/stapled.py b/src/anyio/streams/stapled.py index 80f64a2e..55978659 100644 --- a/src/anyio/streams/stapled.py +++ b/src/anyio/streams/stapled.py @@ -120,7 +120,9 @@ def __post_init__(self) -> None: self.listeners = listeners async def serve( - self, handler: Callable[[T_Stream], Any], task_group: TaskGroup | None = None + self, + handler: Callable[[T_Stream], Any], + task_group: TaskGroup | None = None, ) -> None: from .. import create_task_group diff --git a/tests/test_debugging.py b/tests/test_debugging.py index fda505cd..e6c6d193 100644 --- a/tests/test_debugging.py +++ b/tests/test_debugging.py @@ -62,7 +62,7 @@ async def main() -> None: async def test_non_main_task_name( name_input: bytes | str | None, expected: str ) -> None: - async def non_main(*, task_status: TaskStatus) -> None: + async def non_main(*, task_status: TaskStatus[str | None]) -> None: task_status.started(anyio.get_current_task().name) async with anyio.create_task_group() as tg: diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index bc4a289d..cba1ed34 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -101,16 +101,16 @@ async def test_start_soon_after_error() -> None: async def test_start_no_value() -> None: - async def taskfunc(*, task_status: TaskStatus) -> None: + async def taskfunc(*, task_status: TaskStatus[None]) -> None: task_status.started() async with create_task_group() as tg: - value = await tg.start(taskfunc) + value: None = await tg.start(taskfunc) assert value is None async def test_start_called_twice() -> None: - async def taskfunc(*, task_status: TaskStatus) -> None: + async def taskfunc(*, task_status: TaskStatus[None]) -> None: task_status.started() with pytest.raises( @@ -119,12 +119,12 @@ async def taskfunc(*, task_status: TaskStatus) -> None: task_status.started() async with create_task_group() as tg: - value = await tg.start(taskfunc) + value: None = await tg.start(taskfunc) assert value is None async def test_start_with_value() -> None: - async def taskfunc(*, task_status: TaskStatus) -> None: + async def taskfunc(*, task_status: TaskStatus[str]) -> None: task_status.started("foo") async with create_task_group() as tg: @@ -144,7 +144,7 @@ async def taskfunc(*, task_status: TaskStatus) -> NoReturn: async def test_start_crash_after_started_call() -> None: - async def taskfunc(*, task_status: TaskStatus) -> NoReturn: + async def taskfunc(*, task_status: TaskStatus[int]) -> NoReturn: task_status.started(2) raise Exception("foo") @@ -250,7 +250,7 @@ async def taskfunc() -> None: async def test_start_exception_delivery(anyio_backend_name: str) -> None: - def task_fn(*, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: + def task_fn(*, task_status: TaskStatus[str] = TASK_STATUS_IGNORED) -> None: task_status.started("hello") if anyio_backend_name == "trio": From 99ba020f44106741820a8ea791087ca9c9708153 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 26 Jan 2024 17:09:48 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anyio/abc/_tasks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anyio/abc/_tasks.py b/src/anyio/abc/_tasks.py index ae729edf..b11a3d11 100644 --- a/src/anyio/abc/_tasks.py +++ b/src/anyio/abc/_tasks.py @@ -94,7 +94,6 @@ async def start( ) -> T_Retval: ... - @abstractmethod async def start( self,