From 76eab0e7d89fec9a41e63d7641459d991c3086ac Mon Sep 17 00:00:00 2001 From: BobTheBuidler <70677534+BobTheBuidler@users.noreply.github.com> Date: Sat, 3 Feb 2024 18:33:25 -0500 Subject: [PATCH] fix: match asyncio.create_task api (#107) --- a_sync/task.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/a_sync/task.py b/a_sync/task.py index bc168e94..98f686cf 100644 --- a/a_sync/task.py +++ b/a_sync/task.py @@ -7,9 +7,10 @@ logger = logging.getLogger(__name__) -def create_task(awaitable: Awaitable[_T], *, name: Optional[str] = None, skip_gc_until_done: bool = False) -> "asyncio.Task[_T]": +def create_task(coro: Awaitable[_T], *, name: Optional[str] = None, skip_gc_until_done: bool = False) -> "asyncio.Task[_T]": """A wrapper over `asyncio.create_task` which will work with any `Awaitable` object, not just `Coroutine` objects""" - coro = awaitable if asyncio.iscoroutine(awaitable) else __await(awaitable) + if not asyncio.iscoroutine(coro): + coro = __await(coro) task = asyncio.create_task(coro, name=name) if skip_gc_until_done: __persist(task) @@ -30,4 +31,4 @@ def __prune_persisted_tasks(): if e := task.exception(): logger.exception(e) raise e - __persisted_tasks.discard(task) \ No newline at end of file + __persisted_tasks.discard(task)