diff --git a/a_sync/__init__.py b/a_sync/__init__.py index 78841c2c..d68d60c9 100644 --- a/a_sync/__init__.py +++ b/a_sync/__init__.py @@ -7,6 +7,7 @@ from a_sync.modifiers.semaphores import apply_semaphore from a_sync.primitives import * from a_sync.singleton import ASyncGenericSingleton +from a_sync.task import create_task from a_sync.utils import all, any, as_yielded from a_sync.utils.as_completed import as_completed from a_sync.utils.gather import gather @@ -25,9 +26,11 @@ "any", "as_completed", "as_yielded", + "create_task", "exhaust_iterator", "exhaust_iterators", "gather", "ASyncIterable", "ASyncIterator", + "ASyncGenericSingleton", ] \ No newline at end of file diff --git a/a_sync/task.py b/a_sync/task.py new file mode 100644 index 00000000..bc168e94 --- /dev/null +++ b/a_sync/task.py @@ -0,0 +1,33 @@ + +import asyncio +import logging +from typing import Any, Awaitable, Optional, Set, TypeVar + +_T = TypeVar('_T') + +logger = logging.getLogger(__name__) + +def create_task(awaitable: Awaitable[_T], *, name: Optional[str] = None, skip_gc_until_done: bool = False) -> "asyncio.Task[_T]": + """A wrapper over `asyncio.create_task` which will work with any `Awaitable` object, not just `Coroutine` objects""" + coro = awaitable if asyncio.iscoroutine(awaitable) else __await(awaitable) + task = asyncio.create_task(coro, name=name) + if skip_gc_until_done: + __persist(task) + return task + +__persisted_tasks: Set["asyncio.Task[Any]"] = set() + +async def __await(awaitable: Awaitable[_T]) -> _T: + return await awaitable + +def __persist(task: "asyncio.Task[Any]") -> None: + __persisted_tasks.add(task) + __prune_persisted_tasks() + +def __prune_persisted_tasks(): + for task in __persisted_tasks: + if task.done(): + if e := task.exception(): + logger.exception(e) + raise e + __persisted_tasks.discard(task) \ No newline at end of file