From 687b369dc02d839bf3afdb1c9285f771c1ce5368 Mon Sep 17 00:00:00 2001 From: BobTheBuidler <70677534+BobTheBuidler@users.noreply.github.com> Date: Sun, 10 Nov 2024 15:33:04 -0400 Subject: [PATCH] chore: run black on all PRs (#331) * chore: run black on all PRs * chore: `black .` * Update prio_semaphore.py * Create pyproject.yaml * Update __init__.py * Update future.py * Update property.py * chore: `black .` * Update task.py * Update task.py * Update prio_semaphore.py * chore: `black .` * Update property.py --------- Co-authored-by: github-actions[bot] --- .github/workflows/black.yaml | 45 ++ a_sync/ENVIRONMENT_VARIABLES.py | 7 +- a_sync/__init__.py | 21 +- a_sync/_smart.py | 79 ++- a_sync/_typing.py | 68 ++- a_sync/a_sync/__init__.py | 28 +- a_sync/a_sync/_descriptor.py | 90 +++- a_sync/a_sync/_flags.py | 6 +- a_sync/a_sync/_helpers.py | 6 +- a_sync/a_sync/_kwargs.py | 5 +- a_sync/a_sync/_meta.py | 115 +++- a_sync/a_sync/abstract.py | 41 +- a_sync/a_sync/base.py | 50 +- a_sync/a_sync/config.py | 36 +- a_sync/a_sync/decorator.py | 139 +++-- a_sync/a_sync/function.py | 354 ++++++++++--- a_sync/a_sync/method.py | 256 ++++++--- a_sync/a_sync/modifiers/__init__.py | 10 +- a_sync/a_sync/modifiers/cache/__init__.py | 40 +- a_sync/a_sync/modifiers/cache/memory.py | 40 +- a_sync/a_sync/modifiers/limiter.py | 33 +- a_sync/a_sync/modifiers/manager.py | 51 +- a_sync/a_sync/modifiers/semaphores.py | 27 +- a_sync/a_sync/property.py | 276 +++++++--- a_sync/a_sync/singleton.py | 1 + a_sync/aliases.py | 1 - a_sync/asyncio/as_completed.py | 173 ++++-- a_sync/asyncio/create_task.py | 20 +- a_sync/asyncio/gather.py | 92 ++-- a_sync/asyncio/utils.py | 6 +- a_sync/exceptions.py | 62 ++- a_sync/executor.py | 119 +++-- a_sync/future.py | 611 ++++++++++++++++------ a_sync/iter.py | 123 +++-- a_sync/primitives/__init__.py | 1 - a_sync/primitives/_debug.py | 24 +- a_sync/primitives/_loggable.py | 9 +- a_sync/primitives/locks/__init__.py | 7 +- a_sync/primitives/locks/counter.py | 44 +- a_sync/primitives/locks/event.py | 28 +- a_sync/primitives/locks/prio_semaphore.py | 114 ++-- a_sync/primitives/locks/semaphore.py | 83 +-- a_sync/primitives/queue.py | 153 ++++-- a_sync/sphinx/__init__.py | 1 - a_sync/sphinx/ext.py | 110 ++-- a_sync/task.py | 313 +++++++---- a_sync/utils/__init__.py | 13 +- a_sync/utils/iterators.py | 147 ++++-- docs/conf.py | 106 ++-- pyproject.yaml | 2 + setup.py | 14 +- tests/conftest.py | 3 +- tests/executor.py | 3 + tests/fixtures.py | 93 ++-- tests/test_abstract.py | 11 +- tests/test_as_completed.py | 75 ++- tests/test_base.py | 83 ++- tests/test_cache.py | 46 +- tests/test_decorator.py | 82 ++- tests/test_executor.py | 21 +- tests/test_future.py | 53 +- tests/test_gather.py | 10 +- tests/test_helpers.py | 2 + tests/test_iter.py | 68 ++- tests/test_limiter.py | 8 +- tests/test_meta.py | 32 +- tests/test_modified.py | 5 +- tests/test_semaphore.py | 23 +- tests/test_singleton.py | 25 +- tests/test_task.py | 93 +++- 70 files changed, 3505 insertions(+), 1431 deletions(-) create mode 100644 .github/workflows/black.yaml create mode 100644 pyproject.yaml diff --git a/.github/workflows/black.yaml b/.github/workflows/black.yaml new file mode 100644 index 00000000..0a6f9738 --- /dev/null +++ b/.github/workflows/black.yaml @@ -0,0 +1,45 @@ +name: Black Formatter + +on: + pull_request: + branches: + - master + +jobs: + format: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v3 + with: + ref: ${{ github.head_ref }} # Check out the PR branch + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + + - name: Install Black + run: pip install black + + - name: Run Black + run: black . + + - name: Check for changes + id: changes + run: | + if [[ -n $(git status --porcelain) ]]; then + echo "changes_detected=true" >> $GITHUB_ENV + else + echo "changes_detected=false" >> $GITHUB_ENV + fi + + - name: Commit changes + if: env.changes_detected == 'true' + run: | + git config --local user.name "github-actions[bot]" + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git add . + git commit -m "chore: \`black .\`" + git push diff --git a/a_sync/ENVIRONMENT_VARIABLES.py b/a_sync/ENVIRONMENT_VARIABLES.py index 2749ca40..7cfd71fc 100644 --- a/a_sync/ENVIRONMENT_VARIABLES.py +++ b/a_sync/ENVIRONMENT_VARIABLES.py @@ -1,4 +1,3 @@ - from typed_envs import EnvVarFactory envs = EnvVarFactory("EZASYNC") @@ -6,7 +5,9 @@ # We have some envs here to help you debug your custom class implementations # If you're only interested in debugging a specific class, set this to the class name -DEBUG_CLASS_NAME = envs.create_env("DEBUG_CLASS_NAME", str, default='', verbose=False) +DEBUG_CLASS_NAME = envs.create_env("DEBUG_CLASS_NAME", str, default="", verbose=False) # Set this to enable debug mode on all classes -DEBUG_MODE = envs.create_env("DEBUG_MODE", bool, default=DEBUG_CLASS_NAME, verbose=False) +DEBUG_MODE = envs.create_env( + "DEBUG_MODE", bool, default=DEBUG_CLASS_NAME, verbose=False +) diff --git a/a_sync/__init__.py b/a_sync/__init__.py index 56137702..95420b78 100644 --- a/a_sync/__init__.py +++ b/a_sync/__init__.py @@ -1,8 +1,12 @@ - from a_sync import aliases, exceptions, iter, task from a_sync.a_sync import ASyncGenericBase, ASyncGenericSingleton, a_sync from a_sync.a_sync.modifiers.semaphores import apply_semaphore -from a_sync.a_sync.property import ASyncCachedPropertyDescriptor, ASyncPropertyDescriptor, cached_property, property +from a_sync.a_sync.property import ( + ASyncCachedPropertyDescriptor, + ASyncPropertyDescriptor, + cached_property, + property, +) from a_sync.asyncio import as_completed, create_task, gather from a_sync.executor import * from a_sync.executor import AsyncThreadPoolExecutor as ThreadPoolExecutor @@ -29,16 +33,13 @@ "exceptions", "iter", "task", - # builtins "sorted", "filter", - # asyncio "create_task", - "gather", + "gather", "as_completed", - # functions "a_sync", "all", @@ -47,33 +48,27 @@ "exhaust_iterator", "exhaust_iterators", "map", - # classes "ASyncIterable", "ASyncIterator", "ASyncGenericSingleton", - "TaskMapping", - + "TaskMapping", # property "cached_property", "property", "ASyncPropertyDescriptor", "ASyncCachedPropertyDescriptor", - # semaphores "Semaphore", "PrioritySemaphore", "ThreadsafeSemaphore", - # queues "Queue", "ProcessingQueue", "SmartProcessingQueue", - # locks "CounterLock", "Event", - # executors "AsyncThreadPoolExecutor", "PruningThreadPoolExecutor", diff --git a/a_sync/_smart.py b/a_sync/_smart.py index 79b5777d..4c21ffc8 100644 --- a/a_sync/_smart.py +++ b/a_sync/_smart.py @@ -1,4 +1,3 @@ - import asyncio import logging import warnings @@ -17,10 +16,12 @@ logger = logging.getLogger(__name__) + class _SmartFutureMixin(Generic[T]): _queue: Optional["SmartProcessingQueue[Any, Any, T]"] = None _key: _Key _waiters: "weakref.WeakSet[SmartTask[T]]" + def __await__(self: Union["SmartFuture", "SmartTask"]) -> Generator[Any, None, T]: if self.done(): return self.result() # May raise too. @@ -32,17 +33,22 @@ def __await__(self: Union["SmartFuture", "SmartTask"]) -> Generator[Any, None, T if not self.done(): raise RuntimeError("await wasn't used with future") return self.result() # May raise too. + @property def num_waiters(self: Union["SmartFuture", "SmartTask"]) -> int: # NOTE: we check .done() because the callback may not have ran yet and its very lightweight if self.done(): # if there are any waiters left, there won't be once the event loop runs once return 0 - return sum(getattr(waiter, 'num_waiters', 1) or 1 for waiter in self._waiters) - def _waiter_done_cleanup_callback(self: Union["SmartFuture", "SmartTask"], waiter: "SmartTask") -> None: + return sum(getattr(waiter, "num_waiters", 1) or 1 for waiter in self._waiters) + + def _waiter_done_cleanup_callback( + self: Union["SmartFuture", "SmartTask"], waiter: "SmartTask" + ) -> None: "Removes the waiter from _waiters, and _queue._futs if applicable" if not self.done(): self._waiters.remove(waiter) + def _self_done_cleanup_callback(self: Union["SmartFuture", "SmartTask"]) -> None: self._waiters.clear() if queue := self._queue: @@ -52,11 +58,12 @@ def _self_done_cleanup_callback(self: Union["SmartFuture", "SmartTask"]) -> None class SmartFuture(_SmartFutureMixin[T], asyncio.Future): _queue = None _key = None + def __init__( - self, - *, - queue: Optional["SmartProcessingQueue[Any, Any, T]"], - key: Optional[_Key] = None, + self, + *, + queue: Optional["SmartProcessingQueue[Any, Any, T]"], + key: Optional[_Key] = None, loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: super().__init__(loop=loop) @@ -66,55 +73,63 @@ def __init__( self._key = key self._waiters = weakref.WeakSet() self.add_done_callback(SmartFuture._self_done_cleanup_callback) + def __repr__(self): return f"<{type(self).__name__} key={self._key} waiters={self.num_waiters} {self._state}>" + def __lt__(self, other: "SmartFuture[T]") -> bool: """heap considers lower values as higher priority so a future with more waiters will be 'less than' a future with less waiters.""" - #other = other_ref() - #if other is None: + # other = other_ref() + # if other is None: # # garbage collected refs should always process first so they can be popped from the queue # return False return self.num_waiters > other.num_waiters + def create_future( *, - queue: Optional["SmartProcessingQueue"] = None, - key: Optional[_Key] = None, + queue: Optional["SmartProcessingQueue"] = None, + key: Optional[_Key] = None, loop: Optional[asyncio.AbstractEventLoop] = None, ) -> SmartFuture[V]: return SmartFuture(queue=queue, key=key, loop=loop or asyncio.get_event_loop()) + class SmartTask(_SmartFutureMixin[T], asyncio.Task): def __init__( - self, - coro: Awaitable[T], - *, - loop: Optional[asyncio.AbstractEventLoop] = None, + self, + coro: Awaitable[T], + *, + loop: Optional[asyncio.AbstractEventLoop] = None, name: Optional[str] = None, ) -> None: super().__init__(coro, loop=loop, name=name) self._waiters: Set["asyncio.Task[T]"] = set() self.add_done_callback(SmartTask._self_done_cleanup_callback) -def smart_task_factory(loop: asyncio.AbstractEventLoop, coro: Awaitable[T]) -> SmartTask[T]: + +def smart_task_factory( + loop: asyncio.AbstractEventLoop, coro: Awaitable[T] +) -> SmartTask[T]: """ Task factory function that an event loop calls to create new tasks. - + This factory function utilizes ez-a-sync's custom :class:`~SmartTask` implementation. - + Args: loop: The event loop. coro: The coroutine to run in the task. - + Returns: A SmartTask instance running the provided coroutine. """ return SmartTask(coro, loop=loop) + def set_smart_task_factory(loop: asyncio.AbstractEventLoop = None) -> None: """ Set the event loop's task factory to :func:`~smart_task_factory` so all tasks will be SmartTask instances. - + Args: loop: Optional; the event loop. If None, the current event loop is used. """ @@ -122,7 +137,10 @@ def set_smart_task_factory(loop: asyncio.AbstractEventLoop = None) -> None: loop = a_sync.asyncio.get_event_loop() loop.set_task_factory(smart_task_factory) -def shield(arg: Awaitable[T], *, loop: Optional[asyncio.AbstractEventLoop] = None) -> SmartFuture[T]: + +def shield( + arg: Awaitable[T], *, loop: Optional[asyncio.AbstractEventLoop] = None +) -> SmartFuture[T]: """ Wait for a future, shielding it from cancellation. @@ -150,9 +168,12 @@ def shield(arg: Awaitable[T], *, loop: Optional[asyncio.AbstractEventLoop] = Non res = None """ if loop is not None: - warnings.warn("The loop argument is deprecated since Python 3.8, " - "and scheduled for removal in Python 3.10.", - DeprecationWarning, stacklevel=2) + warnings.warn( + "The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, + stacklevel=2, + ) inner = asyncio.ensure_future(arg, loop=loop) if inner.done(): # Shortcut. @@ -162,6 +183,7 @@ def shield(arg: Awaitable[T], *, loop: Optional[asyncio.AbstractEventLoop] = Non # special handling to connect SmartFutures to SmartTasks if enabled if (waiters := getattr(inner, "_waiters", None)) is not None: waiters.add(outer) + def _inner_done_callback(inner): if outer.cancelled(): if not inner.cancelled(): @@ -187,4 +209,11 @@ def _outer_done_callback(outer): return outer -__all__ = ["create_future", "shield", "SmartFuture", "SmartTask", "smart_task_factory", "set_smart_task_factory"] +__all__ = [ + "create_future", + "shield", + "SmartFuture", + "SmartTask", + "smart_task_factory", + "set_smart_task_factory", +] diff --git a/a_sync/_typing.py b/a_sync/_typing.py index 3175fa7f..dcf0ab7e 100644 --- a/a_sync/_typing.py +++ b/a_sync/_typing.py @@ -7,16 +7,47 @@ import asyncio from concurrent.futures._base import Executor from decimal import Decimal -from typing import (TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, AsyncIterator, - Awaitable, Callable, Coroutine, DefaultDict, Deque, Dict, Generator, - Generic, ItemsView, Iterable, Iterator, KeysView, List, Literal, - Mapping, NoReturn, Optional, Protocol, Set, Tuple, Type, TypedDict, - TypeVar, Union, ValuesView, final, overload, runtime_checkable) +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + Coroutine, + DefaultDict, + Deque, + Dict, + Generator, + Generic, + ItemsView, + Iterable, + Iterator, + KeysView, + List, + Literal, + Mapping, + NoReturn, + Optional, + Protocol, + Set, + Tuple, + Type, + TypedDict, + TypeVar, + Union, + ValuesView, + final, + overload, + runtime_checkable, +) from typing_extensions import Concatenate, ParamSpec, Self, Unpack if TYPE_CHECKING: from a_sync import ASyncGenericBase + B = TypeVar("B", bound=ASyncGenericBase) else: B = TypeVar("B") @@ -27,7 +58,7 @@ I = TypeVar("I") """A :class:`TypeVar` that is used to represent instances of a common class.""" -E = TypeVar('E', bound=Exception) +E = TypeVar("E", bound=Exception) TYPE = TypeVar("TYPE", bound=Type) P = ParamSpec("P") @@ -51,6 +82,7 @@ AnyFn = Union[CoroFn[P, T], SyncFn[P, T]] "Type alias for any function, whether synchronous or asynchronous." + class CoroBoundMethod(Protocol[I, P, T]): """ Protocol for coroutine bound methods. @@ -59,13 +91,15 @@ class CoroBoundMethod(Protocol[I, P, T]): class MyClass: async def my_method(self, x: int) -> str: return str(x) - + instance = MyClass() bound_method: CoroBoundMethod[MyClass, [int], str] = instance.my_method """ + __self__: I __call__: Callable[P, Awaitable[T]] + class SyncBoundMethod(Protocol[I, P, T]): """ Protocol for synchronous bound methods. @@ -74,36 +108,43 @@ class SyncBoundMethod(Protocol[I, P, T]): class MyClass: def my_method(self, x: int) -> str: return str(x) - + instance = MyClass() bound_method: SyncBoundMethod[MyClass, [int], str] = instance.my_method """ + __self__: I __call__: Callable[P, T] + AnyBoundMethod = Union[CoroBoundMethod[Any, P, T], SyncBoundMethod[Any, P, T]] "Type alias for any bound method, whether synchronous or asynchronous." + @runtime_checkable class AsyncUnboundMethod(Protocol[I, P, T]): """ Protocol for unbound asynchronous methods. - + An unbound method is a method that hasn't been bound to an instance of a class yet. It's essentially the function object itself, before it's accessed through an instance. """ + __get__: Callable[[I, Type], CoroBoundMethod[I, P, T]] + @runtime_checkable class SyncUnboundMethod(Protocol[I, P, T]): """ Protocol for unbound synchronous methods. - + An unbound method is a method that hasn't been bound to an instance of a class yet. It's essentially the function object itself, before it's accessed through an instance. """ + __get__: Callable[[I, Type], SyncBoundMethod[I, P, T]] + AnyUnboundMethod = Union[AsyncUnboundMethod[I, P, T], SyncUnboundMethod[I, P, T]] "Type alias for any unbound method, whether synchronous or asynchronous." @@ -122,19 +163,21 @@ class SyncUnboundMethod(Protocol[I, P, T]): AsyncDecoratorOrCoroFn = Union[AsyncDecorator[P, T], CoroFn[P, T]] "Type alias for either an asynchronous decorator or a coroutine function." -DefaultMode = Literal['sync', 'async', None] +DefaultMode = Literal["sync", "async", None] "Type alias for default modes of operation." -CacheType = Literal['memory', None] +CacheType = Literal["memory", None] "Type alias for cache types." SemaphoreSpec = Optional[Union[asyncio.Semaphore, int]] "Type alias for semaphore specifications." + class ModifierKwargs(TypedDict, total=False): """ TypedDict for keyword arguments that modify the behavior of asynchronous operations. """ + default: DefaultMode cache_type: CacheType cache_typed: bool @@ -145,6 +188,7 @@ class ModifierKwargs(TypedDict, total=False): # sync modifiers executor: Executor + AnyIterable = Union[AsyncIterable[K], Iterable[K]] "Type alias for any iterable, whether synchronous or asynchronous." diff --git a/a_sync/a_sync/__init__.py b/a_sync/a_sync/__init__.py index 18694f46..a2c466b3 100644 --- a/a_sync/a_sync/__init__.py +++ b/a_sync/a_sync/__init__.py @@ -1,13 +1,25 @@ - from a_sync.a_sync.base import ASyncGenericBase from a_sync.a_sync.decorator import a_sync -from a_sync.a_sync.function import ASyncFunction, ASyncFunctionAsyncDefault, ASyncFunctionSyncDefault +from a_sync.a_sync.function import ( + ASyncFunction, + ASyncFunctionAsyncDefault, + ASyncFunctionSyncDefault, +) from a_sync.a_sync.modifiers.semaphores import apply_semaphore + # NOTE: Some of these we purposely import without including in __all__. Do not remove. -from a_sync.a_sync.property import (ASyncCachedPropertyDescriptor, ASyncCachedPropertyDescriptorAsyncDefault, - ASyncCachedPropertyDescriptorSyncDefault, ASyncPropertyDescriptor, - ASyncPropertyDescriptorAsyncDefault, ASyncPropertyDescriptorSyncDefault, - HiddenMethod, HiddenMethodDescriptor, cached_property, property) +from a_sync.a_sync.property import ( + ASyncCachedPropertyDescriptor, + ASyncCachedPropertyDescriptorAsyncDefault, + ASyncCachedPropertyDescriptorSyncDefault, + ASyncPropertyDescriptor, + ASyncPropertyDescriptorAsyncDefault, + ASyncPropertyDescriptorSyncDefault, + HiddenMethod, + HiddenMethodDescriptor, + cached_property, + property, +) from a_sync.a_sync.singleton import ASyncGenericSingleton @@ -15,14 +27,12 @@ # entrypoints "a_sync", "ASyncGenericBase", - # classes "ASyncFunction", - "property", "cached_property", "ASyncPropertyDescriptor", "ASyncCachedPropertyDescriptor", "HiddenMethod", "HiddenMethodDescriptor", -] \ No newline at end of file +] diff --git a/a_sync/a_sync/_descriptor.py b/a_sync/a_sync/_descriptor.py index 97e8d146..afb3e95d 100644 --- a/a_sync/a_sync/_descriptor.py +++ b/a_sync/a_sync/_descriptor.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from a_sync import TaskMapping + class ASyncDescriptor(ModifiedMixin, Generic[I, P, T]): """ A descriptor base class for asynchronous methods and properties. @@ -28,9 +29,9 @@ class ASyncDescriptor(ModifiedMixin, Generic[I, P, T]): __slots__ = "field_name", "_fget" def __init__( - self, - _fget: AnyFn[Concatenate[I, P], T], - field_name: Optional[str] = None, + self, + _fget: AnyFn[Concatenate[I, P], T], + field_name: Optional[str] = None, **modifiers: ModifierKwargs, ) -> None: """ @@ -45,19 +46,21 @@ def __init__( ValueError: If _fget is not callable. """ if not callable(_fget): - raise ValueError(f'Unable to decorate {_fget}') + raise ValueError(f"Unable to decorate {_fget}") self.modifiers = ModifierManager(modifiers) if isinstance(_fget, ASyncFunction): self.modifiers.update(_fget.modifiers) self.__wrapped__ = _fget elif asyncio.iscoroutinefunction(_fget): - self.__wrapped__: AsyncUnboundMethod[I, P, T] = self.modifiers.apply_async_modifiers(_fget) + self.__wrapped__: AsyncUnboundMethod[I, P, T] = ( + self.modifiers.apply_async_modifiers(_fget) + ) else: self.__wrapped__ = _fget self.field_name = field_name or _fget.__name__ """The name of the field the {cls} is bound to.""" - + functools.update_wrapper(self, self.__wrapped__) def __repr__(self) -> str: @@ -73,7 +76,9 @@ def __set_name__(self, owner, name): """ self.field_name = name - def map(self, *instances: AnyIterable[I], **bound_method_kwargs: P.kwargs) -> "TaskMapping[I, T]": + def map( + self, *instances: AnyIterable[I], **bound_method_kwargs: P.kwargs + ) -> "TaskMapping[I, T]": """ Create a TaskMapping for the given instances. @@ -85,6 +90,7 @@ def map(self, *instances: AnyIterable[I], **bound_method_kwargs: P.kwargs) -> "T A TaskMapping object. """ from a_sync.task import TaskMapping + return TaskMapping(self, *instances, **bound_method_kwargs) @functools.cached_property @@ -137,7 +143,13 @@ def sum(self) -> ASyncFunction[Concatenate[AnyIterable[I], P], T]: """ return decorator.a_sync(default=self.default)(self._sum) - async def _all(self, *instances: AnyIterable[I], concurrency: Optional[int] = None, name: str = "", **kwargs: P.kwargs) -> bool: + async def _all( + self, + *instances: AnyIterable[I], + concurrency: Optional[int] = None, + name: str = "", + **kwargs: P.kwargs, + ) -> bool: """ Check if all results are truthy. @@ -150,9 +162,17 @@ async def _all(self, *instances: AnyIterable[I], concurrency: Optional[int] = No Returns: A boolean indicating if all results are truthy. """ - return await self.map(*instances, concurrency=concurrency, name=name, **kwargs).all(pop=True, sync=False) - - async def _any(self, *instances: AnyIterable[I], concurrency: Optional[int] = None, name: str = "", **kwargs: P.kwargs) -> bool: + return await self.map( + *instances, concurrency=concurrency, name=name, **kwargs + ).all(pop=True, sync=False) + + async def _any( + self, + *instances: AnyIterable[I], + concurrency: Optional[int] = None, + name: str = "", + **kwargs: P.kwargs, + ) -> bool: """ Check if any result is truthy. @@ -165,9 +185,17 @@ async def _any(self, *instances: AnyIterable[I], concurrency: Optional[int] = No Returns: A boolean indicating if any result is truthy. """ - return await self.map(*instances, concurrency=concurrency, name=name, **kwargs).any(pop=True, sync=False) - - async def _min(self, *instances: AnyIterable[I], concurrency: Optional[int] = None, name: str = "", **kwargs: P.kwargs) -> T: + return await self.map( + *instances, concurrency=concurrency, name=name, **kwargs + ).any(pop=True, sync=False) + + async def _min( + self, + *instances: AnyIterable[I], + concurrency: Optional[int] = None, + name: str = "", + **kwargs: P.kwargs, + ) -> T: """ Find the minimum result. @@ -180,9 +208,17 @@ async def _min(self, *instances: AnyIterable[I], concurrency: Optional[int] = No Returns: The minimum result. """ - return await self.map(*instances, concurrency=concurrency, name=name, **kwargs).min(pop=True, sync=False) - - async def _max(self, *instances: AnyIterable[I], concurrency: Optional[int] = None, name: str = "", **kwargs: P.kwargs) -> T: + return await self.map( + *instances, concurrency=concurrency, name=name, **kwargs + ).min(pop=True, sync=False) + + async def _max( + self, + *instances: AnyIterable[I], + concurrency: Optional[int] = None, + name: str = "", + **kwargs: P.kwargs, + ) -> T: """ Find the maximum result. @@ -195,9 +231,17 @@ async def _max(self, *instances: AnyIterable[I], concurrency: Optional[int] = No Returns: The maximum result. """ - return await self.map(*instances, concurrency=concurrency, name=name, **kwargs).max(pop=True, sync=False) - - async def _sum(self, *instances: AnyIterable[I], concurrency: Optional[int] = None, name: str = "", **kwargs: P.kwargs) -> T: + return await self.map( + *instances, concurrency=concurrency, name=name, **kwargs + ).max(pop=True, sync=False) + + async def _sum( + self, + *instances: AnyIterable[I], + concurrency: Optional[int] = None, + name: str = "", + **kwargs: P.kwargs, + ) -> T: """ Calculate the sum of results. @@ -210,10 +254,12 @@ async def _sum(self, *instances: AnyIterable[I], concurrency: Optional[int] = No Returns: The sum of the results. """ - return await self.map(*instances, concurrency=concurrency, name=name, **kwargs).sum(pop=True, sync=False) + return await self.map( + *instances, concurrency=concurrency, name=name, **kwargs + ).sum(pop=True, sync=False) def __init_subclass__(cls) -> None: for attr in cls.__dict__.values(): if attr.__doc__ and "{cls}" in attr.__doc__: attr.__doc__ = attr.__doc__.replace("{cls}", f":class:`{cls.__name__}`") - return super().__init_subclass__() \ No newline at end of file + return super().__init_subclass__() diff --git a/a_sync/a_sync/_flags.py b/a_sync/a_sync/_flags.py index 53b80433..db94877f 100644 --- a/a_sync/a_sync/_flags.py +++ b/a_sync/a_sync/_flags.py @@ -11,15 +11,16 @@ from a_sync import exceptions -AFFIRMATIVE_FLAGS = {'sync'} +AFFIRMATIVE_FLAGS = {"sync"} """Set of flags indicating synchronous behavior.""" -NEGATIVE_FLAGS = {'asynchronous'} +NEGATIVE_FLAGS = {"asynchronous"} """Set of flags indicating asynchronous behavior.""" VIABLE_FLAGS = AFFIRMATIVE_FLAGS | NEGATIVE_FLAGS """Set of all valid flags.""" + def negate_if_necessary(flag: str, flag_value: bool) -> bool: """Negate the flag value if necessary based on the flag type. @@ -40,6 +41,7 @@ def negate_if_necessary(flag: str, flag_value: bool) -> bool: return bool(not flag_value) raise exceptions.InvalidFlag(flag) + def validate_flag_value(flag: str, flag_value: Any) -> bool: """ Validate that the flag value is a boolean. diff --git a/a_sync/a_sync/_helpers.py b/a_sync/a_sync/_helpers.py index 28993200..8f6a2ed5 100644 --- a/a_sync/a_sync/_helpers.py +++ b/a_sync/a_sync/_helpers.py @@ -32,6 +32,7 @@ def _await(awaitable: Awaitable[T]) -> T: raise exceptions.SyncModeInAsyncContextError from None raise + def _asyncify(func: SyncFn[P, T], executor: Executor) -> CoroFn[P, T]: # type: ignore [misc] """ Convert a synchronous function to a coroutine function. @@ -47,12 +48,15 @@ def _asyncify(func: SyncFn[P, T], executor: Executor) -> CoroFn[P, T]: # type: :class:`exceptions.FunctionNotSync`: If the input function is already asynchronous. """ from a_sync.a_sync.function import ASyncFunction + if asyncio.iscoroutinefunction(func) or isinstance(func, ASyncFunction): raise exceptions.FunctionNotSync(func) + @functools.wraps(func) async def _asyncify_wrap(*args: P.args, **kwargs: P.kwargs) -> T: return await asyncio.futures.wrap_future( - executor.submit(func, *args, **kwargs), + executor.submit(func, *args, **kwargs), loop=a_sync.asyncio.get_event_loop(), ) + return _asyncify_wrap diff --git a/a_sync/a_sync/_kwargs.py b/a_sync/a_sync/_kwargs.py index 67b044d6..c086d10a 100644 --- a/a_sync/a_sync/_kwargs.py +++ b/a_sync/a_sync/_kwargs.py @@ -25,9 +25,10 @@ def get_flag_name(kwargs: dict) -> Optional[str]: if len(present_flags) == 0: return None if len(present_flags) != 1: - raise exceptions.TooManyFlags('kwargs', present_flags) + raise exceptions.TooManyFlags("kwargs", present_flags) return present_flags[0] + def is_sync(flag: str, kwargs: dict, pop_flag: bool = False) -> bool: """ Determine if the operation should be synchronous based on the flag value. @@ -41,4 +42,4 @@ def is_sync(flag: str, kwargs: dict, pop_flag: bool = False) -> bool: True if the operation should be synchronous, False otherwise. """ flag_value = kwargs.pop(flag) if pop_flag else kwargs[flag] - return _flags.negate_if_necessary(flag, flag_value) \ No newline at end of file + return _flags.negate_if_necessary(flag, flag_value) diff --git a/a_sync/a_sync/_meta.py b/a_sync/a_sync/_meta.py index 90e3e9f0..8ff408e8 100644 --- a/a_sync/a_sync/_meta.py +++ b/a_sync/a_sync/_meta.py @@ -1,4 +1,3 @@ - import inspect import logging import threading @@ -9,54 +8,112 @@ from a_sync.a_sync import modifiers from a_sync.a_sync.function import ASyncFunction, ModifiedMixin from a_sync.a_sync.method import ASyncMethodDescriptor -from a_sync.a_sync.property import ASyncPropertyDescriptor, ASyncCachedPropertyDescriptor +from a_sync.a_sync.property import ( + ASyncPropertyDescriptor, + ASyncCachedPropertyDescriptor, +) from a_sync.future import _ASyncFutureWrappedFn # type: ignore [attr-defined] from a_sync.iter import ASyncGeneratorFunction from a_sync.primitives.locks.semaphore import Semaphore logger = logging.getLogger(__name__) + class ASyncMeta(ABCMeta): """Any class with metaclass ASyncMeta will have its functions wrapped with a_sync upon class instantiation.""" + def __new__(cls, new_class_name, bases, attrs): _update_logger(new_class_name) - logger.debug("woah, you're defining a new ASync class `%s`! let's walk thru it together", new_class_name) - logger.debug("first, I check whether you've defined any modifiers on `%s`", new_class_name) + logger.debug( + "woah, you're defining a new ASync class `%s`! let's walk thru it together", + new_class_name, + ) + logger.debug( + "first, I check whether you've defined any modifiers on `%s`", + new_class_name, + ) # NOTE: Open quesion: what do we do when a parent class and subclass define the same modifier differently? - # Currently the parent value is used for functions defined on the parent, + # Currently the parent value is used for functions defined on the parent, # and the subclass value is used for functions defined on the subclass. class_defined_modifiers = modifiers.get_modifiers_from(attrs) - logger.debug('found modifiers: %s', class_defined_modifiers) - logger.debug("now I inspect the class definition to figure out which attributes need to be wrapped") + logger.debug("found modifiers: %s", class_defined_modifiers) + logger.debug( + "now I inspect the class definition to figure out which attributes need to be wrapped" + ) for attr_name, attr_value in list(attrs.items()): if attr_name.startswith("_"): - logger.debug("`%s.%s` starts with an underscore, skipping", new_class_name, attr_name) + logger.debug( + "`%s.%s` starts with an underscore, skipping", + new_class_name, + attr_name, + ) continue elif "__" in attr_name: - logger.debug("`%s.%s` incluldes a double-underscore, skipping", new_class_name, attr_name) + logger.debug( + "`%s.%s` incluldes a double-underscore, skipping", + new_class_name, + attr_name, + ) continue elif isinstance(attr_value, (_ASyncFutureWrappedFn, Semaphore)): - logger.debug("`%s.%s` is a %s, skipping", new_class_name, attr_name, attr_value.__class__.__name__) + logger.debug( + "`%s.%s` is a %s, skipping", + new_class_name, + attr_name, + attr_value.__class__.__name__, + ) continue - logger.debug(f"inspecting `{new_class_name}.{attr_name}` of type {attr_value.__class__.__name__}") + logger.debug( + f"inspecting `{new_class_name}.{attr_name}` of type {attr_value.__class__.__name__}" + ) fn_modifiers = dict(class_defined_modifiers) # Special handling for functions decorated with a_sync decorators if isinstance(attr_value, ModifiedMixin): - logger.debug("`%s.%s` is a `ModifiedMixin` object, which means you decorated it with an a_sync decorator even though `%s` is an ASyncABC class", new_class_name, attr_name, new_class_name) - logger.debug("you probably did this so you could apply some modifiers to `%s` specifically", attr_name) + logger.debug( + "`%s.%s` is a `ModifiedMixin` object, which means you decorated it with an a_sync decorator even though `%s` is an ASyncABC class", + new_class_name, + attr_name, + new_class_name, + ) + logger.debug( + "you probably did this so you could apply some modifiers to `%s` specifically", + attr_name, + ) modified_modifiers = attr_value.modifiers._modifiers if modified_modifiers: - logger.debug("I found `%s.%s` is modified with %s", new_class_name, attr_name, modified_modifiers) + logger.debug( + "I found `%s.%s` is modified with %s", + new_class_name, + attr_name, + modified_modifiers, + ) fn_modifiers.update(modified_modifiers) else: logger.debug("I did not find any modifiers") - logger.debug("full modifier set for `%s.%s`: %s", new_class_name, attr_name, fn_modifiers) - if isinstance(attr_value, (ASyncPropertyDescriptor, ASyncCachedPropertyDescriptor)): + logger.debug( + "full modifier set for `%s.%s`: %s", + new_class_name, + attr_name, + fn_modifiers, + ) + if isinstance( + attr_value, (ASyncPropertyDescriptor, ASyncCachedPropertyDescriptor) + ): # Wrap property logger.debug("`%s is a property, now let's wrap it", attr_name) - logger.debug("since `%s` is a property, we will add a hidden dundermethod so you can still access it both sync and async", attr_name) - attrs[attr_value.hidden_method_name] = attr_value.hidden_method_descriptor - logger.debug("`%s.%s` is now %s", new_class_name, attr_value.hidden_method_name, attr_value.hidden_method_descriptor) + logger.debug( + "since `%s` is a property, we will add a hidden dundermethod so you can still access it both sync and async", + attr_name, + ) + attrs[attr_value.hidden_method_name] = ( + attr_value.hidden_method_descriptor + ) + logger.debug( + "`%s.%s` is now %s", + new_class_name, + attr_value.hidden_method_name, + attr_value.hidden_method_descriptor, + ) elif isinstance(attr_value, ASyncFunction): attrs[attr_name] = ASyncMethodDescriptor(attr_value, **fn_modifiers) else: @@ -67,15 +124,22 @@ def __new__(cls, new_class_name, bases, attrs): # NOTE We will need to improve this logic if somebody needs to use it with classmethods or staticmethods. attrs[attr_name] = ASyncMethodDescriptor(attr_value, **fn_modifiers) else: - logger.debug("`%s.%s` is not callable, we will take no action with it", new_class_name, attr_name) - return super(ASyncMeta, cls).__new__(cls, new_class_name, bases, attrs) + logger.debug( + "`%s.%s` is not callable, we will take no action with it", + new_class_name, + attr_name, + ) + return super(ASyncMeta, cls).__new__(cls, new_class_name, bases, attrs) class ASyncSingletonMeta(ASyncMeta): - def __init__(cls, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]) -> None: + def __init__( + cls, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any] + ) -> None: cls.__instances: Dict[bool, object] = {} cls.__lock = threading.Lock() super().__init__(name, bases, namespace) + def __call__(cls, *args: Any, **kwargs: Any): is_sync = cls.__a_sync_instance_will_be_sync__(args, kwargs) # type: ignore [attr-defined] if is_sync not in cls.__instances: @@ -85,8 +149,12 @@ def __call__(cls, *args: Any, **kwargs: Any): cls.__instances[is_sync] = super().__call__(*args, **kwargs) return cls.__instances[is_sync] + def _update_logger(new_class_name: str) -> None: - if ENVIRONMENT_VARIABLES.DEBUG_MODE or ENVIRONMENT_VARIABLES.DEBUG_CLASS_NAME == new_class_name: + if ( + ENVIRONMENT_VARIABLES.DEBUG_MODE + or ENVIRONMENT_VARIABLES.DEBUG_CLASS_NAME == new_class_name + ): logger.addHandler(_debug_handler) logger.setLevel(logging.DEBUG) logger.info("debug mode activated") @@ -94,6 +162,7 @@ def _update_logger(new_class_name: str) -> None: logger.removeHandler(_debug_handler) logger.setLevel(logging.INFO) + _debug_handler = logging.StreamHandler() __all__ = ["ASyncMeta", "ASyncSingletonMeta"] diff --git a/a_sync/a_sync/abstract.py b/a_sync/a_sync/abstract.py index 1802dde3..4c8bdaa1 100644 --- a/a_sync/a_sync/abstract.py +++ b/a_sync/a_sync/abstract.py @@ -1,4 +1,3 @@ - import abc import functools import logging @@ -11,6 +10,7 @@ logger = logging.getLogger(__name__) + class ASyncABC(metaclass=ASyncMeta): ################################## @@ -30,33 +30,46 @@ def __a_sync_should_await__(self, kwargs: dict) -> bool: def __a_sync_instance_should_await__(self) -> bool: """ A flag indicating whether the instance should default to asynchronous execution. - - You can override this if you want. + + You can override this if you want. If you want to be able to hotswap instance modes, you can redefine this as a non-cached property. """ - return _flags.negate_if_necessary(self.__a_sync_flag_name__, self.__a_sync_flag_value__) - + return _flags.negate_if_necessary( + self.__a_sync_flag_name__, self.__a_sync_flag_value__ + ) + def __a_sync_should_await_from_kwargs__(self, kwargs: dict) -> bool: """You can override this if you want.""" if flag := _kwargs.get_flag_name(kwargs): return _kwargs.is_sync(flag, kwargs, pop_flag=True) # type: ignore [arg-type] raise NoFlagsFound("kwargs", kwargs.keys()) - + @classmethod def __a_sync_instance_will_be_sync__(cls, args: tuple, kwargs: dict) -> bool: """You can override this if you want.""" - logger.debug("checking `%s.%s.__init__` signature against provided kwargs to determine a_sync mode for the new instance", cls.__module__, cls.__name__) + logger.debug( + "checking `%s.%s.__init__` signature against provided kwargs to determine a_sync mode for the new instance", + cls.__module__, + cls.__name__, + ) if flag := _kwargs.get_flag_name(kwargs): sync = _kwargs.is_sync(flag, kwargs) # type: ignore [arg-type] - logger.debug("kwargs indicate the new instance created with args %s %s is %ssynchronous", args, kwargs, 'a' if sync is False else '') + logger.debug( + "kwargs indicate the new instance created with args %s %s is %ssynchronous", + args, + kwargs, + "a" if sync is False else "", + ) return sync - logger.debug("No valid flags found in kwargs, checking class definition for defined default") + logger.debug( + "No valid flags found in kwargs, checking class definition for defined default" + ) return cls.__a_sync_default_mode__() # type: ignore [arg-type] ###################################### # Concrete Methods (non-overridable) # ###################################### - + @property def __a_sync_modifiers__(self: "ASyncABC") -> ModifierKwargs: """You should not override this.""" @@ -69,12 +82,12 @@ def __a_sync_modifiers__(self: "ASyncABC") -> ModifierKwargs: @abc.abstractproperty def __a_sync_flag_name__(self) -> str: pass - + @abc.abstractproperty def __a_sync_flag_value__(self) -> bool: pass - + @abc.abstractclassmethod # type: ignore [arg-type, misc] - def __a_sync_default_mode__(cls) -> bool: # type: ignore [empty-body] - # mypy doesnt recognize this abc member + def __a_sync_default_mode__(cls) -> bool: # type: ignore [empty-body] + # mypy doesnt recognize this abc member pass diff --git a/a_sync/a_sync/base.py b/a_sync/a_sync/base.py index 5c4d7ae3..bf635a94 100644 --- a/a_sync/a_sync/base.py +++ b/a_sync/a_sync/base.py @@ -11,6 +11,7 @@ logger = logging.getLogger(__name__) + class ASyncGenericBase(ASyncABC): """ Base class for creating dual-function sync/async-capable classes without writing all your code twice. @@ -35,7 +36,7 @@ async def my_property(self): @a_sync async def my_method(self): return await another_async_operation() - + # Synchronous usage obj = MyClass(sync=True) sync_result = obj.my_property @@ -56,8 +57,10 @@ async def my_method(self): def __init__(self): if type(self) is ASyncGenericBase: cls_name = type(self).__name__ - raise NotImplementedError(f"You should not create instances of `{cls_name}` directly, you should subclass `ASyncGenericBase` instead.") - + raise NotImplementedError( + f"You should not create instances of `{cls_name}` directly, you should subclass `ASyncGenericBase` instead." + ) + @functools.cached_property def __a_sync_flag_name__(self) -> str: logger.debug("checking a_sync flag for %s", self) @@ -67,8 +70,14 @@ def __a_sync_flag_name__(self) -> str: # We can't get the flag name from the __init__ signature, # but maybe the implementation sets the flag somewhere else. # Let's check the instance's atributes - logger.debug("unable to find flag name using `%s.__init__` signature, checking for flag attributes defined on %s", self.__class__.__name__, self) - present_flags = [flag for flag in _flags.VIABLE_FLAGS if hasattr(self, flag)] + logger.debug( + "unable to find flag name using `%s.__init__` signature, checking for flag attributes defined on %s", + self.__class__.__name__, + self, + ) + present_flags = [ + flag for flag in _flags.VIABLE_FLAGS if hasattr(self, flag) + ] if not present_flags: raise exceptions.NoFlagsFound(self) from None if len(present_flags) > 1: @@ -77,7 +86,7 @@ def __a_sync_flag_name__(self) -> str: if not isinstance(flag, str): raise exceptions.InvalidFlag(flag) return flag - + @functools.cached_property def __a_sync_flag_value__(self) -> bool: """If you wish to be able to hotswap default modes, just duplicate this def as a non-cached property.""" @@ -85,7 +94,7 @@ def __a_sync_flag_value__(self) -> bool: flag_value = getattr(self, flag) if not isinstance(flag_value, bool): raise exceptions.InvalidFlagValue(flag, flag_value) - logger.debug('`%s.%s` is currently %s', self, flag, flag_value) + logger.debug("`%s.%s` is currently %s", self, flag, flag_value) return flag_value @classmethod # type: ignore [misc] @@ -97,14 +106,21 @@ def __a_sync_default_mode__(cls) -> bool: # type: ignore [override] flag = cls.__get_a_sync_flag_name_from_class_def() flag_value = cls.__get_a_sync_flag_value_from_class_def(flag) sync = _flags.negate_if_necessary(flag, flag_value) # type: ignore [arg-type] - logger.debug("`%s.%s` indicates default mode is %ssynchronous", cls, flag, 'a' if sync is False else '') + logger.debug( + "`%s.%s` indicates default mode is %ssynchronous", + cls, + flag, + "a" if sync is False else "", + ) return sync - + @classmethod def __get_a_sync_flag_name_from_signature(cls) -> Optional[str]: logger.debug("Searching for flags defined on %s.__init__", cls) if cls.__name__ == "ASyncGenericBase": - logger.debug("There are no flags defined on the base class, this is expected. Skipping.") + logger.debug( + "There are no flags defined on the base class, this is expected. Skipping." + ) return None parameters = inspect.signature(cls.__init__).parameters logger.debug("parameters: %s", parameters) @@ -115,17 +131,19 @@ def __get_a_sync_flag_name_from_class_def(cls) -> str: logger.debug("Searching for flags defined on %s", cls) try: return cls.__parse_flag_name_from_list(cls.__dict__) # type: ignore [arg-type] - # idk why __dict__ doesn't type check as a dict + # idk why __dict__ doesn't type check as a dict except exceptions.NoFlagsFound: for base in cls.__bases__: with suppress(exceptions.NoFlagsFound): - return cls.__parse_flag_name_from_list(base.__dict__) # type: ignore [arg-type] - # idk why __dict__ doesn't type check as a dict + return cls.__parse_flag_name_from_list(base.__dict__) # type: ignore [arg-type] + # idk why __dict__ doesn't type check as a dict raise exceptions.NoFlagsFound(cls, list(cls.__dict__.keys())) @classmethod # type: ignore [misc] def __a_sync_flag_default_value_from_signature(cls) -> bool: - logger.debug("checking `__init__` signature for default %s a_sync flag value", cls) + logger.debug( + "checking `__init__` signature for default %s a_sync flag value", cls + ) signature = inspect.signature(cls.__init__) flag = cls.__get_a_sync_flag_name_from_signature() flag_value = signature.parameters[flag].default @@ -133,7 +151,7 @@ def __a_sync_flag_default_value_from_signature(cls) -> bool: raise NotImplementedError( "The implementation for 'cls' uses an arg to specify sync mode, instead of a kwarg. We are unable to proceed. I suppose we can extend the code to accept positional arg flags if necessary" ) - logger.debug('%s defines %s, default value %s', cls, flag, flag_value) + logger.debug("%s defines %s, default value %s", cls, flag, flag_value) return flag_value @classmethod @@ -154,4 +172,4 @@ def __parse_flag_name_from_list(cls, items: Dict[str, Any]) -> str: raise exceptions.TooManyFlags(cls, present_flags) flag = present_flags[0] logger.debug("found flag %s", flag) - return flag \ No newline at end of file + return flag diff --git a/a_sync/a_sync/config.py b/a_sync/a_sync/config.py index a3dc4921..b8449057 100644 --- a/a_sync/a_sync/config.py +++ b/a_sync/a_sync/config.py @@ -16,6 +16,7 @@ EXECUTOR_TYPE = os.environ.get("A_SYNC_EXECUTOR_TYPE", "threads") EXECUTOR_VALUE = int(os.environ.get("A_SYNC_EXECUTOR_VALUE", 8)) + @functools.lru_cache(maxsize=1) def get_default_executor() -> Executor: """ @@ -27,11 +28,14 @@ def get_default_executor() -> Executor: Raises: :class:`ValueError`: If an invalid EXECUTOR_TYPE is specified. """ - if EXECUTOR_TYPE.lower().startswith('p'): # p, P, proc, Processes, etc + if EXECUTOR_TYPE.lower().startswith("p"): # p, P, proc, Processes, etc return ProcessPoolExecutor(EXECUTOR_VALUE) - elif EXECUTOR_TYPE.lower().startswith('t'): # t, T, thread, THREADS, etc + elif EXECUTOR_TYPE.lower().startswith("t"): # t, T, thread, THREADS, etc return ThreadPoolExecutor(EXECUTOR_VALUE) - raise ValueError("Invalid value for A_SYNC_EXECUTOR_TYPE. Please use 'threads' or 'processes'.") + raise ValueError( + "Invalid value for A_SYNC_EXECUTOR_TYPE. Please use 'threads' or 'processes'." + ) + default_sync_executor = get_default_executor() @@ -53,13 +57,29 @@ def get_default_executor() -> Executor: # User configurable default modifiers to be applied to any a_sync decorated function if you do not specify kwarg values for each modifier. DEFAULT_MODE: DefaultMode = os.environ.get("A_SYNC_DEFAULT_MODE") # type: ignore [assignment] -CACHE_TYPE: CacheType = typ if (typ := os.environ.get("A_SYNC_CACHE_TYPE", "").lower()) else null_modifiers['cache_type'] +CACHE_TYPE: CacheType = ( + typ + if (typ := os.environ.get("A_SYNC_CACHE_TYPE", "").lower()) + else null_modifiers["cache_type"] +) CACHE_TYPED = bool(os.environ.get("A_SYNC_CACHE_TYPED")) -RAM_CACHE_MAXSIZE = int(os.environ.get("A_SYNC_RAM_CACHE_MAXSIZE", -1)) -RAM_CACHE_TTL = ttl if (ttl := float(os.environ.get("A_SYNC_RAM_CACHE_TTL", 0))) else null_modifiers['ram_cache_ttl'] +RAM_CACHE_MAXSIZE = int(os.environ.get("A_SYNC_RAM_CACHE_MAXSIZE", -1)) +RAM_CACHE_TTL = ( + ttl + if (ttl := float(os.environ.get("A_SYNC_RAM_CACHE_TTL", 0))) + else null_modifiers["ram_cache_ttl"] +) -RUNS_PER_MINUTE = rpm if (rpm := int(os.environ.get("A_SYNC_RUNS_PER_MINUTE", 0))) else null_modifiers['runs_per_minute'] -SEMAPHORE = rpm if (rpm := int(os.environ.get("A_SYNC_SEMAPHORE", 0))) else null_modifiers['semaphore'] +RUNS_PER_MINUTE = ( + rpm + if (rpm := int(os.environ.get("A_SYNC_RUNS_PER_MINUTE", 0))) + else null_modifiers["runs_per_minute"] +) +SEMAPHORE = ( + rpm + if (rpm := int(os.environ.get("A_SYNC_SEMAPHORE", 0))) + else null_modifiers["semaphore"] +) user_set_default_modifiers = ModifierKwargs( default=DEFAULT_MODE, diff --git a/a_sync/a_sync/decorator.py b/a_sync/a_sync/decorator.py index cda61465..ab5e71c6 100644 --- a/a_sync/a_sync/decorator.py +++ b/a_sync/a_sync/decorator.py @@ -2,9 +2,14 @@ # mypy: disable-error-code=misc from a_sync._typing import * from a_sync.a_sync import _flags, config -from a_sync.a_sync.function import (ASyncDecorator, ASyncFunction, ASyncDecoratorAsyncDefault, - ASyncDecoratorSyncDefault, ASyncFunctionAsyncDefault, - ASyncFunctionSyncDefault) +from a_sync.a_sync.function import ( + ASyncDecorator, + ASyncFunction, + ASyncDecoratorAsyncDefault, + ASyncDecoratorSyncDefault, + ASyncFunctionAsyncDefault, + ASyncFunctionSyncDefault, +) ######################## # The a_sync decorator # @@ -18,36 +23,42 @@ # async def some_fn(): # pass + @overload def a_sync( default: Literal["async"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncDecoratorAsyncDefault:... +) -> ASyncDecoratorAsyncDefault: ... + @overload def a_sync( default: Literal["sync"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncDecoratorSyncDefault:... +) -> ASyncDecoratorSyncDefault: ... + @overload def a_sync( **modifiers: Unpack[ModifierKwargs], -) -> ASyncDecorator:... +) -> ASyncDecorator: ... -@overload # async def, None default -def a_sync( + +@overload # async def, None default +def a_sync( coro_fn: CoroFn[P, T], default: Literal[None] = None, **modifiers: Unpack[ModifierKwargs], -) -> ASyncFunctionAsyncDefault[P, T]:... +) -> ASyncFunctionAsyncDefault[P, T]: ... + -@overload # sync def none default -def a_sync( +@overload # sync def none default +def a_sync( coro_fn: SyncFn[P, T], default: Literal[None] = None, **modifiers: Unpack[ModifierKwargs], -) -> ASyncFunctionSyncDefault[P, T]:... +) -> ASyncFunctionSyncDefault[P, T]: ... + # @a_sync(default='async') # def some_fn(): @@ -59,51 +70,60 @@ def a_sync( # # NOTE These should output a decorator that will be applied to 'some_fn' -@overload -def a_sync( + +@overload +def a_sync( coro_fn: Literal[None], - default: Literal['async'], + default: Literal["async"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncDecoratorAsyncDefault:... +) -> ASyncDecoratorAsyncDefault: ... + -@overload # if you try to use default as the only arg -def a_sync( - coro_fn: Literal['async'], +@overload # if you try to use default as the only arg +def a_sync( + coro_fn: Literal["async"], default: Literal[None], **modifiers: Unpack[ModifierKwargs], -) -> ASyncDecoratorAsyncDefault:... +) -> ASyncDecoratorAsyncDefault: ... + # a_sync(some_fn, default='async') -@overload # async def, async default -def a_sync( + +@overload # async def, async default +def a_sync( coro_fn: CoroFn[P, T], - default: Literal['async'], + default: Literal["async"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncFunctionAsyncDefault[P, T]:... +) -> ASyncFunctionAsyncDefault[P, T]: ... + -@overload # sync def async default -def a_sync( +@overload # sync def async default +def a_sync( coro_fn: SyncFn[P, T], - default: Literal['async'], + default: Literal["async"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncFunctionAsyncDefault[P, T]:... +) -> ASyncFunctionAsyncDefault[P, T]: ... + # a_sync(some_fn, default='sync') -@overload # async def, sync default -def a_sync( + +@overload # async def, sync default +def a_sync( coro_fn: CoroFn[P, T], - default: Literal['sync'], + default: Literal["sync"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncFunctionSyncDefault:... +) -> ASyncFunctionSyncDefault: ... -@overload # sync def sync default -def a_sync( + +@overload # sync def sync default +def a_sync( coro_fn: SyncFn[P, T], - default: Literal['sync'], + default: Literal["sync"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncFunctionSyncDefault:... +) -> ASyncFunctionSyncDefault: ... + # @a_sync(default='sync') # def some_fn(): @@ -115,32 +135,38 @@ def a_sync( # # NOTE These should output a decorator that will be applied to 'some_fn' -@overload -def a_sync( + +@overload +def a_sync( coro_fn: Literal[None], - default: Literal['sync'], + default: Literal["sync"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncDecoratorSyncDefault:... +) -> ASyncDecoratorSyncDefault: ... -@overload # if you try to use default as the only arg -def a_sync( - coro_fn: Literal['sync'], + +@overload # if you try to use default as the only arg +def a_sync( + coro_fn: Literal["sync"], default: Literal[None] = None, **modifiers: Unpack[ModifierKwargs], -) -> ASyncDecoratorSyncDefault:... +) -> ASyncDecoratorSyncDefault: ... + -@overload # if you try to use default as the only arg -def a_sync( - coro_fn: Literal['sync'], +@overload # if you try to use default as the only arg +def a_sync( + coro_fn: Literal["sync"], default: Literal[None], **modifiers: Unpack[ModifierKwargs], -) -> ASyncDecoratorSyncDefault:... - +) -> ASyncDecoratorSyncDefault: ... + + # catchall -def a_sync( +def a_sync( coro_fn: Optional[AnyFn[P, T]] = None, default: DefaultMode = config.DEFAULT_MODE, - **modifiers: Unpack[ModifierKwargs], # default values are set by passing these kwargs into a ModifierManager object. + **modifiers: Unpack[ + ModifierKwargs + ], # default values are set by passing these kwargs into a ModifierManager object. ) -> Union[ASyncDecorator, ASyncFunction[P, T]]: """ A versatile decorator that enables both synchronous and asynchronous execution of functions. @@ -155,7 +181,7 @@ def a_sync( If None, the mode is inferred from the decorated function type. **modifiers: Additional keyword arguments to modify the behavior of the decorated function. See :class:`ModifierKwargs` for available options. - + Modifiers: lib defaults: async settings @@ -240,12 +266,12 @@ def a_sync( both synchronous and asynchronous usage, or for gradually migrating synchronous code to asynchronous without breaking existing interfaces. """ - + # If the dev tried passing a default as an arg instead of a kwarg, ie: @a_sync('sync')... - if coro_fn in ['async', 'sync']: + if coro_fn in ["async", "sync"]: default = coro_fn # type: ignore [assignment] coro_fn = None - + if default == "sync": deco = ASyncDecoratorSyncDefault(default=default, **modifiers) elif default == "async": @@ -254,4 +280,5 @@ def a_sync( deco = ASyncDecorator(default=default, **modifiers) return deco if coro_fn is None else deco(coro_fn) # type: ignore [arg-type] -# TODO: in a future release, I will make this usable with sync functions as well \ No newline at end of file + +# TODO: in a future release, I will make this usable with sync functions as well diff --git a/a_sync/a_sync/function.py b/a_sync/a_sync/function.py index 2114134b..7a75cca8 100644 --- a/a_sync/a_sync/function.py +++ b/a_sync/a_sync/function.py @@ -1,14 +1,11 @@ - import functools import inspect import logging import sys from async_lru import _LRUCacheWrapper -from async_property.base import \ - AsyncPropertyDescriptor # type: ignore [import] -from async_property.cached import \ - AsyncCachedPropertyDescriptor # type: ignore [import] +from async_property.base import AsyncPropertyDescriptor # type: ignore [import] +from async_property.cached import AsyncCachedPropertyDescriptor # type: ignore [import] from a_sync._typing import * from a_sync.a_sync import _flags, _helpers, _kwargs @@ -16,11 +13,15 @@ if TYPE_CHECKING: from a_sync import TaskMapping - from a_sync.a_sync.method import (ASyncBoundMethod, ASyncBoundMethodAsyncDefault, - ASyncBoundMethodSyncDefault) + from a_sync.a_sync.method import ( + ASyncBoundMethod, + ASyncBoundMethodAsyncDefault, + ASyncBoundMethodSyncDefault, + ) logger = logging.getLogger(__name__) + class ModifiedMixin: """ A mixin class that provides functionality for applying modifiers to functions. @@ -69,16 +70,19 @@ def default(self) -> DefaultMode: def _validate_wrapped_fn(fn: Callable) -> None: """Ensures 'fn' is an appropriate function for wrapping with a_sync.""" if isinstance(fn, (AsyncPropertyDescriptor, AsyncCachedPropertyDescriptor)): - return # These are always valid + return # These are always valid if not callable(fn): - raise TypeError(f'Input is not callable. Unable to decorate {fn}') + raise TypeError(f"Input is not callable. Unable to decorate {fn}") if isinstance(fn, _LRUCacheWrapper): fn = fn.__wrapped__ _check_not_genfunc(fn) fn_args = inspect.getfullargspec(fn)[0] for flag in _flags.VIABLE_FLAGS: if flag in fn_args: - raise RuntimeError(f"{fn} must not have any arguments with the following names: {_flags.VIABLE_FLAGS}") + raise RuntimeError( + f"{fn} must not have any arguments with the following names: {_flags.VIABLE_FLAGS}" + ) + class ASyncFunction(ModifiedMixin, Generic[P, T]): """ @@ -109,9 +113,13 @@ async def my_coroutine(x: int) -> str: # NOTE: We can't use __slots__ here because it breaks functools.update_wrapper @overload - def __init__(self, fn: CoroFn[P, T], **modifiers: Unpack[ModifierKwargs]) -> None:... + def __init__( + self, fn: CoroFn[P, T], **modifiers: Unpack[ModifierKwargs] + ) -> None: ... @overload - def __init__(self, fn: SyncFn[P, T], **modifiers: Unpack[ModifierKwargs]) -> None:... + def __init__( + self, fn: SyncFn[P, T], **modifiers: Unpack[ModifierKwargs] + ) -> None: ... def __init__(self, fn: AnyFn[P, T], **modifiers: Unpack[ModifierKwargs]) -> None: """ Initialize an ASyncFunction instance. @@ -132,18 +140,26 @@ def __init__(self, fn: AnyFn[P, T], **modifiers: Unpack[ModifierKwargs]) -> None if self.__doc__ is None: self.__doc__ = f"Since `{self.__name__}` is an {self.__docstring_append__}" else: - self.__doc__ += f"\n\nSince `{self.__name__}` is an {self.__docstring_append__}" + self.__doc__ += ( + f"\n\nSince `{self.__name__}` is an {self.__docstring_append__}" + ) @overload - def __call__(self, *args: P.args, sync: Literal[True], **kwargs: P.kwargs) -> T:... + def __call__(self, *args: P.args, sync: Literal[True], **kwargs: P.kwargs) -> T: ... @overload - def __call__(self, *args: P.args, sync: Literal[False], **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:... + def __call__( + self, *args: P.args, sync: Literal[False], **kwargs: P.kwargs + ) -> Coroutine[Any, Any, T]: ... @overload - def __call__(self, *args: P.args, asynchronous: Literal[False], **kwargs: P.kwargs) -> T:... + def __call__( + self, *args: P.args, asynchronous: Literal[False], **kwargs: P.kwargs + ) -> T: ... @overload - def __call__(self, *args: P.args, asynchronous: Literal[True], **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:... + def __call__( + self, *args: P.args, asynchronous: Literal[True], **kwargs: P.kwargs + ) -> Coroutine[Any, Any, T]: ... @overload - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> MaybeCoro[T]:... + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> MaybeCoro[T]: ... def __call__(self, *args: P.args, **kwargs: P.kwargs) -> MaybeCoro[T]: """ Call the wrapped function either synchronously or asynchronously. @@ -161,14 +177,18 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> MaybeCoro[T]: Raises: Exception: Any exception that may be raised by the wrapped function. """ - logger.debug("calling %s fn: %s with args: %s kwargs: %s", self, self.fn, args, kwargs) + logger.debug( + "calling %s fn: %s with args: %s kwargs: %s", self, self.fn, args, kwargs + ) return self.fn(*args, **kwargs) def __repr__(self) -> str: return f"<{self.__class__.__name__} {self.__module__}.{self.__name__} at {hex(id(self))}>" @functools.cached_property - def fn(self): # -> Union[SyncFn[[CoroFn[P, T]], MaybeAwaitable[T]], SyncFn[[SyncFn[P, T]], MaybeAwaitable[T]]]: + def fn( + self, + ): # -> Union[SyncFn[[CoroFn[P, T]], MaybeAwaitable[T]], SyncFn[[SyncFn[P, T]], MaybeAwaitable[T]]]: """ Returns the final wrapped version of :attr:`ASyncFunction._fn` decorated with all of the a_sync goodness. @@ -179,7 +199,13 @@ def fn(self): # -> Union[SyncFn[[CoroFn[P, T]], MaybeAwaitable[T]], SyncFn[[Sync if sys.version_info >= (3, 11) or TYPE_CHECKING: # we can specify P.args in python>=3.11 but in lower versions it causes a crash. Everything should still type check correctly on all versions. - def map(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> "TaskMapping[P, T]": + def map( + self, + *iterables: AnyIterable[P.args], + concurrency: Optional[int] = None, + task_name: str = "", + **function_kwargs: P.kwargs, + ) -> "TaskMapping[P, T]": """ Create a TaskMapping for the wrapped function with the given iterables. @@ -193,9 +219,22 @@ def map(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] = None A TaskMapping object. """ from a_sync import TaskMapping - return TaskMapping(self, *iterables, concurrency=concurrency, name=task_name, **function_kwargs) - async def any(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> bool: + return TaskMapping( + self, + *iterables, + concurrency=concurrency, + name=task_name, + **function_kwargs, + ) + + async def any( + self, + *iterables: AnyIterable[P.args], + concurrency: Optional[int] = None, + task_name: str = "", + **function_kwargs: P.kwargs, + ) -> bool: """ Check if any result of the function applied to the iterables is truthy. @@ -208,9 +247,20 @@ async def any(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] Returns: A boolean indicating if any result is truthy. """ - return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).any(pop=True, sync=False) - - async def all(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> bool: + return await self.map( + *iterables, + concurrency=concurrency, + task_name=task_name, + **function_kwargs, + ).any(pop=True, sync=False) + + async def all( + self, + *iterables: AnyIterable[P.args], + concurrency: Optional[int] = None, + task_name: str = "", + **function_kwargs: P.kwargs, + ) -> bool: """ Check if all results of the function applied to the iterables are truthy. @@ -223,9 +273,20 @@ async def all(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] Returns: A boolean indicating if all results are truthy. """ - return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).all(pop=True, sync=False) - - async def min(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> T: + return await self.map( + *iterables, + concurrency=concurrency, + task_name=task_name, + **function_kwargs, + ).all(pop=True, sync=False) + + async def min( + self, + *iterables: AnyIterable[P.args], + concurrency: Optional[int] = None, + task_name: str = "", + **function_kwargs: P.kwargs, + ) -> T: """ Find the minimum result of the function applied to the iterables. @@ -238,9 +299,20 @@ async def min(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] Returns: The minimum result. """ - return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).min(pop=True, sync=False) - - async def max(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> T: + return await self.map( + *iterables, + concurrency=concurrency, + task_name=task_name, + **function_kwargs, + ).min(pop=True, sync=False) + + async def max( + self, + *iterables: AnyIterable[P.args], + concurrency: Optional[int] = None, + task_name: str = "", + **function_kwargs: P.kwargs, + ) -> T: """ Find the maximum result of the function applied to the iterables. @@ -253,9 +325,20 @@ async def max(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] Returns: The maximum result. """ - return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).max(pop=True, sync=False) - - async def sum(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> T: + return await self.map( + *iterables, + concurrency=concurrency, + task_name=task_name, + **function_kwargs, + ).max(pop=True, sync=False) + + async def sum( + self, + *iterables: AnyIterable[P.args], + concurrency: Optional[int] = None, + task_name: str = "", + **function_kwargs: P.kwargs, + ) -> T: """ Calculate the sum of the results of the function applied to the iterables. @@ -268,9 +351,22 @@ async def sum(self, *iterables: AnyIterable[P.args], concurrency: Optional[int] Returns: The sum of the results. """ - return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).sum(pop=True, sync=False) + return await self.map( + *iterables, + concurrency=concurrency, + task_name=task_name, + **function_kwargs, + ).sum(pop=True, sync=False) + else: - def map(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> "TaskMapping[P, T]": + + def map( + self, + *iterables: AnyIterable[Any], + concurrency: Optional[int] = None, + task_name: str = "", + **function_kwargs: P.kwargs, + ) -> "TaskMapping[P, T]": """ Create a TaskMapping for the wrapped function with the given iterables. @@ -284,9 +380,22 @@ def map(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = None, t A TaskMapping object. """ from a_sync import TaskMapping - return TaskMapping(self, *iterables, concurrency=concurrency, name=task_name, **function_kwargs) - async def any(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> bool: + return TaskMapping( + self, + *iterables, + concurrency=concurrency, + name=task_name, + **function_kwargs, + ) + + async def any( + self, + *iterables: AnyIterable[Any], + concurrency: Optional[int] = None, + task_name: str = "", + **function_kwargs: P.kwargs, + ) -> bool: """ Check if any result of the function applied to the iterables is truthy. @@ -299,9 +408,20 @@ async def any(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = N Returns: A boolean indicating if any result is truthy. """ - return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).any(pop=True, sync=False) - - async def all(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> bool: + return await self.map( + *iterables, + concurrency=concurrency, + task_name=task_name, + **function_kwargs, + ).any(pop=True, sync=False) + + async def all( + self, + *iterables: AnyIterable[Any], + concurrency: Optional[int] = None, + task_name: str = "", + **function_kwargs: P.kwargs, + ) -> bool: """ Check if all results of the function applied to the iterables are truthy. @@ -314,9 +434,20 @@ async def all(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = N Returns: A boolean indicating if all results are truthy. """ - return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).all(pop=True, sync=False) - - async def min(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> T: + return await self.map( + *iterables, + concurrency=concurrency, + task_name=task_name, + **function_kwargs, + ).all(pop=True, sync=False) + + async def min( + self, + *iterables: AnyIterable[Any], + concurrency: Optional[int] = None, + task_name: str = "", + **function_kwargs: P.kwargs, + ) -> T: """ Find the minimum result of the function applied to the iterables. @@ -329,9 +460,20 @@ async def min(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = N Returns: The minimum result. """ - return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).min(pop=True, sync=False) - - async def max(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> T: + return await self.map( + *iterables, + concurrency=concurrency, + task_name=task_name, + **function_kwargs, + ).min(pop=True, sync=False) + + async def max( + self, + *iterables: AnyIterable[Any], + concurrency: Optional[int] = None, + task_name: str = "", + **function_kwargs: P.kwargs, + ) -> T: """ Find the maximum result of the function applied to the iterables. @@ -344,9 +486,20 @@ async def max(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = N Returns: The maximum result. """ - return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).max(pop=True, sync=False) - - async def sum(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs) -> T: + return await self.map( + *iterables, + concurrency=concurrency, + task_name=task_name, + **function_kwargs, + ).max(pop=True, sync=False) + + async def sum( + self, + *iterables: AnyIterable[Any], + concurrency: Optional[int] = None, + task_name: str = "", + **function_kwargs: P.kwargs, + ) -> T: """ Calculate the sum of the results of the function applied to the iterables. @@ -359,7 +512,12 @@ async def sum(self, *iterables: AnyIterable[Any], concurrency: Optional[int] = N Returns: The sum of the results. """ - return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **function_kwargs).sum(pop=True, sync=False) + return await self.map( + *iterables, + concurrency=concurrency, + task_name=task_name, + **function_kwargs, + ).sum(pop=True, sync=False) @functools.cached_property def _sync_default(self) -> bool: @@ -372,7 +530,11 @@ def _sync_default(self) -> bool: Returns: True if the default is sync, False if the default is async. """ - return True if self.default == 'sync' else False if self.default == 'async' else not self._async_def + return ( + True + if self.default == "sync" + else False if self.default == "async" else not self._async_def + ) @functools.cached_property def _async_def(self) -> bool: @@ -413,7 +575,9 @@ def _asyncified(self) -> CoroFn[P, T]: An asynchronous function with both sync and async modifiers applied. """ if self._async_def: - raise TypeError(f"Can only be applied to sync functions, not {self.__wrapped__}") + raise TypeError( + f"Can only be applied to sync functions, not {self.__wrapped__}" + ) return self._asyncify(self._modified_fn) # type: ignore [arg-type] @functools.cached_property @@ -432,7 +596,7 @@ def _modified_fn(self) -> AnyFn[P, T]: return self.modifiers.apply_sync_modifiers(self.__wrapped__) # type: ignore [return-value] @functools.cached_property - def _async_wrap(self): # -> SyncFn[[CoroFn[P, T]], MaybeAwaitable[T]]: + def _async_wrap(self): # -> SyncFn[[CoroFn[P, T]], MaybeAwaitable[T]]: """ The final wrapper if the wrapped function is an asynchronous function. @@ -441,15 +605,19 @@ def _async_wrap(self): # -> SyncFn[[CoroFn[P, T]], MaybeAwaitable[T]]: Returns: The final wrapped function. """ + @functools.wraps(self._modified_fn) def async_wrap(*args: P.args, **kwargs: P.kwargs) -> MaybeAwaitable[T]: # type: ignore [name-defined] - should_await = self._run_sync(kwargs) # Must take place before coro is created, we're popping a kwarg. + should_await = self._run_sync( + kwargs + ) # Must take place before coro is created, we're popping a kwarg. coro = self._modified_fn(*args, **kwargs) return self._await(coro) if should_await else coro + return async_wrap @functools.cached_property - def _sync_wrap(self): # -> SyncFn[[SyncFn[P, T]], MaybeAwaitable[T]]: + def _sync_wrap(self): # -> SyncFn[[SyncFn[P, T]], MaybeAwaitable[T]]: """ The final wrapper if the wrapped function is a synchronous function. @@ -458,36 +626,44 @@ def _sync_wrap(self): # -> SyncFn[[SyncFn[P, T]], MaybeAwaitable[T]]: Returns: The final wrapped function. """ + @functools.wraps(self._modified_fn) def sync_wrap(*args: P.args, **kwargs: P.kwargs) -> MaybeAwaitable[T]: # type: ignore [name-defined] if self._run_sync(kwargs): return self._modified_fn(*args, **kwargs) - return self._asyncified(*args, **kwargs) + return self._asyncified(*args, **kwargs) + return sync_wrap __docstring_append__ = ":class:`~a_sync.a_sync.function.ASyncFunction`, you can optionally pass either a `sync` or `asynchronous` kwarg with a boolean value." + if sys.version_info < (3, 10): _inherit = ASyncFunction[AnyFn[P, T], ASyncFunction[P, T]] else: _inherit = ASyncFunction[[AnyFn[P, T]], ASyncFunction[P, T]] - + + class ASyncDecorator(ModifiedMixin): def __init__(self, **modifiers: Unpack[ModifierKwargs]) -> None: - assert 'default' in modifiers, modifiers + assert "default" in modifiers, modifiers self.modifiers = ModifierManager(modifiers) self.validate_inputs() - + def validate_inputs(self) -> None: - if self.modifiers.default not in ['sync', 'async', None]: - raise ValueError(f"'default' must be either 'sync', 'async', or None. You passed {self.modifiers.default}.") - + if self.modifiers.default not in ["sync", "async", None]: + raise ValueError( + f"'default' must be either 'sync', 'async', or None. You passed {self.modifiers.default}." + ) + @overload def __call__(self, func: AnyFn[Concatenate[B, P], T]) -> "ASyncBoundMethod[B, P, T]": # type: ignore [override] ... + @overload def __call__(self, func: AnyFn[P, T]) -> ASyncFunction[P, T]: # type: ignore [override] ... + def __call__(self, func: AnyFn[P, T]) -> ASyncFunction[P, T]: # type: ignore [override] if self.default == "async": return ASyncFunctionAsyncDefault(func, **self.modifiers) @@ -498,6 +674,7 @@ def __call__(self, func: AnyFn[P, T]) -> ASyncFunction[P, T]: # type: ignore [o else: return ASyncFunctionSyncDefault(func, **self.modifiers) + def _check_not_genfunc(func: Callable) -> None: if inspect.isasyncgenfunction(func) or inspect.isgeneratorfunction(func): raise ValueError("unable to decorate generator functions with this decorator") @@ -505,6 +682,7 @@ def _check_not_genfunc(func: Callable) -> None: # Mypy helper classes + class ASyncFunctionSyncDefault(ASyncFunction[P, T]): """A specialized :class:`~ASyncFunction` that defaults to synchronous execution. @@ -528,16 +706,23 @@ async def my_function(x: int) -> str: result = await my_function(5, sync=False) # returns "5" ``` """ + @overload - def __call__(self, *args: P.args, sync: Literal[True], **kwargs: P.kwargs) -> T:... + def __call__(self, *args: P.args, sync: Literal[True], **kwargs: P.kwargs) -> T: ... @overload - def __call__(self, *args: P.args, sync: Literal[False], **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:... + def __call__( + self, *args: P.args, sync: Literal[False], **kwargs: P.kwargs + ) -> Coroutine[Any, Any, T]: ... @overload - def __call__(self, *args: P.args, asynchronous: Literal[False], **kwargs: P.kwargs) -> T:... + def __call__( + self, *args: P.args, asynchronous: Literal[False], **kwargs: P.kwargs + ) -> T: ... @overload - def __call__(self, *args: P.args, asynchronous: Literal[True], **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:... + def __call__( + self, *args: P.args, asynchronous: Literal[True], **kwargs: P.kwargs + ) -> Coroutine[Any, Any, T]: ... @overload - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:... + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ... def __call__(self, *args: P.args, **kwargs: P.kwargs) -> MaybeCoro[T]: """Call the wrapped function, defaulting to synchronous execution. @@ -557,13 +742,14 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> MaybeCoro[T]: return self.fn(*args, **kwargs) __docstring_append__ = ":class:`~a_sync.a_sync.function.ASyncFunctionSyncDefault`, you can optionally pass `sync=False` or `asynchronous=True` to force it to return a coroutine. Without either kwarg, it will run synchronously." - + + class ASyncFunctionAsyncDefault(ASyncFunction[P, T]): """ A specialized :class:`~ASyncFunction` that defaults to asynchronous execution. This class is used when the :func:`~a_sync` decorator is applied with `default='async'`. - It provides type hints to indicate that the default call behavior is asynchronous + It provides type hints to indicate that the default call behavior is asynchronous and supports IDE type checking for most use cases. The wrapped function can still be called synchronously by passing `sync=True` @@ -582,16 +768,23 @@ async def my_function(x: int) -> str: result = my_function(5, sync=True) # returns "5" ``` """ + @overload - def __call__(self, *args: P.args, sync: Literal[True], **kwargs: P.kwargs) -> T:... + def __call__(self, *args: P.args, sync: Literal[True], **kwargs: P.kwargs) -> T: ... @overload - def __call__(self, *args: P.args, sync: Literal[False], **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:... + def __call__( + self, *args: P.args, sync: Literal[False], **kwargs: P.kwargs + ) -> Coroutine[Any, Any, T]: ... @overload - def __call__(self, *args: P.args, asynchronous: Literal[False], **kwargs: P.kwargs) -> T:... + def __call__( + self, *args: P.args, asynchronous: Literal[False], **kwargs: P.kwargs + ) -> T: ... @overload - def __call__(self, *args: P.args, asynchronous: Literal[True], **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:... + def __call__( + self, *args: P.args, asynchronous: Literal[True], **kwargs: P.kwargs + ) -> Coroutine[Any, Any, T]: ... @overload - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:... + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]: ... def __call__(self, *args: P.args, **kwargs: P.kwargs) -> MaybeCoro[T]: """Call the wrapped function, defaulting to asynchronous execution. @@ -612,29 +805,36 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> MaybeCoro[T]: __docstring_append__ = ":class:`~a_sync.a_sync.function.ASyncFunctionAsyncDefault`, you can optionally pass `sync=True` or `asynchronous=False` to force it to run synchronously and return a value. Without either kwarg, it will return a coroutine for you to await." + class ASyncDecoratorSyncDefault(ASyncDecorator): @overload def __call__(self, func: AnyFn[Concatenate[B, P], T]) -> "ASyncBoundMethodSyncDefault[P, T]": # type: ignore [override] ... + @overload def __call__(self, func: AnyBoundMethod[P, T]) -> ASyncFunctionSyncDefault[P, T]: # type: ignore [override] ... + @overload def __call__(self, func: AnyFn[P, T]) -> ASyncFunctionSyncDefault[P, T]: # type: ignore [override] ... + def __call__(self, func: AnyFn[P, T]) -> ASyncFunctionSyncDefault[P, T]: return ASyncFunctionSyncDefault(func, **self.modifiers) + class ASyncDecoratorAsyncDefault(ASyncDecorator): @overload def __call__(self, func: AnyFn[Concatenate[B, P], T]) -> "ASyncBoundMethodAsyncDefault[P, T]": # type: ignore [override] ... + @overload def __call__(self, func: AnyBoundMethod[P, T]) -> ASyncFunctionAsyncDefault[P, T]: # type: ignore [override] ... + @overload def __call__(self, func: AnyFn[P, T]) -> ASyncFunctionAsyncDefault[P, T]: # type: ignore [override] ... + def __call__(self, func: AnyFn[P, T]) -> ASyncFunctionAsyncDefault[P, T]: return ASyncFunctionAsyncDefault(func, **self.modifiers) - diff --git a/a_sync/a_sync/method.py b/a_sync/a_sync/method.py index f0242dc1..981cb3f9 100644 --- a/a_sync/a_sync/method.py +++ b/a_sync/a_sync/method.py @@ -16,7 +16,11 @@ from a_sync._typing import * from a_sync.a_sync import _helpers, _kwargs from a_sync.a_sync._descriptor import ASyncDescriptor -from a_sync.a_sync.function import ASyncFunction, ASyncFunctionAsyncDefault, ASyncFunctionSyncDefault +from a_sync.a_sync.function import ( + ASyncFunction, + ASyncFunctionAsyncDefault, + ASyncFunctionSyncDefault, +) if TYPE_CHECKING: from a_sync import TaskMapping @@ -24,9 +28,10 @@ logger = logging.getLogger(__name__) + class ASyncMethodDescriptor(ASyncDescriptor[I, P, T]): """ - This class provides the core functionality for creating :class:`ASyncBoundMethod` objects, + This class provides the core functionality for creating :class:`ASyncBoundMethod` objects, which can be used to define methods that can be called both synchronously and asynchronously. """ @@ -46,14 +51,22 @@ async def __call__(self, instance: I, *args: P.args, **kwargs: P.kwargs) -> T: The result of the method call. """ # NOTE: This is only used by TaskMapping atm # TODO: use it elsewhere - logger.debug("awaiting %s for instance: %s args: %s kwargs: %s", self, instance, args, kwargs) + logger.debug( + "awaiting %s for instance: %s args: %s kwargs: %s", + self, + instance, + args, + kwargs, + ) return await self.__get__(instance, None)(*args, **kwargs) @overload - def __get__(self, instance: None, owner: Type[I]) -> Self:... + def __get__(self, instance: None, owner: Type[I]) -> Self: ... @overload - def __get__(self, instance: I, owner: Type[I]) -> "ASyncBoundMethod[I, P, T]":... - def __get__(self, instance: Optional[I], owner: Type[I]) -> Union[Self, "ASyncBoundMethod[I, P, T]"]: + def __get__(self, instance: I, owner: Type[I]) -> "ASyncBoundMethod[I, P, T]": ... + def __get__( + self, instance: Optional[I], owner: Type[I] + ) -> Union[Self, "ASyncBoundMethod[I, P, T]"]: """ Get the bound method or the descriptor itself. @@ -72,20 +85,42 @@ def __get__(self, instance: Optional[I], owner: Type[I]) -> Union[Self, "ASyncBo bound._cache_handle.cancel() except KeyError: from a_sync.a_sync.abstract import ASyncABC + if self.default == "sync": - bound = ASyncBoundMethodSyncDefault(instance, self.__wrapped__, self.__is_async_def__, **self.modifiers) + bound = ASyncBoundMethodSyncDefault( + instance, self.__wrapped__, self.__is_async_def__, **self.modifiers + ) elif self.default == "async": - bound = ASyncBoundMethodAsyncDefault(instance, self.__wrapped__, self.__is_async_def__, **self.modifiers) + bound = ASyncBoundMethodAsyncDefault( + instance, self.__wrapped__, self.__is_async_def__, **self.modifiers + ) elif isinstance(instance, ASyncABC): try: if instance.__a_sync_instance_should_await__: - bound = ASyncBoundMethodSyncDefault(instance, self.__wrapped__, self.__is_async_def__, **self.modifiers) + bound = ASyncBoundMethodSyncDefault( + instance, + self.__wrapped__, + self.__is_async_def__, + **self.modifiers, + ) else: - bound = ASyncBoundMethodAsyncDefault(instance, self.__wrapped__, self.__is_async_def__, **self.modifiers) + bound = ASyncBoundMethodAsyncDefault( + instance, + self.__wrapped__, + self.__is_async_def__, + **self.modifiers, + ) except AttributeError: - bound = ASyncBoundMethod(instance, self.__wrapped__, self.__is_async_def__, **self.modifiers) + bound = ASyncBoundMethod( + instance, + self.__wrapped__, + self.__is_async_def__, + **self.modifiers, + ) else: - bound = ASyncBoundMethod(instance, self.__wrapped__, self.__is_async_def__, **self.modifiers) + bound = ASyncBoundMethod( + instance, self.__wrapped__, self.__is_async_def__, **self.modifiers + ) instance.__dict__[self.field_name] = bound logger.debug("new bound method: %s", bound) # Handler for popping unused bound methods from bound method cache @@ -103,7 +138,9 @@ def __set__(self, instance, value): Raises: :class:`RuntimeError`: Always raised to prevent setting. """ - raise RuntimeError(f"cannot set {self.field_name}, {self} is what you get. sorry.") + raise RuntimeError( + f"cannot set {self.field_name}, {self} is what you get. sorry." + ) def __delete__(self, instance): """ @@ -115,7 +152,9 @@ def __delete__(self, instance): Raises: :class:`RuntimeError`: Always raised to prevent deletion. """ - raise RuntimeError(f"cannot delete {self.field_name}, you're stuck with {self} forever. sorry.") + raise RuntimeError( + f"cannot delete {self.field_name}, you're stuck with {self} forever. sorry." + ) @functools.cached_property def __is_async_def__(self) -> bool: @@ -138,7 +177,10 @@ def _get_cache_handle(self, instance: I) -> asyncio.TimerHandle: A timer handle for cache management. """ # NOTE: use `instance.__dict__.pop` instead of `delattr` so we don't create a strong ref to `instance` - return asyncio.get_event_loop().call_later(300, instance.__dict__.pop, self.field_name) + return asyncio.get_event_loop().call_later( + 300, instance.__dict__.pop, self.field_name + ) + @final class ASyncMethodDescriptorSyncDefault(ASyncMethodDescriptor[I, P, T]): @@ -168,10 +210,18 @@ class ASyncMethodDescriptorSyncDefault(ASyncMethodDescriptor[I, P, T]): """Synchronous default version of the :meth:`~ASyncMethodDescriptor.sum` method.""" @overload - def __get__(self, instance: None, owner: Type[I] = None) -> "ASyncMethodDescriptorSyncDefault[I, P, T]":... + def __get__( + self, instance: None, owner: Type[I] = None + ) -> "ASyncMethodDescriptorSyncDefault[I, P, T]": ... @overload - def __get__(self, instance: I, owner: Type[I] = None) -> "ASyncBoundMethodSyncDefault[I, P, T]":... - def __get__(self, instance: Optional[I], owner: Type[I] = None) -> "Union[ASyncMethodDescriptorSyncDefault, ASyncBoundMethodSyncDefault[I, P, T]]": + def __get__( + self, instance: I, owner: Type[I] = None + ) -> "ASyncBoundMethodSyncDefault[I, P, T]": ... + def __get__( + self, instance: Optional[I], owner: Type[I] = None + ) -> ( + "Union[ASyncMethodDescriptorSyncDefault, ASyncBoundMethodSyncDefault[I, P, T]]" + ): """ Get the bound method or the descriptor itself. @@ -189,13 +239,16 @@ def __get__(self, instance: Optional[I], owner: Type[I] = None) -> "Union[ASyncM # we will set a new one in the finally block bound._cache_handle.cancel() except KeyError: - bound = ASyncBoundMethodSyncDefault(instance, self.__wrapped__, self.__is_async_def__, **self.modifiers) + bound = ASyncBoundMethodSyncDefault( + instance, self.__wrapped__, self.__is_async_def__, **self.modifiers + ) instance.__dict__[self.field_name] = bound logger.debug("new bound method: %s", bound) # Handler for popping unused bound methods from bound method cache bound._cache_handle = self._get_cache_handle(instance) return bound + @final class ASyncMethodDescriptorAsyncDefault(ASyncMethodDescriptor[I, P, T]): """ @@ -213,21 +266,27 @@ class ASyncMethodDescriptorAsyncDefault(ASyncMethodDescriptor[I, P, T]): all: ASyncFunctionAsyncDefault[Concatenate[AnyIterable[I], P], bool] """Asynchronous default version of the :meth:`~ASyncMethodDescriptor.all` method.""" - + min: ASyncFunctionAsyncDefault[Concatenate[AnyIterable[I], P], T] """Asynchronous default version of the :meth:`~ASyncMethodDescriptor.min` method.""" - + max: ASyncFunctionAsyncDefault[Concatenate[AnyIterable[I], P], T] """Asynchronous default version of the :meth:`~ASyncMethodDescriptor.max` method.""" - + sum: ASyncFunctionAsyncDefault[Concatenate[AnyIterable[I], P], T] """Asynchronous default version of the :meth:`~ASyncMethodDescriptor.sum` method.""" @overload - def __get__(self, instance: None, owner: Type[I]) -> "ASyncMethodDescriptorAsyncDefault[I, P, T]":... + def __get__( + self, instance: None, owner: Type[I] + ) -> "ASyncMethodDescriptorAsyncDefault[I, P, T]": ... @overload - def __get__(self, instance: I, owner: Type[I]) -> "ASyncBoundMethodAsyncDefault[I, P, T]":... - def __get__(self, instance: Optional[I], owner: Type[I]) -> "Union[ASyncMethodDescriptorAsyncDefault, ASyncBoundMethodAsyncDefault[I, P, T]]": + def __get__( + self, instance: I, owner: Type[I] + ) -> "ASyncBoundMethodAsyncDefault[I, P, T]": ... + def __get__( + self, instance: Optional[I], owner: Type[I] + ) -> "Union[ASyncMethodDescriptorAsyncDefault, ASyncBoundMethodAsyncDefault[I, P, T]]": """ Get the bound method or the descriptor itself. @@ -245,13 +304,16 @@ def __get__(self, instance: Optional[I], owner: Type[I]) -> "Union[ASyncMethodDe # we will set a new one in the finally block bound._cache_handle.cancel() except KeyError: - bound = ASyncBoundMethodAsyncDefault(instance, self.__wrapped__, self.__is_async_def__, **self.modifiers) + bound = ASyncBoundMethodAsyncDefault( + instance, self.__wrapped__, self.__is_async_def__, **self.modifiers + ) instance.__dict__[self.field_name] = bound logger.debug("new bound method: %s", bound) # Handler for popping unused bound methods from bound method cache bound._cache_handle = self._get_cache_handle(instance) return bound + class ASyncBoundMethod(ASyncFunction[P, T], Generic[I, P, T]): """ A bound method that can be called both synchronously and asynchronously. @@ -259,6 +321,7 @@ class ASyncBoundMethod(ASyncFunction[P, T], Generic[I, P, T]): This class represents a method bound to an instance, which can be called either synchronously or asynchronously based on various conditions. """ + # NOTE: this is created by the Descriptor _cache_handle: asyncio.TimerHandle @@ -273,9 +336,9 @@ class ASyncBoundMethod(ASyncFunction[P, T], Generic[I, P, T]): __slots__ = "_is_async_def", "__weakself__" def __init__( - self, - instance: I, - unbound: AnyFn[Concatenate[I, P], T], + self, + instance: I, + unbound: AnyFn[Concatenate[I, P], T], async_def: bool, **modifiers: Unpack[ModifierKwargs], ) -> None: @@ -312,15 +375,21 @@ def __repr__(self) -> str: return f"<{self.__class__.__name__} for function COLLECTED.COLLECTED.{self.__name__} bound to {self.__weakself__}>" @overload - def __call__(self, *args: P.args, sync: Literal[True], **kwargs: P.kwargs) -> T:... + def __call__(self, *args: P.args, sync: Literal[True], **kwargs: P.kwargs) -> T: ... @overload - def __call__(self, *args: P.args, sync: Literal[False], **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:... + def __call__( + self, *args: P.args, sync: Literal[False], **kwargs: P.kwargs + ) -> Coroutine[Any, Any, T]: ... @overload - def __call__(self, *args: P.args, asynchronous: Literal[False], **kwargs: P.kwargs) -> T:... + def __call__( + self, *args: P.args, asynchronous: Literal[False], **kwargs: P.kwargs + ) -> T: ... @overload - def __call__(self, *args: P.args, asynchronous: Literal[True], **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:... + def __call__( + self, *args: P.args, asynchronous: Literal[True], **kwargs: P.kwargs + ) -> Coroutine[Any, Any, T]: ... @overload - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> MaybeCoro[T]:... + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> MaybeCoro[T]: ... def __call__(self, *args: P.args, **kwargs: P.kwargs) -> MaybeCoro[T]: """ Call the bound method. @@ -345,9 +414,13 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> MaybeCoro[T]: pass elif self._should_await(kwargs): # The awaitable was not awaited, so now we need to check the flag as defined on 'self' and await if appropriate. - logger.debug("awaiting %s for %s args: %s kwargs: %s", coro, self, args, kwargs) + logger.debug( + "awaiting %s for %s args: %s kwargs: %s", coro, self, args, kwargs + ) retval = _helpers._await(coro) - logger.debug("returning %s for %s args: %s kwargs: %s", retval, self, args, kwargs) + logger.debug( + "returning %s for %s args: %s kwargs: %s", retval, self, args, kwargs + ) return retval # type: ignore [call-overload, return-value] @property @@ -375,9 +448,16 @@ def __bound_to_a_sync_instance__(self) -> bool: True if bound to an ASyncABC instance, False otherwise. """ from a_sync.a_sync.abstract import ASyncABC + return isinstance(self.__self__, ASyncABC) - def map(self, *iterables: AnyIterable[I], concurrency: Optional[int] = None, task_name: str = "", **kwargs: P.kwargs) -> "TaskMapping[I, T]": + def map( + self, + *iterables: AnyIterable[I], + concurrency: Optional[int] = None, + task_name: str = "", + **kwargs: P.kwargs, + ) -> "TaskMapping[I, T]": """ Create a TaskMapping for this method. @@ -391,9 +471,18 @@ def map(self, *iterables: AnyIterable[I], concurrency: Optional[int] = None, tas A TaskMapping instance for this method. """ from a_sync import TaskMapping - return TaskMapping(self, *iterables, concurrency=concurrency, name=task_name, **kwargs) - async def any(self, *iterables: AnyIterable[I], concurrency: Optional[int] = None, task_name: str = "", **kwargs: P.kwargs) -> bool: + return TaskMapping( + self, *iterables, concurrency=concurrency, name=task_name, **kwargs + ) + + async def any( + self, + *iterables: AnyIterable[I], + concurrency: Optional[int] = None, + task_name: str = "", + **kwargs: P.kwargs, + ) -> bool: """ Check if any of the results are truthy. @@ -406,9 +495,17 @@ async def any(self, *iterables: AnyIterable[I], concurrency: Optional[int] = Non Returns: True if any result is truthy, False otherwise. """ - return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **kwargs).any(pop=True, sync=False) + return await self.map( + *iterables, concurrency=concurrency, task_name=task_name, **kwargs + ).any(pop=True, sync=False) - async def all(self, *iterables: AnyIterable[I], concurrency: Optional[int] = None, task_name: str = "", **kwargs: P.kwargs) -> bool: + async def all( + self, + *iterables: AnyIterable[I], + concurrency: Optional[int] = None, + task_name: str = "", + **kwargs: P.kwargs, + ) -> bool: """ Check if all of the results are truthy. @@ -421,9 +518,17 @@ async def all(self, *iterables: AnyIterable[I], concurrency: Optional[int] = Non Returns: True if all results are truthy, False otherwise. """ - return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **kwargs).all(pop=True, sync=False) + return await self.map( + *iterables, concurrency=concurrency, task_name=task_name, **kwargs + ).all(pop=True, sync=False) - async def min(self, *iterables: AnyIterable[I], concurrency: Optional[int] = None, task_name: str = "", **kwargs: P.kwargs) -> T: + async def min( + self, + *iterables: AnyIterable[I], + concurrency: Optional[int] = None, + task_name: str = "", + **kwargs: P.kwargs, + ) -> T: """ Find the minimum result. @@ -436,9 +541,17 @@ async def min(self, *iterables: AnyIterable[I], concurrency: Optional[int] = Non Returns: The minimum result. """ - return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **kwargs).min(pop=True, sync=False) + return await self.map( + *iterables, concurrency=concurrency, task_name=task_name, **kwargs + ).min(pop=True, sync=False) - async def max(self, *iterables: AnyIterable[I], concurrency: Optional[int] = None, task_name: str = "", **kwargs: P.kwargs) -> T: + async def max( + self, + *iterables: AnyIterable[I], + concurrency: Optional[int] = None, + task_name: str = "", + **kwargs: P.kwargs, + ) -> T: """ Find the maximum result. @@ -451,9 +564,17 @@ async def max(self, *iterables: AnyIterable[I], concurrency: Optional[int] = Non Returns: The maximum result. """ - return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **kwargs).max(pop=True, sync=False) + return await self.map( + *iterables, concurrency=concurrency, task_name=task_name, **kwargs + ).max(pop=True, sync=False) - async def sum(self, *iterables: AnyIterable[I], concurrency: Optional[int] = None, task_name: str = "", **kwargs: P.kwargs) -> T: + async def sum( + self, + *iterables: AnyIterable[I], + concurrency: Optional[int] = None, + task_name: str = "", + **kwargs: P.kwargs, + ) -> T: """ Calculate the sum of the results. @@ -466,7 +587,9 @@ async def sum(self, *iterables: AnyIterable[I], concurrency: Optional[int] = Non Returns: The sum of the results. """ - return await self.map(*iterables, concurrency=concurrency, task_name=task_name, **kwargs).sum(pop=True, sync=False) + return await self.map( + *iterables, concurrency=concurrency, task_name=task_name, **kwargs + ).sum(pop=True, sync=False) def _should_await(self, kwargs: dict) -> bool: """ @@ -503,7 +626,9 @@ class ASyncBoundMethodSyncDefault(ASyncBoundMethod[I, P, T]): A bound method with synchronous default behavior. """ - def __get__(self, instance: Optional[I], owner: Type[I]) -> ASyncFunctionSyncDefault[P, T]: + def __get__( + self, instance: Optional[I], owner: Type[I] + ) -> ASyncFunctionSyncDefault[P, T]: """ Get the bound method or descriptor. @@ -517,15 +642,21 @@ def __get__(self, instance: Optional[I], owner: Type[I]) -> ASyncFunctionSyncDef return ASyncBoundMethod.__get__(self, instance, owner) @overload - def __call__(self, *args: P.args, sync: Literal[True], **kwargs: P.kwargs) -> T:... + def __call__(self, *args: P.args, sync: Literal[True], **kwargs: P.kwargs) -> T: ... @overload - def __call__(self, *args: P.args, sync: Literal[False], **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:... + def __call__( + self, *args: P.args, sync: Literal[False], **kwargs: P.kwargs + ) -> Coroutine[Any, Any, T]: ... @overload - def __call__(self, *args: P.args, asynchronous: Literal[False], **kwargs: P.kwargs) -> T:... + def __call__( + self, *args: P.args, asynchronous: Literal[False], **kwargs: P.kwargs + ) -> T: ... @overload - def __call__(self, *args: P.args, asynchronous: Literal[True], **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:... + def __call__( + self, *args: P.args, asynchronous: Literal[True], **kwargs: P.kwargs + ) -> Coroutine[Any, Any, T]: ... @overload - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:... + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ... def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: """ Call the bound method with synchronous default behavior. @@ -539,6 +670,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: """ return ASyncBoundMethod.__call__(self, *args, **kwargs) + class ASyncBoundMethodAsyncDefault(ASyncBoundMethod[I, P, T]): """ A bound method with asynchronous default behavior. @@ -558,15 +690,21 @@ def __get__(self, instance: I, owner: Type[I]) -> ASyncFunctionAsyncDefault[P, T return ASyncBoundMethod.__get__(self, instance, owner) @overload - def __call__(self, *args: P.args, sync: Literal[True], **kwargs: P.kwargs) -> T:... + def __call__(self, *args: P.args, sync: Literal[True], **kwargs: P.kwargs) -> T: ... @overload - def __call__(self, *args: P.args, sync: Literal[False], **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:... + def __call__( + self, *args: P.args, sync: Literal[False], **kwargs: P.kwargs + ) -> Coroutine[Any, Any, T]: ... @overload - def __call__(self, *args: P.args, asynchronous: Literal[False], **kwargs: P.kwargs) -> T:... + def __call__( + self, *args: P.args, asynchronous: Literal[False], **kwargs: P.kwargs + ) -> T: ... @overload - def __call__(self, *args: P.args, asynchronous: Literal[True], **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:... + def __call__( + self, *args: P.args, asynchronous: Literal[True], **kwargs: P.kwargs + ) -> Coroutine[Any, Any, T]: ... @overload - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:... + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]: ... def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]: """ Call the bound method with asynchronous default behavior. diff --git a/a_sync/a_sync/modifiers/__init__.py b/a_sync/a_sync/modifiers/__init__.py index 543e56fd..5f32da2d 100644 --- a/a_sync/a_sync/modifiers/__init__.py +++ b/a_sync/a_sync/modifiers/__init__.py @@ -11,9 +11,9 @@ def get_modifiers_from(thing: Union[dict, type, object]) -> ModifierKwargs: return ModifierKwargs({modifier: thing[modifier] for modifier in valid_modifiers if modifier in thing}) # type: ignore [misc] return ModifierKwargs({modifier: getattr(thing, modifier) for modifier in valid_modifiers if hasattr(thing, modifier)}) # type: ignore [misc] + def apply_class_defined_modifiers(attrs_from_metaclass: dict): - if 'semaphore' in attrs_from_metaclass and isinstance(val := attrs_from_metaclass['semaphore'], int): - attrs_from_metaclass['semaphore'] = ThreadsafeSemaphore(val) - if "runs_per_minute" in attrs_from_metaclass and isinstance(val := attrs_from_metaclass['runs_per_minute'], int): - attrs_from_metaclass['runs_per_minute'] = AsyncLimiter(val) - \ No newline at end of file + if isinstance(val := attrs_from_metaclass.get("semaphore"), int): + attrs_from_metaclass["semaphore"] = ThreadsafeSemaphore(val) + if isinstance(val := attrs_from_metaclass.get("runs_per_minute"), int): + attrs_from_metaclass["runs_per_minute"] = AsyncLimiter(val) diff --git a/a_sync/a_sync/modifiers/cache/__init__.py b/a_sync/a_sync/modifiers/cache/__init__.py index 92531c06..4605f2e1 100644 --- a/a_sync/a_sync/modifiers/cache/__init__.py +++ b/a_sync/a_sync/modifiers/cache/__init__.py @@ -13,48 +13,56 @@ class CacheArgs(TypedDict): ram_cache_maxsize: Optional[int] ram_cache_ttl: Optional[int] + @overload def apply_async_cache( coro_fn: Literal[None], **modifiers: Unpack[CacheArgs], -) -> AsyncDecorator[P, T]:... - +) -> AsyncDecorator[P, T]: ... + + @overload def apply_async_cache( coro_fn: int, **modifiers: Unpack[CacheArgs], -) -> AsyncDecorator[P, T]:... - +) -> AsyncDecorator[P, T]: ... + + @overload def apply_async_cache( coro_fn: CoroFn[P, T], **modifiers: Unpack[CacheArgs], -) -> CoroFn[P, T]:... - +) -> CoroFn[P, T]: ... + + def apply_async_cache( coro_fn: Union[CoroFn[P, T], CacheType, int] = None, - cache_type: CacheType = 'memory', + cache_type: CacheType = "memory", cache_typed: bool = False, ram_cache_maxsize: Optional[int] = None, ram_cache_ttl: Optional[int] = None, ) -> AsyncDecoratorOrCoroFn[P, T]: - + # Parse Inputs if isinstance(coro_fn, int): assert ram_cache_maxsize is None ram_cache_maxsize = coro_fn coro_fn = None - - # Validate + + # Validate elif coro_fn is None: if ram_cache_maxsize is not None and not isinstance(ram_cache_maxsize, int): - raise TypeError("'lru_cache_maxsize' must be an integer or None.", ram_cache_maxsize) + raise TypeError( + "'lru_cache_maxsize' must be an integer or None.", ram_cache_maxsize + ) elif not asyncio.iscoroutinefunction(coro_fn): raise exceptions.FunctionNotAsync(coro_fn) - - if cache_type == 'memory': - cache_decorator = apply_async_memory_cache(maxsize=ram_cache_maxsize, ttl=ram_cache_ttl, typed=cache_typed) + + if cache_type == "memory": + cache_decorator = apply_async_memory_cache( + maxsize=ram_cache_maxsize, ttl=ram_cache_ttl, typed=cache_typed + ) return cache_decorator if coro_fn is None else cache_decorator(coro_fn) - elif cache_type == 'disk': + elif cache_type == "disk": pass - raise NotImplementedError(f"cache_type: {cache_type}") \ No newline at end of file + raise NotImplementedError(f"cache_type: {cache_type}") diff --git a/a_sync/a_sync/modifiers/cache/memory.py b/a_sync/a_sync/modifiers/cache/memory.py index e9b6ebbb..5363394a 100644 --- a/a_sync/a_sync/modifiers/cache/memory.py +++ b/a_sync/a_sync/modifiers/cache/memory.py @@ -7,34 +7,36 @@ from a_sync import exceptions from a_sync._typing import * + class CacheKwargs(TypedDict): maxsize: Optional[int] ttl: Optional[int] typed: bool + @overload def apply_async_memory_cache( - coro_fn: Literal[None], - **kwargs: Unpack[CacheKwargs] -) -> AsyncDecorator[P, T]:... - + coro_fn: Literal[None], **kwargs: Unpack[CacheKwargs] +) -> AsyncDecorator[P, T]: ... + + @overload def apply_async_memory_cache( - coro_fn: int, - **kwargs: Unpack[CacheKwargs] -) -> AsyncDecorator[P, T]:... - + coro_fn: int, **kwargs: Unpack[CacheKwargs] +) -> AsyncDecorator[P, T]: ... + + @overload def apply_async_memory_cache( - coro_fn: CoroFn[P, T], - **kwargs: Unpack[CacheKwargs] -) -> CoroFn[P, T]:... + coro_fn: CoroFn[P, T], **kwargs: Unpack[CacheKwargs] +) -> CoroFn[P, T]: ... + @overload def apply_async_memory_cache( - coro_fn: Literal[None], - **kwargs: Unpack[CacheKwargs] -) -> AsyncDecorator[P, T]:... + coro_fn: Literal[None], **kwargs: Unpack[CacheKwargs] +) -> AsyncDecorator[P, T]: ... + def apply_async_memory_cache( coro_fn: Optional[Union[CoroFn[P, T], int]] = None, @@ -47,14 +49,16 @@ def apply_async_memory_cache( assert maxsize is None maxsize = coro_fn coro_fn = None - - # Validate + + # Validate elif coro_fn is None: if not (maxsize is None or isinstance(maxsize, int)): - raise TypeError("'lru_cache_maxsize' must be a positive integer or None.", maxsize) + raise TypeError( + "'lru_cache_maxsize' must be a positive integer or None.", maxsize + ) elif not asyncio.iscoroutinefunction(coro_fn): raise exceptions.FunctionNotAsync(coro_fn) - + if maxsize == -1: maxsize = None diff --git a/a_sync/a_sync/modifiers/limiter.py b/a_sync/a_sync/modifiers/limiter.py index 0c4c877f..864b62c3 100644 --- a/a_sync/a_sync/modifiers/limiter.py +++ b/a_sync/a_sync/modifiers/limiter.py @@ -12,20 +12,23 @@ def apply_rate_limit( coro_fn: Literal[None], runs_per_minute: int, -) -> AsyncDecorator[P, T]:... - +) -> AsyncDecorator[P, T]: ... + + @overload def apply_rate_limit( coro_fn: int, runs_per_minute: Literal[None], -) -> AsyncDecorator[P, T]:... - +) -> AsyncDecorator[P, T]: ... + + @overload def apply_rate_limit( coro_fn: CoroFn[P, T], runs_per_minute: Union[int, AsyncLimiter], -) -> CoroFn[P, T]:... - +) -> CoroFn[P, T]: ... + + def apply_rate_limit( coro_fn: Optional[Union[CoroFn[P, T], int]] = None, runs_per_minute: Optional[Union[int, AsyncLimiter]] = None, @@ -35,21 +38,25 @@ def apply_rate_limit( assert runs_per_minute is None runs_per_minute = coro_fn coro_fn = None - + elif coro_fn is None: if runs_per_minute is not None and not isinstance(runs_per_minute, int): raise TypeError("'runs_per_minute' must be an integer.", runs_per_minute) - + elif not asyncio.iscoroutinefunction(coro_fn): raise exceptions.FunctionNotAsync(coro_fn) - + def rate_limit_decorator(coro_fn: CoroFn[P, T]) -> CoroFn[P, T]: - limiter = runs_per_minute if isinstance(runs_per_minute, AsyncLimiter) else AsyncLimiter(runs_per_minute) if runs_per_minute else aliases.dummy + limiter = ( + runs_per_minute + if isinstance(runs_per_minute, AsyncLimiter) + else AsyncLimiter(runs_per_minute) if runs_per_minute else aliases.dummy + ) + async def rate_limit_wrap(*args: P.args, **kwargs: P.kwargs) -> T: async with limiter: # type: ignore [attr-defined] return await coro_fn(*args, **kwargs) + return rate_limit_wrap - + return rate_limit_decorator if coro_fn is None else rate_limit_decorator(coro_fn) - - \ No newline at end of file diff --git a/a_sync/a_sync/modifiers/manager.py b/a_sync/a_sync/modifiers/manager.py index 5f3e0ca1..96ed9ab0 100644 --- a/a_sync/a_sync/modifiers/manager.py +++ b/a_sync/a_sync/modifiers/manager.py @@ -5,7 +5,12 @@ from a_sync.a_sync.config import user_set_default_modifiers, null_modifiers from a_sync.a_sync.modifiers import cache, limiter, semaphores -valid_modifiers = [key for key in ModifierKwargs.__annotations__ if not key.startswith('_') and not key.endswith('_')] +valid_modifiers = [ + key + for key in ModifierKwargs.__annotations__ + if not key.startswith("_") and not key.endswith("_") +] + class ModifierManager(Dict[str, Any]): default: DefaultMode @@ -17,35 +22,43 @@ class ModifierManager(Dict[str, Any]): semaphore: SemaphoreSpec # sync modifiers executor: Executor - __slots__ = "_modifiers", + __slots__ = ("_modifiers",) + def __init__(self, modifiers: ModifierKwargs) -> None: for key in modifiers.keys(): if key not in valid_modifiers: raise ValueError(f"'{key}' is not a supported modifier.") self._modifiers = modifiers + def __repr__(self) -> str: return str(self._modifiers) + def __getattribute__(self, modifier_key: str) -> Any: if modifier_key not in valid_modifiers: return super().__getattribute__(modifier_key) - return self[modifier_key] if modifier_key in self else user_defaults[modifier_key] + return ( + self[modifier_key] if modifier_key in self else user_defaults[modifier_key] + ) - @property def use_limiter(self) -> bool: return self.runs_per_minute != nulls.runs_per_minute + @property def use_semaphore(self) -> bool: return self.semaphore != nulls.semaphore + @property def use_cache(self) -> bool: - return any([ - self.cache_type != nulls.cache_type, - self.ram_cache_maxsize != nulls.ram_cache_maxsize, - self.ram_cache_ttl != nulls.ram_cache_ttl, - self.cache_typed != nulls.cache_typed, - ]) - + return any( + [ + self.cache_type != nulls.cache_type, + self.ram_cache_maxsize != nulls.ram_cache_maxsize, + self.ram_cache_ttl != nulls.ram_cache_ttl, + self.cache_typed != nulls.cache_typed, + ] + ) + def apply_async_modifiers(self, coro_fn: CoroFn[P, T]) -> CoroFn[P, T]: # NOTE: THESE STACK IN REVERSE ORDER if self.use_limiter: @@ -55,35 +68,43 @@ def apply_async_modifiers(self, coro_fn: CoroFn[P, T]) -> CoroFn[P, T]: if self.use_cache: coro_fn = cache.apply_async_cache( coro_fn, - cache_type=self.cache_type or 'memory', + cache_type=self.cache_type or "memory", cache_typed=self.cache_typed, ram_cache_maxsize=self.ram_cache_maxsize, - ram_cache_ttl=self.ram_cache_ttl + ram_cache_ttl=self.ram_cache_ttl, ) return coro_fn - + def apply_sync_modifiers(self, function: SyncFn[P, T]) -> SyncFn[P, T]: @functools.wraps(function) def sync_modifier_wrap(*args: P.args, **kwargs: P.kwargs) -> T: return function(*args, **kwargs) + # NOTE There are no sync modifiers at this time but they will be added here for my convenience. return sync_modifier_wrap - + # Dictionary api def keys(self) -> KeysView[str]: # type: ignore [override] return self._modifiers.keys() + def values(self) -> ValuesView[Any]: # type: ignore [override] return self._modifiers.values() + def items(self) -> ItemsView[str, Any]: # type: ignore [override] return self._modifiers.items() + def __contains__(self, key: str) -> bool: # type: ignore [override] return key in self._modifiers + def __iter__(self) -> Iterator[str]: return self._modifiers.__iter__() + def __len__(self) -> int: return len(self._modifiers) + def __getitem__(self, modifier_key: str): return self._modifiers[modifier_key] # type: ignore [literal-required] + nulls = ModifierManager(null_modifiers) user_defaults = ModifierManager(user_set_default_modifiers) diff --git a/a_sync/a_sync/modifiers/semaphores.py b/a_sync/a_sync/modifiers/semaphores.py index c4ac9e89..e056f623 100644 --- a/a_sync/a_sync/modifiers/semaphores.py +++ b/a_sync/a_sync/modifiers/semaphores.py @@ -14,20 +14,23 @@ def apply_semaphore( # type: ignore [misc] coro_fn: Literal[None], semaphore: SemaphoreSpec, -) -> AsyncDecorator[P, T]:... +) -> AsyncDecorator[P, T]: ... + @overload def apply_semaphore( coro_fn: SemaphoreSpec, semaphore: Literal[None], -) -> AsyncDecorator[P, T]:... +) -> AsyncDecorator[P, T]: ... + @overload def apply_semaphore( coro_fn: CoroFn[P, T], semaphore: SemaphoreSpec, -) -> CoroFn[P, T]:... - +) -> CoroFn[P, T]: ... + + def apply_semaphore( coro_fn: Optional[Union[CoroFn[P, T], SemaphoreSpec]] = None, semaphore: SemaphoreSpec = None, @@ -38,31 +41,35 @@ def apply_semaphore( raise ValueError("You can only pass in one arg.") semaphore = coro_fn coro_fn = None - + elif not asyncio.iscoroutinefunction(coro_fn): raise exceptions.FunctionNotAsync(coro_fn) - + # Create the semaphore if necessary if isinstance(semaphore, int): semaphore = primitives.ThreadsafeSemaphore(semaphore) elif not isinstance(semaphore, asyncio.Semaphore): - raise TypeError(f"'semaphore' must either be an integer or a Semaphore object. You passed {semaphore}") - + raise TypeError( + f"'semaphore' must either be an integer or a Semaphore object. You passed {semaphore}" + ) + # Create and return the decorator if isinstance(semaphore, primitives.Semaphore): # NOTE: Our `Semaphore` primitive can be used as a decorator. # While you can use it the `async with` way like any other semaphore and we could make this code section cleaner, # applying it as a decorator adds some useful info to its debug logs so we do that here if we can. semaphore_decorator = semaphore - + else: + def semaphore_decorator(coro_fn: CoroFn[P, T]) -> CoroFn[P, T]: @functools.wraps(coro_fn) async def semaphore_wrap(*args, **kwargs) -> T: async with semaphore: # type: ignore [union-attr] return await coro_fn(*args, **kwargs) + return semaphore_wrap - + return semaphore_decorator if coro_fn is None else semaphore_decorator(coro_fn) diff --git a/a_sync/a_sync/property.py b/a_sync/a_sync/property.py index a320f2e0..478bf14b 100644 --- a/a_sync/a_sync/property.py +++ b/a_sync/a_sync/property.py @@ -1,4 +1,3 @@ - import functools import logging @@ -9,8 +8,15 @@ from a_sync._typing import * from a_sync.a_sync import _helpers, config from a_sync.a_sync._descriptor import ASyncDescriptor -from a_sync.a_sync.function import ASyncFunction, ASyncFunctionAsyncDefault, ASyncFunctionSyncDefault -from a_sync.a_sync.method import ASyncBoundMethodAsyncDefault, ASyncMethodDescriptorAsyncDefault +from a_sync.a_sync.function import ( + ASyncFunction, + ASyncFunctionAsyncDefault, + ASyncFunctionSyncDefault, +) +from a_sync.a_sync.method import ( + ASyncBoundMethodAsyncDefault, + ASyncMethodDescriptorAsyncDefault, +) if TYPE_CHECKING: from a_sync.task import TaskMapping @@ -18,17 +24,20 @@ logger = logging.getLogger(__name__) + class _ASyncPropertyDescriptorBase(ASyncDescriptor[I, Tuple[()], T]): any: ASyncFunction[AnyIterable[I], bool] all: ASyncFunction[AnyIterable[I], bool] min: ASyncFunction[AnyIterable[I], T] max: ASyncFunction[AnyIterable[I], T] sum: ASyncFunction[AnyIterable[I], T] + hidden_method_descriptor: "HiddenMethodDescriptor[T]" __wrapped__: Callable[[I], T] __slots__ = "hidden_method_name", "hidden_method_descriptor", "_fget" + def __init__( - self, - _fget: AsyncGetterFunction[I, T], + self, + _fget: AsyncGetterFunction[I, T], field_name: Optional[str] = None, **modifiers: Unpack[ModifierKwargs], ) -> None: @@ -36,45 +45,90 @@ def __init__( self.hidden_method_name = f"__{self.field_name}__" hidden_modifiers = dict(self.modifiers) hidden_modifiers["default"] = "async" - self.hidden_method_descriptor: HiddenMethodDescriptor[T] = HiddenMethodDescriptor(self.get, self.hidden_method_name, **hidden_modifiers) + self.hidden_method_descriptor = HiddenMethodDescriptor( + self.get, self.hidden_method_name, **hidden_modifiers + ) if asyncio.iscoroutinefunction(_fget): self._fget = self.__wrapped__ else: self._fget = _helpers._asyncify(self.__wrapped__, self.modifiers.executor) + @overload - def __get__(self, instance: None, owner: Type[I]) -> Self:... + def __get__(self, instance: None, owner: Type[I]) -> Self: ... @overload - def __get__(self, instance: I, owner: Type[I]) -> Awaitable[T]:... - def __get__(self, instance: Optional[I], owner: Type[I]) -> Union[Self, Awaitable[T]]: + def __get__(self, instance: I, owner: Type[I]) -> Awaitable[T]: ... + def __get__( + self, instance: Optional[I], owner: Type[I] + ) -> Union[Self, Awaitable[T]]: if instance is None: return self awaitable = super().__get__(instance, owner) # if the user didn't specify a default behavior, we will defer to the instance if _is_a_sync_instance(instance): - should_await = self.default == "sync" if self.default else instance.__a_sync_instance_should_await__ + should_await = ( + self.default == "sync" + if self.default + else instance.__a_sync_instance_should_await__ + ) else: - should_await = self.default == "sync" if self.default else not asyncio.get_event_loop().is_running() + should_await = ( + self.default == "sync" + if self.default + else not asyncio.get_event_loop().is_running() + ) if should_await: - logger.debug("awaiting awaitable for %s for instance: %s owner: %s", awaitable, self, instance, owner) + logger.debug( + "awaiting awaitable for %s for instance: %s owner: %s", + awaitable, + self, + instance, + owner, + ) retval = _helpers._await(awaitable) else: retval = awaitable - logger.debug("returning %s for %s for instance: %s owner: %s", retval, self, instance, owner) + logger.debug( + "returning %s for %s for instance: %s owner: %s", + retval, + self, + instance, + owner, + ) return retval + async def get(self, instance: I, owner: Optional[Type[I]] = None) -> T: if instance is None: raise ValueError(instance) logger.debug("awaiting %s for instance %s", self, instance) return await super().__get__(instance, owner) - def map(self, instances: AnyIterable[I], owner: Optional[Type[I]] = None, concurrency: Optional[int] = None, name: str = "") -> "TaskMapping[I, T]": + + def map( + self, + instances: AnyIterable[I], + owner: Optional[Type[I]] = None, + concurrency: Optional[int] = None, + name: str = "", + ) -> "TaskMapping[I, T]": from a_sync.task import TaskMapping - logger.debug("mapping %s to instances: %s owner: %s", self, instances, owner) - return TaskMapping(self, instances, owner=owner, name=name or self.field_name, concurrency=concurrency) -class ASyncPropertyDescriptor(_ASyncPropertyDescriptorBase[I, T], ap.base.AsyncPropertyDescriptor): + logger.debug("mapping %s to instances: %s owner: %s", self, instances, owner) + return TaskMapping( + self, + instances, + owner=owner, + name=name or self.field_name, + concurrency=concurrency, + ) + + +class ASyncPropertyDescriptor( + _ASyncPropertyDescriptorBase[I, T], ap.base.AsyncPropertyDescriptor +): pass -class property(ASyncPropertyDescriptor[I, T]):... + +class property(ASyncPropertyDescriptor[I, T]): ... + @final class ASyncPropertyDescriptorSyncDefault(property[I, T]): @@ -91,13 +145,15 @@ class ASyncPropertyDescriptorSyncDefault(property[I, T]): min: ASyncFunctionSyncDefault[AnyIterable[I], T] max: ASyncFunctionSyncDefault[AnyIterable[I], T] sum: ASyncFunctionSyncDefault[AnyIterable[I], T] + @overload - def __get__(self, instance: None, owner: Type[I]) -> Self:... + def __get__(self, instance: None, owner: Type[I]) -> Self: ... @overload - def __get__(self, instance: I, owner: Type[I]) -> T:... + def __get__(self, instance: I, owner: Type[I]) -> T: ... def __get__(self, instance: Optional[I], owner: Type[I]) -> Union[Self, T]: return _ASyncPropertyDescriptorBase.__get__(self, instance, owner) + @final class ASyncPropertyDescriptorAsyncDefault(property[I, T]): """ @@ -116,89 +172,105 @@ class ASyncPropertyDescriptorAsyncDefault(property[I, T]): ASyncPropertyDecorator = Callable[[AnyGetterFunction[I, T]], property[I, T]] -ASyncPropertyDecoratorSyncDefault = Callable[[AnyGetterFunction[I, T]], ASyncPropertyDescriptorSyncDefault[I, T]] -ASyncPropertyDecoratorAsyncDefault = Callable[[AnyGetterFunction[I, T]], ASyncPropertyDescriptorAsyncDefault[I, T]] +ASyncPropertyDecoratorSyncDefault = Callable[ + [AnyGetterFunction[I, T]], ASyncPropertyDescriptorSyncDefault[I, T] +] +ASyncPropertyDecoratorAsyncDefault = Callable[ + [AnyGetterFunction[I, T]], ASyncPropertyDescriptorAsyncDefault[I, T] +] + @overload def a_sync_property( # type: ignore [misc] func: Literal[None] = None, **modifiers: Unpack[ModifierKwargs], -) -> ASyncPropertyDecorator[I, T]:... +) -> ASyncPropertyDecorator[I, T]: ... + @overload def a_sync_property( # type: ignore [misc] func: AnyGetterFunction[I, T], **modifiers: Unpack[ModifierKwargs], -) -> ASyncPropertyDescriptor[I, T]:... +) -> ASyncPropertyDescriptor[I, T]: ... + @overload def a_sync_property( # type: ignore [misc] func: Literal[None], default: Literal["sync"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncPropertyDecoratorSyncDefault[I, T]:... +) -> ASyncPropertyDecoratorSyncDefault[I, T]: ... + @overload def a_sync_property( # type: ignore [misc] func: Literal[None], default: Literal["sync"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncPropertyDecoratorSyncDefault[I, T]:... +) -> ASyncPropertyDecoratorSyncDefault[I, T]: ... + @overload def a_sync_property( # type: ignore [misc] func: Literal[None], default: Literal["async"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncPropertyDecoratorAsyncDefault[I, T]:... +) -> ASyncPropertyDecoratorAsyncDefault[I, T]: ... + @overload def a_sync_property( # type: ignore [misc] func: Literal[None], default: DefaultMode = config.DEFAULT_MODE, **modifiers: Unpack[ModifierKwargs], -) -> ASyncPropertyDecorator[I, T]:... - +) -> ASyncPropertyDecorator[I, T]: ... + + @overload def a_sync_property( # type: ignore [misc] default: Literal["sync"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncPropertyDecoratorSyncDefault[I, T]:... - +) -> ASyncPropertyDecoratorSyncDefault[I, T]: ... + + @overload def a_sync_property( # type: ignore [misc] default: Literal["async"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncPropertyDecoratorAsyncDefault[I, T]:... - +) -> ASyncPropertyDecoratorAsyncDefault[I, T]: ... + + @overload def a_sync_property( # type: ignore [misc] func: AnyGetterFunction[I, T], default: Literal["sync"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncPropertyDescriptorSyncDefault[I, T]:... - +) -> ASyncPropertyDescriptorSyncDefault[I, T]: ... + + @overload def a_sync_property( # type: ignore [misc] func: AnyGetterFunction[I, T], default: Literal["async"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncPropertyDescriptorAsyncDefault[I, T]:... - +) -> ASyncPropertyDescriptorAsyncDefault[I, T]: ... + + @overload def a_sync_property( # type: ignore [misc] func: AnyGetterFunction[I, T], default: DefaultMode = config.DEFAULT_MODE, **modifiers: Unpack[ModifierKwargs], -) -> ASyncPropertyDescriptor[I, T]:... - +) -> ASyncPropertyDescriptor[I, T]: ... + + def a_sync_property( # type: ignore [misc] func: Union[AnyGetterFunction[I, T], DefaultMode] = None, **modifiers: Unpack[ModifierKwargs], ) -> Union[ ASyncPropertyDescriptor[I, T], - ASyncPropertyDescriptorSyncDefault[I, T], - ASyncPropertyDescriptorAsyncDefault[I, T], + ASyncPropertyDescriptorSyncDefault[I, T], + ASyncPropertyDescriptorAsyncDefault[I, T], ASyncPropertyDecorator[I, T], ASyncPropertyDecoratorSyncDefault[I, T], ASyncPropertyDecoratorAsyncDefault[I, T], @@ -214,7 +286,9 @@ def a_sync_property( # type: ignore [misc] return decorator if func is None else decorator(func) -class ASyncCachedPropertyDescriptor(_ASyncPropertyDescriptorBase[I, T], ap.cached.AsyncCachedPropertyDescriptor): +class ASyncCachedPropertyDescriptor( + _ASyncPropertyDescriptorBase[I, T], ap.cached.AsyncCachedPropertyDescriptor +): """ A descriptor class for dual-function sync/async cached properties. @@ -223,17 +297,18 @@ class ASyncCachedPropertyDescriptor(_ASyncPropertyDescriptorBase[I, T], ap.cache """ __slots__ = "_fset", "_fdel", "__async_property__" + def __init__( - self, - _fget: AsyncGetterFunction[I, T], - _fset = None, - _fdel = None, - field_name=None, + self, + _fget: AsyncGetterFunction[I, T], + _fset=None, + _fdel=None, + field_name=None, **modifiers: Unpack[ModifierKwargs], ) -> None: super().__init__(_fget, field_name, **modifiers) - self._check_method_sync(_fset, 'setter') - self._check_method_sync(_fdel, 'deleter') + self._check_method_sync(_fset, "setter") + self._check_method_sync(_fdel, "deleter") self._fset = _fset self._fdel = _fdel @@ -245,7 +320,7 @@ def get_lock(self, instance: I) -> "asyncio.Task[T]": task = asyncio.create_task(self._fget(instance)) instance_state.lock[self.field_name] = task return task - + def pop_lock(self, instance: I) -> None: self.get_instance_state(instance).lock.pop(self.field_name, None) @@ -261,9 +336,12 @@ async def load_value(): self.__set__(instance, value) self.pop_lock(instance) return value + return load_value - -class cached_property(ASyncCachedPropertyDescriptor[I, T]):... + + +class cached_property(ASyncCachedPropertyDescriptor[I, T]): ... + @final class ASyncCachedPropertyDescriptorSyncDefault(cached_property[I, T]): @@ -275,10 +353,11 @@ class ASyncCachedPropertyDescriptorSyncDefault(cached_property[I, T]): """ default: Literal["sync"] + @overload - def __get__(self, instance: None, owner: Type[I]) -> Self:... + def __get__(self, instance: None, owner: Type[I]) -> Self: ... @overload - def __get__(self, instance: I, owner: Type[I]) -> T:... + def __get__(self, instance: I, owner: Type[I]) -> T: ... def __get__(self, instance: Optional[I], owner: Type[I]) -> Union[Self, T]: return _ASyncPropertyDescriptorBase.__get__(self, instance, owner) @@ -294,83 +373,101 @@ class ASyncCachedPropertyDescriptorAsyncDefault(cached_property[I, T]): default: Literal["async"] -ASyncCachedPropertyDecorator = Callable[[AnyGetterFunction[I, T]], cached_property[I, T]] -ASyncCachedPropertyDecoratorSyncDefault = Callable[[AnyGetterFunction[I, T]], ASyncCachedPropertyDescriptorSyncDefault[I, T]] -ASyncCachedPropertyDecoratorAsyncDefault = Callable[[AnyGetterFunction[I, T]], ASyncCachedPropertyDescriptorAsyncDefault[I, T]] + +ASyncCachedPropertyDecorator = Callable[ + [AnyGetterFunction[I, T]], cached_property[I, T] +] +ASyncCachedPropertyDecoratorSyncDefault = Callable[ + [AnyGetterFunction[I, T]], ASyncCachedPropertyDescriptorSyncDefault[I, T] +] +ASyncCachedPropertyDecoratorAsyncDefault = Callable[ + [AnyGetterFunction[I, T]], ASyncCachedPropertyDescriptorAsyncDefault[I, T] +] + @overload def a_sync_cached_property( # type: ignore [misc] func: Literal[None] = None, **modifiers: Unpack[ModifierKwargs], -) -> ASyncCachedPropertyDecorator[I, T]:... +) -> ASyncCachedPropertyDecorator[I, T]: ... + @overload def a_sync_cached_property( # type: ignore [misc] func: AnyGetterFunction[I, T], **modifiers: Unpack[ModifierKwargs], -) -> ASyncCachedPropertyDescriptor[I, T]:... +) -> ASyncCachedPropertyDescriptor[I, T]: ... + @overload def a_sync_cached_property( # type: ignore [misc] func: Literal[None], default: Literal["sync"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncCachedPropertyDecoratorSyncDefault[I, T]:... +) -> ASyncCachedPropertyDecoratorSyncDefault[I, T]: ... + @overload def a_sync_cached_property( # type: ignore [misc] func: Literal[None], default: Literal["async"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncCachedPropertyDecoratorAsyncDefault[I, T]:... +) -> ASyncCachedPropertyDecoratorAsyncDefault[I, T]: ... + @overload def a_sync_cached_property( # type: ignore [misc] func: Literal[None], default: DefaultMode, **modifiers: Unpack[ModifierKwargs], -) -> ASyncCachedPropertyDecorator[I, T]:... +) -> ASyncCachedPropertyDecorator[I, T]: ... + @overload def a_sync_cached_property( # type: ignore [misc] default: Literal["sync"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncCachedPropertyDecoratorSyncDefault[I, T]:... +) -> ASyncCachedPropertyDecoratorSyncDefault[I, T]: ... + @overload def a_sync_cached_property( # type: ignore [misc] default: Literal["async"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncCachedPropertyDecoratorAsyncDefault[I, T]:... - +) -> ASyncCachedPropertyDecoratorAsyncDefault[I, T]: ... + + @overload def a_sync_cached_property( # type: ignore [misc] func: AnyGetterFunction[I, T], default: Literal["sync"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncCachedPropertyDescriptorSyncDefault[I, T]:... +) -> ASyncCachedPropertyDescriptorSyncDefault[I, T]: ... + @overload def a_sync_cached_property( # type: ignore [misc] func: AnyGetterFunction[I, T], default: Literal["async"], **modifiers: Unpack[ModifierKwargs], -) -> ASyncCachedPropertyDescriptorAsyncDefault[I, T]:... +) -> ASyncCachedPropertyDescriptorAsyncDefault[I, T]: ... + @overload def a_sync_cached_property( # type: ignore [misc] func: AnyGetterFunction[I, T], default: DefaultMode = config.DEFAULT_MODE, **modifiers: Unpack[ModifierKwargs], -) -> ASyncCachedPropertyDescriptor[I, T]:... - +) -> ASyncCachedPropertyDescriptor[I, T]: ... + + def a_sync_cached_property( # type: ignore [misc] func: Optional[AnyGetterFunction[I, T]] = None, **modifiers: Unpack[ModifierKwargs], ) -> Union[ ASyncCachedPropertyDescriptor[I, T], - ASyncCachedPropertyDescriptorSyncDefault[I, T], - ASyncCachedPropertyDescriptorAsyncDefault[I, T], + ASyncCachedPropertyDescriptorSyncDefault[I, T], + ASyncCachedPropertyDescriptorAsyncDefault[I, T], ASyncCachedPropertyDecorator[I, T], ASyncCachedPropertyDecoratorSyncDefault[I, T], ASyncCachedPropertyDecoratorAsyncDefault[I, T], @@ -385,35 +482,40 @@ def a_sync_cached_property( # type: ignore [misc] decorator = functools.partial(descriptor_class, **modifiers) return decorator if func is None else decorator(func) + @final class HiddenMethod(ASyncBoundMethodAsyncDefault[I, Tuple[()], T]): def __init__( - self, - instance: I, - unbound: AnyFn[Concatenate[I, P], T], + self, + instance: I, + unbound: AnyFn[Concatenate[I, P], T], async_def: bool, field_name: str, **modifiers: Unpack[ModifierKwargs], ) -> None: super().__init__(instance, unbound, async_def, **modifiers) self.__name__ = field_name + def __repr__(self) -> str: instance_type = type(self.__self__) return f"<{self.__class__.__name__} for property {instance_type.__module__}.{instance_type.__name__}.{self.__name__[2:-2]} bound to {self.__self__}>" + def _should_await(self, kwargs: dict) -> bool: try: return self.__self__.__a_sync_should_await_from_kwargs__(kwargs) except (AttributeError, exceptions.NoFlagsFound): return False + def __await__(self) -> Generator[Any, None, T]: return self(sync=False).__await__() + @final class HiddenMethodDescriptor(ASyncMethodDescriptorAsyncDefault[I, Tuple[()], T]): def __init__( - self, - _fget: AnyFn[Concatenate[I, P], Awaitable[T]], - field_name: Optional[str] = None, + self, + _fget: AnyFn[Concatenate[I, P], Awaitable[T]], + field_name: Optional[str] = None, **modifiers: Unpack[ModifierKwargs], ) -> None: """ @@ -434,6 +536,7 @@ def __init__( self.__doc__ += f"A :class:`HiddenMethodDescriptor` for :meth:`{self.__wrapped__.__qualname__}`." if self.__wrapped__.__doc__: self.__doc__ += f"\n\nThe original docstring for :meth:`~{self.__wrapped__.__qualname__}` is shown below:\n\n{self.__wrapped__.__doc__}" + def __get__(self, instance: I, owner: Type[I]) -> HiddenMethod[I, T]: if instance is None: return self @@ -441,23 +544,34 @@ def __get__(self, instance: I, owner: Type[I]) -> HiddenMethod[I, T]: bound = instance.__dict__[self.field_name] bound._cache_handle.cancel() except KeyError: - bound = HiddenMethod(instance, self.__wrapped__, self.__is_async_def__, self.field_name, **self.modifiers) + bound = HiddenMethod( + instance, + self.__wrapped__, + self.__is_async_def__, + self.field_name, + **self.modifiers, + ) instance.__dict__[self.field_name] = bound logger.debug("new hidden method: %s", bound) bound._cache_handle = self._get_cache_handle(instance) return bound + def _is_a_sync_instance(instance: object) -> bool: try: return instance.__is_a_sync_instance__ # type: ignore [attr-defined] except AttributeError: from a_sync.a_sync.abstract import ASyncABC + is_a_sync = isinstance(instance, ASyncABC) instance.__is_a_sync_instance__ = is_a_sync return is_a_sync -def _parse_args(func: Union[None, DefaultMode, AsyncGetterFunction[I, T]], modifiers: ModifierKwargs) -> Tuple[Optional[AsyncGetterFunction[I, T]], ModifierKwargs]: - if func in ['sync', 'async']: - modifiers['default'] = func + +def _parse_args( + func: Union[None, DefaultMode, AsyncGetterFunction[I, T]], modifiers: ModifierKwargs +) -> Tuple[Optional[AsyncGetterFunction[I, T]], ModifierKwargs]: + if func in ["sync", "async"]: + modifiers["default"] = func func = None return func, modifiers diff --git a/a_sync/a_sync/singleton.py b/a_sync/a_sync/singleton.py index 3d7d8387..74a8ccda 100644 --- a/a_sync/a_sync/singleton.py +++ b/a_sync/a_sync/singleton.py @@ -1,6 +1,7 @@ from a_sync.a_sync._meta import ASyncSingletonMeta from a_sync.a_sync.base import ASyncGenericBase + class ASyncGenericSingleton(ASyncGenericBase, metaclass=ASyncSingletonMeta): """ A base class for creating singleton-esque ASync classes. diff --git a/a_sync/aliases.py b/a_sync/aliases.py index ddc57f8c..d4802a28 100644 --- a/a_sync/aliases.py +++ b/a_sync/aliases.py @@ -1,4 +1,3 @@ - from a_sync.a_sync.modifiers.semaphores import dummy_semaphore as dummy from a_sync.a_sync.property import a_sync_cached_property as cached_property from a_sync.a_sync.property import a_sync_property as property diff --git a/a_sync/asyncio/as_completed.py b/a_sync/asyncio/as_completed.py index 7d6ef950..1bb83d6d 100644 --- a/a_sync/asyncio/as_completed.py +++ b/a_sync/asyncio/as_completed.py @@ -7,26 +7,65 @@ try: from tqdm.asyncio import tqdm_asyncio except ImportError as e: + class tqdm_asyncio: # type: ignore [no-redef] def as_completed(*args, **kwargs): raise ImportError("You must have tqdm installed to use this feature") - + + from a_sync._typing import * from a_sync.iter import ASyncIterator + @overload -def as_completed(fs: Iterable[Awaitable[T]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: Literal[False] = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Iterator[Coroutine[Any, Any, T]]: - ... +def as_completed( + fs: Iterable[Awaitable[T]], + *, + timeout: Optional[float] = None, + return_exceptions: bool = False, + aiter: Literal[False] = False, + tqdm: bool = False, + **tqdm_kwargs: Any +) -> Iterator[Coroutine[Any, Any, T]]: ... @overload -def as_completed(fs: Iterable[Awaitable[T]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: Literal[True] = True, tqdm: bool = False, **tqdm_kwargs: Any) -> ASyncIterator[T]: - ... +def as_completed( + fs: Iterable[Awaitable[T]], + *, + timeout: Optional[float] = None, + return_exceptions: bool = False, + aiter: Literal[True] = True, + tqdm: bool = False, + **tqdm_kwargs: Any +) -> ASyncIterator[T]: ... @overload -def as_completed(fs: Mapping[K, Awaitable[V]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: Literal[False] = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Iterator[Coroutine[Any, Any, Tuple[K, V]]]: - ... +def as_completed( + fs: Mapping[K, Awaitable[V]], + *, + timeout: Optional[float] = None, + return_exceptions: bool = False, + aiter: Literal[False] = False, + tqdm: bool = False, + **tqdm_kwargs: Any +) -> Iterator[Coroutine[Any, Any, Tuple[K, V]]]: ... @overload -def as_completed(fs: Mapping[K, Awaitable[V]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: Literal[True] = True, tqdm: bool = False, **tqdm_kwargs: Any) -> ASyncIterator[Tuple[K, V]]: - ... -def as_completed(fs, *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: bool = False, tqdm: bool = False, **tqdm_kwargs: Any): +def as_completed( + fs: Mapping[K, Awaitable[V]], + *, + timeout: Optional[float] = None, + return_exceptions: bool = False, + aiter: Literal[True] = True, + tqdm: bool = False, + **tqdm_kwargs: Any +) -> ASyncIterator[Tuple[K, V]]: ... +def as_completed( + fs, + *, + timeout: Optional[float] = None, + return_exceptions: bool = False, + aiter: bool = False, + tqdm: bool = False, + **tqdm_kwargs: Any +): """ Concurrently awaits a list of awaitable objects or mappings of awaitables and returns an iterator of results. @@ -56,40 +95,76 @@ def as_completed(fs, *, timeout: Optional[float] = None, return_exceptions: bool for coro in as_completed(awaitables): val = await coro ... - + async for val in as_completed(awaitables, aiter=True): ... ``` - + Awaiting mappings of awaitables: ``` mapping = {'key1': async_function1(), 'key2': async_function2()} - + for coro in as_completed(mapping): k, v = await coro ... - + async for k, v in as_completed(mapping, aiter=True): ... ``` """ if isinstance(fs, Mapping): - return as_completed_mapping(fs, timeout=timeout, return_exceptions=return_exceptions, aiter=aiter, tqdm=tqdm, **tqdm_kwargs) + return as_completed_mapping( + fs, + timeout=timeout, + return_exceptions=return_exceptions, + aiter=aiter, + tqdm=tqdm, + **tqdm_kwargs + ) if return_exceptions: fs = [_exc_wrap(f) for f in fs] return ( - ASyncIterator(__yield_as_completed(fs, timeout=timeout, tqdm=tqdm, **tqdm_kwargs)) if aiter - else tqdm_asyncio.as_completed(fs, timeout=timeout, **tqdm_kwargs) if tqdm - else asyncio.as_completed(fs, timeout=timeout) + ASyncIterator( + __yield_as_completed(fs, timeout=timeout, tqdm=tqdm, **tqdm_kwargs) + ) + if aiter + else ( + tqdm_asyncio.as_completed(fs, timeout=timeout, **tqdm_kwargs) + if tqdm + else asyncio.as_completed(fs, timeout=timeout) + ) ) + @overload -def as_completed_mapping(mapping: Mapping[K, Awaitable[V]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: Literal[True] = True, tqdm: bool = False, **tqdm_kwargs: Any) -> ASyncIterator[Tuple[K, V]]: - ... +def as_completed_mapping( + mapping: Mapping[K, Awaitable[V]], + *, + timeout: Optional[float] = None, + return_exceptions: bool = False, + aiter: Literal[True] = True, + tqdm: bool = False, + **tqdm_kwargs: Any +) -> ASyncIterator[Tuple[K, V]]: ... @overload -def as_completed_mapping(mapping: Mapping[K, Awaitable[V]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: Literal[False] = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Iterator[Coroutine[Any, Any, Tuple[K, V]]]: - ... -def as_completed_mapping(mapping: Mapping[K, Awaitable[V]], *, timeout: Optional[float] = None, return_exceptions: bool = False, aiter: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> Union[Iterator[Coroutine[Any, Any, Tuple[K, V]]], ASyncIterator[Tuple[K, V]]]: +def as_completed_mapping( + mapping: Mapping[K, Awaitable[V]], + *, + timeout: Optional[float] = None, + return_exceptions: bool = False, + aiter: Literal[False] = False, + tqdm: bool = False, + **tqdm_kwargs: Any +) -> Iterator[Coroutine[Any, Any, Tuple[K, V]]]: ... +def as_completed_mapping( + mapping: Mapping[K, Awaitable[V]], + *, + timeout: Optional[float] = None, + return_exceptions: bool = False, + aiter: bool = False, + tqdm: bool = False, + **tqdm_kwargs: Any +) -> Union[Iterator[Coroutine[Any, Any, Tuple[K, V]]], ASyncIterator[Tuple[K, V]]]: """ Concurrently awaits a mapping of awaitable objects and returns an iterator or async iterator of results. @@ -109,32 +184,63 @@ def as_completed_mapping(mapping: Mapping[K, Awaitable[V]], *, timeout: Optional Example: ``` mapping = {'key1': async_function1(), 'key2': async_function2()} - + for coro in as_completed_mapping(mapping): k, v = await coro ... - + async for k, v in as_completed_mapping(mapping, aiter=True): ... ``` """ - return as_completed([__mapping_wrap(k, v, return_exceptions=return_exceptions) for k, v in mapping.items()], timeout=timeout, aiter=aiter, tqdm=tqdm, **tqdm_kwargs) + return as_completed( + [ + __mapping_wrap(k, v, return_exceptions=return_exceptions) + for k, v in mapping.items() + ], + timeout=timeout, + aiter=aiter, + tqdm=tqdm, + **tqdm_kwargs + ) + async def _exc_wrap(awaitable: Awaitable[T]) -> Union[T, Exception]: try: return await awaitable except Exception as e: return e - -async def __yield_as_completed(futs: Iterable[Awaitable[T]], *, timeout: Optional[float] = None, return_exceptions: bool = False, tqdm: bool = False, **tqdm_kwargs: Any) -> AsyncIterator[T]: - for fut in as_completed(futs, timeout=timeout, return_exceptions=return_exceptions, tqdm=tqdm, **tqdm_kwargs): + + +async def __yield_as_completed( + futs: Iterable[Awaitable[T]], + *, + timeout: Optional[float] = None, + return_exceptions: bool = False, + tqdm: bool = False, + **tqdm_kwargs: Any +) -> AsyncIterator[T]: + for fut in as_completed( + futs, + timeout=timeout, + return_exceptions=return_exceptions, + tqdm=tqdm, + **tqdm_kwargs + ): yield await fut + @overload -async def __mapping_wrap(k: K, v: Awaitable[V], return_exceptions: Literal[True] = True) -> Union[V, Exception]:... +async def __mapping_wrap( + k: K, v: Awaitable[V], return_exceptions: Literal[True] = True +) -> Union[V, Exception]: ... @overload -async def __mapping_wrap(k: K, v: Awaitable[V], return_exceptions: Literal[False] = False) -> V:... -async def __mapping_wrap(k: K, v: Awaitable[V], return_exceptions: bool = False) -> Union[V, Exception]: +async def __mapping_wrap( + k: K, v: Awaitable[V], return_exceptions: Literal[False] = False +) -> V: ... +async def __mapping_wrap( + k: K, v: Awaitable[V], return_exceptions: bool = False +) -> Union[V, Exception]: try: return k, await v except Exception as e: @@ -142,4 +248,5 @@ async def __mapping_wrap(k: K, v: Awaitable[V], return_exceptions: bool = False) return k, e raise -__all__ = ["as_completed", "as_completed_mapping"] \ No newline at end of file + +__all__ = ["as_completed", "as_completed_mapping"] diff --git a/a_sync/asyncio/create_task.py b/a_sync/asyncio/create_task.py index 64cbfc74..1f7fe57f 100644 --- a/a_sync/asyncio/create_task.py +++ b/a_sync/asyncio/create_task.py @@ -12,10 +12,11 @@ logger = logging.getLogger(__name__) + def create_task( - coro: Awaitable[T], - *, - name: Optional[str] = None, + coro: Awaitable[T], + *, + name: Optional[str] = None, skip_gc_until_done: bool = False, log_destroy_pending: bool = True, ) -> "asyncio.Task[T]": @@ -50,6 +51,7 @@ def create_task( __persisted_tasks: Set["asyncio.Task[Any]"] = set() + async def __await(awaitable: Awaitable[T]) -> T: """Wait for the completion of an Awaitable.""" try: @@ -60,6 +62,7 @@ async def __await(awaitable: Awaitable[T]) -> T: args.append(awaitable._children) raise RuntimeError(*args) from None + def __prune_persisted_tasks(): """Remove completed tasks from the set of persisted tasks.""" for task in tuple(__persisted_tasks): @@ -68,18 +71,19 @@ def __prune_persisted_tasks(): if not isinstance(e, exceptions.PersistedTaskException): logger.exception(e) raise e - # we have to manually log the traceback that asyncio would usually log + # we have to manually log the traceback that asyncio would usually log # since we already got the exception from the task and the usual handler will now not run context = { - 'message': f'{task.__class__.__name__} exception was never retrieved', - 'exception': e, - 'future': task, + "message": f"{task.__class__.__name__} exception was never retrieved", + "exception": e, + "future": task, } if task._source_traceback: - context['source_traceback'] = task._source_traceback + context["source_traceback"] = task._source_traceback task._loop.call_exception_handler(context) __persisted_tasks.discard(task) + async def __persisted_task_exc_wrap(task: "asyncio.Task[T]") -> T: """ Wrap a task to handle its exception in a specialized manner. diff --git a/a_sync/asyncio/gather.py b/a_sync/asyncio/gather.py index a9a66d32..db53b0c5 100644 --- a/a_sync/asyncio/gather.py +++ b/a_sync/asyncio/gather.py @@ -3,15 +3,18 @@ """ import asyncio -from typing import (Any, Awaitable, Dict, List, Mapping, TypeVar, Union, - overload) +from typing import Any, Awaitable, Dict, List, Mapping, TypeVar, Union, overload try: from tqdm.asyncio import tqdm_asyncio except ImportError as e: + class tqdm_asyncio: # type: ignore [no-redef] async def gather(*args, **kwargs): - raise ImportError("You must have tqdm installed in order to use this feature") + raise ImportError( + "You must have tqdm installed in order to use this feature" + ) + from a_sync._typing import * from a_sync.asyncio.as_completed import as_completed_mapping, _exc_wrap @@ -19,29 +22,28 @@ async def gather(*args, **kwargs): Excluder = Callable[[T], bool] + @overload async def gather( - *awaitables: Mapping[K, Awaitable[V]], - return_exceptions: bool = False, - exclude_if: Optional[Excluder[V]] = None, - tqdm: bool = False, + *awaitables: Mapping[K, Awaitable[V]], + return_exceptions: bool = False, + exclude_if: Optional[Excluder[V]] = None, + tqdm: bool = False, **tqdm_kwargs: Any, -) -> Dict[K, V]: - ... +) -> Dict[K, V]: ... @overload async def gather( - *awaitables: Awaitable[T], - return_exceptions: bool = False, + *awaitables: Awaitable[T], + return_exceptions: bool = False, exclude_if: Optional[Excluder[T]] = None, tqdm: bool = False, **tqdm_kwargs: Any, -) -> List[T]: - ... +) -> List[T]: ... async def gather( - *awaitables: Union[Awaitable[T], Mapping[K, Awaitable[V]]], - return_exceptions: bool = False, - exclude_if: Optional[Excluder[T]] = None, - tqdm: bool = False, + *awaitables: Union[Awaitable[T], Mapping[K, Awaitable[V]]], + return_exceptions: bool = False, + exclude_if: Optional[Excluder[T]] = None, + tqdm: bool = False, **tqdm_kwargs: Any, ) -> Union[List[T], Dict[K, V]]: """ @@ -53,7 +55,7 @@ async def gather( - Uses type hints for use with static type checkers. - Supports gathering either individual awaitables or a k:v mapping of awaitables. - Provides progress reporting using tqdm if 'tqdm' is set to True. - + Args: *awaitables: The awaitables to await concurrently. It can be a single awaitable or a mapping of awaitables. return_exceptions (optional): If True, exceptions are returned as results instead of raising them. Defaults to False. @@ -65,7 +67,7 @@ async def gather( Examples: Awaiting individual awaitables: - + - Results will be a list containing the result of each awaitable in sequential order. ``` @@ -75,9 +77,9 @@ async def gather( ``` Awaiting mappings of awaitables: - + - Results will be a dictionary with 'key1' mapped to the result of thing1() and 'key2' mapped to the result of thing2. - + ``` >>> mapping = {'key1': thing1(), 'key2': thing2()} >>> results = await gather(mapping) @@ -87,19 +89,37 @@ async def gather( """ is_mapping = _is_mapping(awaitables) results = await ( - gather_mapping(awaitables[0], return_exceptions=return_exceptions, exclude_if=exclude_if, tqdm=tqdm, **tqdm_kwargs) if is_mapping - else tqdm_asyncio.gather(*(_exc_wrap(a) for a in awaitables) if return_exceptions else awaitables, **tqdm_kwargs) if tqdm - else asyncio.gather(*awaitables, return_exceptions=return_exceptions) # type: ignore [arg-type] + gather_mapping( + awaitables[0], + return_exceptions=return_exceptions, + exclude_if=exclude_if, + tqdm=tqdm, + **tqdm_kwargs, + ) + if is_mapping + else ( + tqdm_asyncio.gather( + *( + (_exc_wrap(a) for a in awaitables) + if return_exceptions + else awaitables + ), + **tqdm_kwargs, + ) + if tqdm + else asyncio.gather(*awaitables, return_exceptions=return_exceptions) + ) # type: ignore [arg-type] ) if exclude_if and not is_mapping: results = [r for r in results if not exclude_if(r)] return results - + + async def gather_mapping( - mapping: Mapping[K, Awaitable[V]], - return_exceptions: bool = False, + mapping: Mapping[K, Awaitable[V]], + return_exceptions: bool = False, exclude_if: Optional[Excluder[V]] = None, - tqdm: bool = False, + tqdm: bool = False, **tqdm_kwargs: Any, ) -> Dict[K, V]: """ @@ -126,14 +146,22 @@ async def gather_mapping( ``` """ results = { - k: v - async for k, v in as_completed_mapping(mapping, return_exceptions=return_exceptions, aiter=True, tqdm=tqdm, **tqdm_kwargs) + k: v + async for k, v in as_completed_mapping( + mapping, + return_exceptions=return_exceptions, + aiter=True, + tqdm=tqdm, + **tqdm_kwargs, + ) if exclude_if is None or not exclude_if(v) } # return data in same order as input mapping - return {k: results[k] for k in mapping} + return {k: results[k] for k in mapping} -_is_mapping = lambda awaitables: len(awaitables) == 1 and isinstance(awaitables[0], Mapping) +_is_mapping = lambda awaitables: len(awaitables) == 1 and isinstance( + awaitables[0], Mapping +) __all__ = ["gather", "gather_mapping"] diff --git a/a_sync/asyncio/utils.py b/a_sync/asyncio/utils.py index 1e98f4c4..16b31911 100644 --- a/a_sync/asyncio/utils.py +++ b/a_sync/asyncio/utils.py @@ -1,12 +1,12 @@ - import asyncio + def get_event_loop() -> asyncio.AbstractEventLoop: try: loop = asyncio.get_event_loop() - except RuntimeError as e: # Necessary for use with multi-threaded applications. + except RuntimeError as e: # Necessary for use with multi-threaded applications. if not str(e).startswith("There is no current event loop in thread"): raise loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - return loop \ No newline at end of file + return loop diff --git a/a_sync/exceptions.py b/a_sync/exceptions.py index 3cf21428..b2df5d21 100644 --- a/a_sync/exceptions.py +++ b/a_sync/exceptions.py @@ -15,6 +15,7 @@ class ASyncFlagException(ValueError): """ Base exception class for flag-related errors in the a_sync library. """ + @property def viable_flags(self) -> Set[str]: """ @@ -23,19 +24,21 @@ def viable_flags(self) -> Set[str]: return VIABLE_FLAGS def desc(self, target) -> str: - if target == 'kwargs': + if target == "kwargs": return "flags present in 'kwargs'" else: - return f'flag attributes defined on {target}' + return f"flag attributes defined on {target}" + class NoFlagsFound(ASyncFlagException): """ Raised when no viable flags are found in the target. """ + def __init__(self, target, kwargs_keys=None): """ Initializes the NoFlagsFound exception. - + Args: target: The target object where flags were expected. kwargs_keys: Optional; keys in the kwargs if applicable. @@ -47,14 +50,16 @@ def __init__(self, target, kwargs_keys=None): err += "\nThis is likely an issue with a custom subclass definition." super().__init__(err) + class TooManyFlags(ASyncFlagException): """ Raised when multiple flags are found, but only one was expected. """ + def __init__(self, target, present_flags): """ Initializes the TooManyFlags exception. - + Args: target: The target object where flags were found. present_flags: The flags that were found. @@ -64,14 +69,16 @@ def __init__(self, target, present_flags): err += "This is likely an issue with a custom subclass definition." super().__init__(err) + class InvalidFlag(ASyncFlagException): """ Raised when an invalid flag is encountered. """ + def __init__(self, flag: Optional[str]): """ Initializes the InvalidFlag exception. - + Args: flag: The invalid flag. """ @@ -79,28 +86,32 @@ def __init__(self, flag: Optional[str]): err += "\nThis code should not be reached and likely indicates an issue with a custom subclass definition." super().__init__(err) + class InvalidFlagValue(ASyncFlagException): """ Raised when a flag has an invalid value. """ + def __init__(self, flag: str, flag_value: Any): """ Initializes the InvalidFlagValue exception. - + Args: flag: The flag with an invalid value. flag_value: The invalid value of the flag. """ super().__init__(f"'{flag}' should be boolean. You passed {flag_value}.") + class FlagNotDefined(ASyncFlagException): """ Raised when a flag is not defined on an object. """ + def __init__(self, obj: Type, flag: str): """ Initializes the FlagNotDefined exception. - + Args: obj: The object where the flag is not defined. flag: The undefined flag. @@ -113,47 +124,58 @@ class ImproperFunctionType(ValueError): Raised when a function that should be sync is async or vice-versa. """ + class FunctionNotAsync(ImproperFunctionType): """ Raised when a function expected to be async is not. """ + def __init__(self, fn): """ Initializes the FunctionNotAsync exception. - + Args: fn: The function that is not async. """ - super().__init__(f"`coro_fn` must be a coroutine function defined with `async def`. You passed {fn}.") + super().__init__( + f"`coro_fn` must be a coroutine function defined with `async def`. You passed {fn}." + ) + class FunctionNotSync(ImproperFunctionType): """ Raised when a function expected to be sync is not. """ + def __init__(self, fn): """ Initializes the FunctionNotSync exception. - + Args: fn: The function that is not sync. """ - super().__init__(f"`func` must be a coroutine function defined with `def`. You passed {fn}.") - + super().__init__( + f"`func` must be a coroutine function defined with `def`. You passed {fn}." + ) + + class ASyncRuntimeError(RuntimeError): def __init__(self, e: RuntimeError): """ Initializes the ASyncRuntimeError exception. - + Args: e: The original runtime error. """ super().__init__(str(e)) + class SyncModeInAsyncContextError(ASyncRuntimeError): """ Raised when synchronous code is used within an asynchronous context. """ - def __init__(self, err: str = ''): + + def __init__(self, err: str = ""): """ Initializes the SyncModeInAsyncContextError exception. """ @@ -163,16 +185,18 @@ def __init__(self, err: str = ''): err += f"{VIABLE_FLAGS}" super().__init__(err) + class MappingError(Exception): """ Base class for errors related to :class:`~TaskMapping`. """ + _msg: str - def __init__(self, mapping: "TaskMapping", msg: str = ''): + def __init__(self, mapping: "TaskMapping", msg: str = ""): """ Initializes the MappingError exception. - + Args: mapping: The TaskMapping where the error occurred. msg: An optional message describing the error. @@ -183,24 +207,30 @@ def __init__(self, mapping: "TaskMapping", msg: str = ''): super().__init__(msg) self.mapping = mapping + class MappingIsEmptyError(MappingError): """ Raised when a TaskMapping is empty and an operation requires it to have items. """ + _msg = "TaskMapping does not contain anything to yield" + class MappingNotEmptyError(MappingError): """ Raised when a TaskMapping is not empty and an operation requires it to be empty. """ + _msg = "TaskMapping already contains some data. In order to use `map`, you need a fresh one" + class PersistedTaskException(Exception): def __init__(self, exc: E, task: asyncio.Task) -> None: super().__init__(f"{exc.__class__.__name__}: {exc}", task) self.exception = exc self.task = task + class EmptySequenceError(ValueError): """ Raised when an operation is attempted on an empty sequence but items are expected. diff --git a/a_sync/executor.py b/a_sync/executor.py index 29a1f6e6..0e85587b 100644 --- a/a_sync/executor.py +++ b/a_sync/executor.py @@ -27,10 +27,12 @@ Initializer = Callable[..., object] + class _AsyncExecutorMixin(cf.Executor, _DebugDaemonMixin): """ A mixin for Executors to provide asynchronous run and submit methods. """ + _max_workers: int _workers: str __slots__ = "_max_workers", "_initializer", "_initargs", "_broken", "_shutdown_lock" @@ -39,7 +41,7 @@ async def run(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: """ A shorthand way to call `await asyncio.get_event_loop().run_in_executor(this_executor, fn, *args)` Doesn't `await this_executor.run(fn, *args)` look so much better? - + Oh, and you can also use kwargs! Args: @@ -50,7 +52,11 @@ async def run(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: Returns: T: The result of the function. """ - return fn(*args, **kwargs) if self.sync_mode else await self.submit(fn, *args, **kwargs) + return ( + fn(*args, **kwargs) + if self.sync_mode + else await self.submit(fn, *args, **kwargs) + ) def submit(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[T]": # type: ignore [override] """ @@ -114,7 +120,7 @@ async def _debug_daemon(self, fut: asyncio.Future, fn, *args, **kwargs) -> None: """ # TODO: make prettier strings for other types if type(fn).__name__ == "function": - fnid = getattr(fn, '__qualname__', fn.__name__) + fnid = getattr(fn, "__qualname__", fn.__name__) if fn.__module__: fnid = f"{fn.__module__}.{fnid}" else: @@ -125,27 +131,39 @@ async def _debug_daemon(self, fut: asyncio.Future, fn, *args, **kwargs) -> None: msg = f"{msg[:-1]} {', '.join(f'{k}={v}' for k, v in kwargs.items())})" else: msg = f"{msg[:-2]})" - + while not fut.done(): await asyncio.sleep(15) if not fut.done(): self.logger.debug(msg, self, fnid) - + + # Process + class AsyncProcessPoolExecutor(_AsyncExecutorMixin, cf.ProcessPoolExecutor): """ An async process pool executor that allows use of kwargs. """ + _workers = "processes" - __slots__ = ("_mp_context", "_processes", "_pending_work_items", "_call_queue", "_result_queue", - "_queue_management_thread", "_queue_count", "_shutdown_thread", "_work_ids", - "_queue_management_thread_wakeup") + __slots__ = ( + "_mp_context", + "_processes", + "_pending_work_items", + "_call_queue", + "_result_queue", + "_queue_management_thread", + "_queue_count", + "_shutdown_thread", + "_work_ids", + "_queue_management_thread_wakeup", + ) def __init__( - self, - max_workers: Optional[int] = None, - mp_context: Optional[multiprocessing.context.BaseContext] = None, + self, + max_workers: Optional[int] = None, + mp_context: Optional[multiprocessing.context.BaseContext] = None, initializer: Optional[Initializer] = None, initargs: Tuple[Any, ...] = (), ) -> None: @@ -164,19 +182,28 @@ def __init__( else: super().__init__(max_workers, mp_context, initializer, initargs) + # Thread + class AsyncThreadPoolExecutor(_AsyncExecutorMixin, cf.ThreadPoolExecutor): """ An async thread pool executor that allows use of kwargs. """ + _workers = "threads" - __slots__ = "_work_queue", "_idle_semaphore", "_threads", "_shutdown", "_thread_name_prefix" + __slots__ = ( + "_work_queue", + "_idle_semaphore", + "_threads", + "_shutdown", + "_thread_name_prefix", + ) def __init__( - self, - max_workers: Optional[int] = None, - thread_name_prefix: str = '', + self, + max_workers: Optional[int] = None, + thread_name_prefix: str = "", initializer: Optional[Initializer] = None, initargs: Tuple[Any, ...] = (), ) -> None: @@ -194,14 +221,18 @@ def __init__( self._max_workers = 0 else: super().__init__(max_workers, thread_name_prefix, initializer, initargs) - + + # For backward-compatibility ProcessPoolExecutor = AsyncProcessPoolExecutor ThreadPoolExecutor = AsyncThreadPoolExecutor # Pruning thread pool -def _worker(executor_reference, work_queue, initializer, initargs, timeout): # NOTE: NEW 'timeout' + +def _worker( + executor_reference, work_queue, initializer, initargs, timeout +): # NOTE: NEW 'timeout' """ Worker function for the PruningThreadPoolExecutor. @@ -216,22 +247,21 @@ def _worker(executor_reference, work_queue, initializer, initargs, timeout): # try: initializer(*initargs) except BaseException: - _base.LOGGER.critical('Exception in initializer:', exc_info=True) + _base.LOGGER.critical("Exception in initializer:", exc_info=True) executor = executor_reference() if executor is not None: executor._initializer_failed() return - + try: while True: try: # NOTE: NEW - work_item = work_queue.get(block=True, - timeout=timeout) # NOTE: NEW + work_item = work_queue.get(block=True, timeout=timeout) # NOTE: NEW except queue.Empty: # NOTE: NEW # Its been 'timeout' seconds and there are no new work items. # NOTE: NEW # Let's suicide the thread. # NOTE: NEW executor = executor_reference() # NOTE: NEW - + with executor._adjusting_lock: # NOTE: NEW # NOTE: We keep a minimum of one thread active to prevent locks if len(executor) > 1: # NOTE: NEW @@ -240,9 +270,9 @@ def _worker(executor_reference, work_queue, initializer, initargs, timeout): # thread._threads_queues.pop(t) # NOTE: NEW # Let the executor know we have one less idle thread available executor._idle_semaphore.acquire(blocking=False) # NOTE: NEW - return # NOTE: NEW + return # NOTE: NEW continue - + if work_item is not None: work_item.run() # Delete references to object. See issue16284 @@ -269,17 +299,25 @@ def _worker(executor_reference, work_queue, initializer, initargs, timeout): # return del executor except BaseException: - _base.LOGGER.critical('Exception in worker', exc_info=True) + _base.LOGGER.critical("Exception in worker", exc_info=True) + class PruningThreadPoolExecutor(AsyncThreadPoolExecutor): """ This `AsyncThreadPoolExecutor` implementation prunes inactive threads after 'timeout' seconds without a work item. Pruned threads will be automatically recreated as needed for future workloads. Up to 'max_threads' can be active at any one time. """ + __slots__ = "_timeout", "_adjusting_lock" - def __init__(self, max_workers=None, thread_name_prefix='', - initializer=None, initargs=(), timeout=TEN_MINUTES): + def __init__( + self, + max_workers=None, + thread_name_prefix="", + initializer=None, + initargs=(), + timeout=TEN_MINUTES, + ): """ Initializes the PruningThreadPoolExecutor. @@ -290,13 +328,13 @@ def __init__(self, max_workers=None, thread_name_prefix='', initargs (Tuple[Any, ...], optional): Arguments for the initializer. Defaults to (). timeout (int, optional): Timeout duration for pruning inactive threads. Defaults to TEN_MINUTES. """ - self._timeout=timeout + self._timeout = timeout self._adjusting_lock = threading.Lock() super().__init__(max_workers, thread_name_prefix, initializer, initargs) - + def __len__(self) -> int: return len(self._threads) - + def _adjust_thread_count(self): """ Adjusts the number of threads based on workload and idle threads. @@ -313,19 +351,24 @@ def weakref_cb(_, q=self._work_queue): num_threads = len(self._threads) if num_threads < self._max_workers: - thread_name = '%s_%d' % (self._thread_name_prefix or self, - num_threads) - t = threading.Thread(name=thread_name, target=_worker, - args=(weakref.ref(self, weakref_cb), - self._work_queue, - self._initializer, - self._initargs, - self._timeout)) + thread_name = "%s_%d" % (self._thread_name_prefix or self, num_threads) + t = threading.Thread( + name=thread_name, + target=_worker, + args=( + weakref.ref(self, weakref_cb), + self._work_queue, + self._initializer, + self._initargs, + self._timeout, + ), + ) t.daemon = True t.start() self._threads.add(t) thread._threads_queues[t] = self._work_queue + executor = PruningThreadPoolExecutor(128) __all__ = [ diff --git a/a_sync/future.py b/a_sync/future.py index fe78face..70259f6c 100644 --- a/a_sync/future.py +++ b/a_sync/future.py @@ -8,27 +8,41 @@ from a_sync._typing import * -def future(callable: Union[Callable[P, Awaitable[T]], Callable[P, T]] = None, **kwargs: Unpack[ModifierKwargs]) -> Callable[P, Union[T, "ASyncFuture[T]"]]: +def future( + callable: AnyFn[P, T] = None, + **kwargs: Unpack[ModifierKwargs], +) -> Callable[P, Union[T, "ASyncFuture[T]"]]: return _ASyncFutureWrappedFn(callable, **kwargs) + async def _gather_check_and_materialize(*things: Unpack[MaybeAwaitable[T]]) -> List[T]: return await asyncio.gather(*[_check_and_materialize(thing) for thing in things]) + async def _check_and_materialize(thing: T) -> T: return await thing if isawaitable(thing) else thing - + + def _materialize(meta: "ASyncFuture[T]") -> T: try: return asyncio.get_event_loop().run_until_complete(meta) except RuntimeError as e: - raise RuntimeError(f"{meta} result is not set and the event loop is running, you will need to await it first") from e + raise RuntimeError( + f"{meta} result is not set and the event loop is running, you will need to await it first" + ) from e + + +MetaNumeric = Union[ + Numeric, "ASyncFuture[int]", "ASyncFuture[float]", "ASyncFuture[Decimal]" +] -MetaNumeric = Union[Numeric, "ASyncFuture[int]", "ASyncFuture[float]", "ASyncFuture[Decimal]"] class ASyncFuture(concurrent.futures.Future, Awaitable[T]): __slots__ = "__awaitable__", "__dependencies", "__dependants", "__task" - - def __init__(self, awaitable: Awaitable[T], dependencies: List["ASyncFuture"] = []) -> None: + + def __init__( + self, awaitable: Awaitable[T], dependencies: List["ASyncFuture"] = [] + ) -> None: self.__awaitable__ = awaitable self.__dependencies = dependencies for dependency in dependencies: @@ -37,19 +51,27 @@ def __init__(self, awaitable: Awaitable[T], dependencies: List["ASyncFuture"] = self.__dependants: List[ASyncFuture] = [] self.__task = None super().__init__() + def __hash__(self) -> int: return hash(self.__awaitable__) + def __repr__(self) -> str: string = f"<{self.__class__.__name__} {self._state} for {self.__awaitable__}" if self.cancelled(): pass elif self.done(): - string += f" exception={self.exception()}" if self.exception() else f" result={super().result()}" + string += ( + f" exception={self.exception()}" + if self.exception() + else f" result={super().result()}" + ) return string + ">" + def __list_dependencies(self, other) -> List["ASyncFuture"]: if isinstance(other, ASyncFuture): return [self, other] return [self] + @property def result(self) -> Union[Callable[[], T], Any]: """ @@ -58,407 +80,665 @@ def result(self) -> Union[Callable[[], T], Any]: If this future is done and the result does NOT have attribute `results`, will again work like cf.Future.result """ if self.done(): - if hasattr(r := super().result(), 'result'): + if hasattr(r := super().result(), "result"): # can be property, method, whatever. should work. return r.result # the result should be callable like an asyncio.Future return super().result return lambda: _materialize(self) + def __getattr__(self, attr: str) -> Any: return getattr(_materialize(self), attr) + def __getitem__(self, key) -> Any: return _materialize(self)[key] + # NOTE: broken, do not use. I think def __setitem__(self, key, value) -> None: _materialize(self)[key] = value + # not sure what to call these def __contains__(self, key: Any) -> bool: - return _materialize(ASyncFuture(self.__contains(key), dependencies=self.__list_dependencies(key))) + return _materialize( + ASyncFuture( + self.__contains(key), dependencies=self.__list_dependencies(key) + ) + ) + def __await__(self) -> Generator[Any, None, T]: return self.__await().__await__() + async def __await(self) -> T: if not self.done(): self.set_result(await self.__task__) return self._result + @property def __task__(self) -> "asyncio.Task[T]": if self.__task is None: self.__task = asyncio.create_task(self.__awaitable__) return self.__task + def __iter__(self): return _materialize(self).__iter__() + def __next__(self): return _materialize(self).__next__() + def __enter__(self): return _materialize(self).__enter__() + def __exit__(self, *args): return _materialize(self).__exit__(*args) + @overload - def __add__(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]":... + def __add__(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]": ... @overload - def __add__(self: "ASyncFuture[float]", other: float) -> "ASyncFuture[float]":... + def __add__(self: "ASyncFuture[float]", other: float) -> "ASyncFuture[float]": ... @overload - def __add__(self: "ASyncFuture[float]", other: int) -> "ASyncFuture[float]":... + def __add__(self: "ASyncFuture[float]", other: int) -> "ASyncFuture[float]": ... @overload - def __add__(self: "ASyncFuture[int]", other: float) -> "ASyncFuture[float]":... + def __add__(self: "ASyncFuture[int]", other: float) -> "ASyncFuture[float]": ... @overload - def __add__(self: "ASyncFuture[Decimal]", other: Decimal) -> "ASyncFuture[Decimal]":... + def __add__( + self: "ASyncFuture[Decimal]", other: Decimal + ) -> "ASyncFuture[Decimal]": ... @overload - def __add__(self: "ASyncFuture[Decimal]", other: int) -> "ASyncFuture[Decimal]":... + def __add__(self: "ASyncFuture[Decimal]", other: int) -> "ASyncFuture[Decimal]": ... @overload - def __add__(self: "ASyncFuture[int]", other: Decimal) -> "ASyncFuture[Decimal]":... + def __add__(self: "ASyncFuture[int]", other: Decimal) -> "ASyncFuture[Decimal]": ... @overload - def __add__(self: "ASyncFuture[int]", other: Awaitable[int]) -> "ASyncFuture[int]":... + def __add__( + self: "ASyncFuture[int]", other: Awaitable[int] + ) -> "ASyncFuture[int]": ... @overload - def __add__(self: "ASyncFuture[float]", other: Awaitable[float]) -> "ASyncFuture[float]":... + def __add__( + self: "ASyncFuture[float]", other: Awaitable[float] + ) -> "ASyncFuture[float]": ... @overload - def __add__(self: "ASyncFuture[float]", other: Awaitable[int]) -> "ASyncFuture[float]":... + def __add__( + self: "ASyncFuture[float]", other: Awaitable[int] + ) -> "ASyncFuture[float]": ... @overload - def __add__(self: "ASyncFuture[int]", other: Awaitable[float]) -> "ASyncFuture[float]":... + def __add__( + self: "ASyncFuture[int]", other: Awaitable[float] + ) -> "ASyncFuture[float]": ... @overload - def __add__(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... + def __add__( + self: "ASyncFuture[Decimal]", other: Awaitable[Decimal] + ) -> "ASyncFuture[Decimal]": ... @overload - def __add__(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":... + def __add__( + self: "ASyncFuture[Decimal]", other: Awaitable[int] + ) -> "ASyncFuture[Decimal]": ... @overload - def __add__(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... + def __add__( + self: "ASyncFuture[int]", other: Awaitable[Decimal] + ) -> "ASyncFuture[Decimal]": ... def __add__(self, other: MetaNumeric) -> "ASyncFuture": - return ASyncFuture(self.__add(other), dependencies=self.__list_dependencies(other)) + return ASyncFuture( + self.__add(other), dependencies=self.__list_dependencies(other) + ) + @overload - def __sub__(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]":... + def __sub__(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]": ... @overload - def __sub__(self: "ASyncFuture[float]", other: float) -> "ASyncFuture[float]":... + def __sub__(self: "ASyncFuture[float]", other: float) -> "ASyncFuture[float]": ... @overload - def __sub__(self: "ASyncFuture[float]", other: int) -> "ASyncFuture[float]":... + def __sub__(self: "ASyncFuture[float]", other: int) -> "ASyncFuture[float]": ... @overload - def __sub__(self: "ASyncFuture[int]", other: float) -> "ASyncFuture[float]":... + def __sub__(self: "ASyncFuture[int]", other: float) -> "ASyncFuture[float]": ... @overload - def __sub__(self: "ASyncFuture[Decimal]", other: Decimal) -> "ASyncFuture[Decimal]":... + def __sub__( + self: "ASyncFuture[Decimal]", other: Decimal + ) -> "ASyncFuture[Decimal]": ... @overload - def __sub__(self: "ASyncFuture[Decimal]", other: int) -> "ASyncFuture[Decimal]":... + def __sub__(self: "ASyncFuture[Decimal]", other: int) -> "ASyncFuture[Decimal]": ... @overload - def __sub__(self: "ASyncFuture[int]", other: Decimal) -> "ASyncFuture[Decimal]":... + def __sub__(self: "ASyncFuture[int]", other: Decimal) -> "ASyncFuture[Decimal]": ... @overload - def __sub__(self: "ASyncFuture[int]", other: Awaitable[int]) -> "ASyncFuture[int]":... + def __sub__( + self: "ASyncFuture[int]", other: Awaitable[int] + ) -> "ASyncFuture[int]": ... @overload - def __sub__(self: "ASyncFuture[float]", other: Awaitable[float]) -> "ASyncFuture[float]":... + def __sub__( + self: "ASyncFuture[float]", other: Awaitable[float] + ) -> "ASyncFuture[float]": ... @overload - def __sub__(self: "ASyncFuture[float]", other: Awaitable[int]) -> "ASyncFuture[float]":... + def __sub__( + self: "ASyncFuture[float]", other: Awaitable[int] + ) -> "ASyncFuture[float]": ... @overload - def __sub__(self: "ASyncFuture[int]", other: Awaitable[float]) -> "ASyncFuture[float]":... + def __sub__( + self: "ASyncFuture[int]", other: Awaitable[float] + ) -> "ASyncFuture[float]": ... @overload - def __sub__(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... + def __sub__( + self: "ASyncFuture[Decimal]", other: Awaitable[Decimal] + ) -> "ASyncFuture[Decimal]": ... @overload - def __sub__(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":... + def __sub__( + self: "ASyncFuture[Decimal]", other: Awaitable[int] + ) -> "ASyncFuture[Decimal]": ... @overload - def __sub__(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... + def __sub__( + self: "ASyncFuture[int]", other: Awaitable[Decimal] + ) -> "ASyncFuture[Decimal]": ... def __sub__(self, other: MetaNumeric) -> "ASyncFuture": - return ASyncFuture(self.__sub(other), dependencies=self.__list_dependencies(other)) + return ASyncFuture( + self.__sub(other), dependencies=self.__list_dependencies(other) + ) + def __mul__(self, other) -> "ASyncFuture": - return ASyncFuture(self.__mul(other), dependencies=self.__list_dependencies(other)) + return ASyncFuture( + self.__mul(other), dependencies=self.__list_dependencies(other) + ) + def __pow__(self, other) -> "ASyncFuture": - return ASyncFuture(self.__pow(other), dependencies=self.__list_dependencies(other)) + return ASyncFuture( + self.__pow(other), dependencies=self.__list_dependencies(other) + ) + def __truediv__(self, other) -> "ASyncFuture": - return ASyncFuture(self.__truediv(other), dependencies=self.__list_dependencies(other)) + return ASyncFuture( + self.__truediv(other), dependencies=self.__list_dependencies(other) + ) + def __floordiv__(self, other) -> "ASyncFuture": - return ASyncFuture(self.__floordiv(other), dependencies=self.__list_dependencies(other)) + return ASyncFuture( + self.__floordiv(other), dependencies=self.__list_dependencies(other) + ) + def __pow__(self, other) -> "ASyncFuture": - return ASyncFuture(self.__pow(other), dependencies=self.__list_dependencies(other)) + return ASyncFuture( + self.__pow(other), dependencies=self.__list_dependencies(other) + ) + @overload - def __radd__(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]":... + def __radd__(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]": ... @overload - def __radd__(self: "ASyncFuture[float]", other: float) -> "ASyncFuture[float]":... + def __radd__(self: "ASyncFuture[float]", other: float) -> "ASyncFuture[float]": ... @overload - def __radd__(self: "ASyncFuture[float]", other: int) -> "ASyncFuture[float]":... + def __radd__(self: "ASyncFuture[float]", other: int) -> "ASyncFuture[float]": ... @overload - def __radd__(self: "ASyncFuture[int]", other: float) -> "ASyncFuture[float]":... + def __radd__(self: "ASyncFuture[int]", other: float) -> "ASyncFuture[float]": ... @overload - def __radd__(self: "ASyncFuture[Decimal]", other: Decimal) -> "ASyncFuture[Decimal]":... + def __radd__( + self: "ASyncFuture[Decimal]", other: Decimal + ) -> "ASyncFuture[Decimal]": ... @overload - def __radd__(self: "ASyncFuture[Decimal]", other: int) -> "ASyncFuture[Decimal]":... + def __radd__( + self: "ASyncFuture[Decimal]", other: int + ) -> "ASyncFuture[Decimal]": ... @overload - def __radd__(self: "ASyncFuture[int]", other: Decimal) -> "ASyncFuture[Decimal]":... + def __radd__( + self: "ASyncFuture[int]", other: Decimal + ) -> "ASyncFuture[Decimal]": ... @overload - def __radd__(self: "ASyncFuture[int]", other: Awaitable[int]) -> "ASyncFuture[int]":... + def __radd__( + self: "ASyncFuture[int]", other: Awaitable[int] + ) -> "ASyncFuture[int]": ... @overload - def __radd__(self: "ASyncFuture[float]", other: Awaitable[float]) -> "ASyncFuture[float]":... + def __radd__( + self: "ASyncFuture[float]", other: Awaitable[float] + ) -> "ASyncFuture[float]": ... @overload - def __radd__(self: "ASyncFuture[float]", other: Awaitable[int]) -> "ASyncFuture[float]":... + def __radd__( + self: "ASyncFuture[float]", other: Awaitable[int] + ) -> "ASyncFuture[float]": ... @overload - def __radd__(self: "ASyncFuture[int]", other: Awaitable[float]) -> "ASyncFuture[float]":... + def __radd__( + self: "ASyncFuture[int]", other: Awaitable[float] + ) -> "ASyncFuture[float]": ... @overload - def __radd__(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... + def __radd__( + self: "ASyncFuture[Decimal]", other: Awaitable[Decimal] + ) -> "ASyncFuture[Decimal]": ... @overload - def __radd__(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":... + def __radd__( + self: "ASyncFuture[Decimal]", other: Awaitable[int] + ) -> "ASyncFuture[Decimal]": ... @overload - def __radd__(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... + def __radd__( + self: "ASyncFuture[int]", other: Awaitable[Decimal] + ) -> "ASyncFuture[Decimal]": ... def __radd__(self, other) -> "ASyncFuture": - return ASyncFuture(self.__radd(other), dependencies=self.__list_dependencies(other)) - + return ASyncFuture( + self.__radd(other), dependencies=self.__list_dependencies(other) + ) + @overload - def __rsub__(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]":... + def __rsub__(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]": ... @overload - def __rsub__(self: "ASyncFuture[float]", other: float) -> "ASyncFuture[float]":... + def __rsub__(self: "ASyncFuture[float]", other: float) -> "ASyncFuture[float]": ... @overload - def __rsub__(self: "ASyncFuture[float]", other: int) -> "ASyncFuture[float]":... + def __rsub__(self: "ASyncFuture[float]", other: int) -> "ASyncFuture[float]": ... @overload - def __rsub__(self: "ASyncFuture[int]", other: float) -> "ASyncFuture[float]":... + def __rsub__(self: "ASyncFuture[int]", other: float) -> "ASyncFuture[float]": ... @overload - def __rsub__(self: "ASyncFuture[Decimal]", other: Decimal) -> "ASyncFuture[Decimal]":... + def __rsub__( + self: "ASyncFuture[Decimal]", other: Decimal + ) -> "ASyncFuture[Decimal]": ... @overload - def __rsub__(self: "ASyncFuture[Decimal]", other: int) -> "ASyncFuture[Decimal]":... + def __rsub__( + self: "ASyncFuture[Decimal]", other: int + ) -> "ASyncFuture[Decimal]": ... @overload - def __rsub__(self: "ASyncFuture[int]", other: Decimal) -> "ASyncFuture[Decimal]":... + def __rsub__( + self: "ASyncFuture[int]", other: Decimal + ) -> "ASyncFuture[Decimal]": ... @overload - def __rsub__(self: "ASyncFuture[int]", other: Awaitable[int]) -> "ASyncFuture[int]":... + def __rsub__( + self: "ASyncFuture[int]", other: Awaitable[int] + ) -> "ASyncFuture[int]": ... @overload - def __rsub__(self: "ASyncFuture[float]", other: Awaitable[float]) -> "ASyncFuture[float]":... + def __rsub__( + self: "ASyncFuture[float]", other: Awaitable[float] + ) -> "ASyncFuture[float]": ... @overload - def __rsub__(self: "ASyncFuture[float]", other: Awaitable[int]) -> "ASyncFuture[float]":... + def __rsub__( + self: "ASyncFuture[float]", other: Awaitable[int] + ) -> "ASyncFuture[float]": ... @overload - def __rsub__(self: "ASyncFuture[int]", other: Awaitable[float]) -> "ASyncFuture[float]":... + def __rsub__( + self: "ASyncFuture[int]", other: Awaitable[float] + ) -> "ASyncFuture[float]": ... @overload - def __rsub__(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... + def __rsub__( + self: "ASyncFuture[Decimal]", other: Awaitable[Decimal] + ) -> "ASyncFuture[Decimal]": ... @overload - def __rsub__(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":... + def __rsub__( + self: "ASyncFuture[Decimal]", other: Awaitable[int] + ) -> "ASyncFuture[Decimal]": ... @overload - def __rsub__(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... + def __rsub__( + self: "ASyncFuture[int]", other: Awaitable[Decimal] + ) -> "ASyncFuture[Decimal]": ... def __rsub__(self, other) -> "ASyncFuture": - return ASyncFuture(self.__rsub(other), dependencies=self.__list_dependencies(other)) + return ASyncFuture( + self.__rsub(other), dependencies=self.__list_dependencies(other) + ) + def __rmul__(self, other) -> "ASyncFuture": - return ASyncFuture(self.__rmul(other), dependencies=self.__list_dependencies(other)) + return ASyncFuture( + self.__rmul(other), dependencies=self.__list_dependencies(other) + ) + def __rtruediv__(self, other) -> "ASyncFuture": - return ASyncFuture(self.__rtruediv(other), dependencies=self.__list_dependencies(other)) + return ASyncFuture( + self.__rtruediv(other), dependencies=self.__list_dependencies(other) + ) + def __rfloordiv__(self, other) -> "ASyncFuture": - return ASyncFuture(self.__rfloordiv(other), dependencies=self.__list_dependencies(other)) + return ASyncFuture( + self.__rfloordiv(other), dependencies=self.__list_dependencies(other) + ) + def __rpow__(self, other) -> "ASyncFuture": - return ASyncFuture(self.__rpow(other), dependencies=self.__list_dependencies(other)) + return ASyncFuture( + self.__rpow(other), dependencies=self.__list_dependencies(other) + ) + def __eq__(self, other) -> "ASyncFuture": - return bool(ASyncFuture(self.__eq(other), dependencies=self.__list_dependencies(other))) + return bool( + ASyncFuture(self.__eq(other), dependencies=self.__list_dependencies(other)) + ) + def __gt__(self, other) -> "ASyncFuture": - return ASyncFuture(self.__gt(other), dependencies=self.__list_dependencies(other)) + return ASyncFuture( + self.__gt(other), dependencies=self.__list_dependencies(other) + ) + def __ge__(self, other) -> "ASyncFuture": - return ASyncFuture(self.__ge(other), dependencies=self.__list_dependencies(other)) + return ASyncFuture( + self.__ge(other), dependencies=self.__list_dependencies(other) + ) + def __lt__(self, other) -> "ASyncFuture": - return ASyncFuture(self.__lt(other), dependencies=self.__list_dependencies(other)) + return ASyncFuture( + self.__lt(other), dependencies=self.__list_dependencies(other) + ) + def __le__(self, other) -> "ASyncFuture": - return ASyncFuture(self.__le(other), dependencies=self.__list_dependencies(other)) - + return ASyncFuture( + self.__le(other), dependencies=self.__list_dependencies(other) + ) + # Maths - + @overload - async def __add(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]":... + async def __add(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]": ... @overload - async def __add(self: "ASyncFuture[float]", other: float) -> "ASyncFuture[float]":... + async def __add( + self: "ASyncFuture[float]", other: float + ) -> "ASyncFuture[float]": ... @overload - async def __add(self: "ASyncFuture[float]", other: int) -> "ASyncFuture[float]":... + async def __add(self: "ASyncFuture[float]", other: int) -> "ASyncFuture[float]": ... @overload - async def __add(self: "ASyncFuture[int]", other: float) -> "ASyncFuture[float]":... + async def __add(self: "ASyncFuture[int]", other: float) -> "ASyncFuture[float]": ... @overload - async def __add(self: "ASyncFuture[Decimal]", other: Decimal) -> "ASyncFuture[Decimal]":... + async def __add( + self: "ASyncFuture[Decimal]", other: Decimal + ) -> "ASyncFuture[Decimal]": ... @overload - async def __add(self: "ASyncFuture[Decimal]", other: int) -> "ASyncFuture[Decimal]":... + async def __add( + self: "ASyncFuture[Decimal]", other: int + ) -> "ASyncFuture[Decimal]": ... @overload - async def __add(self: "ASyncFuture[int]", other: Decimal) -> "ASyncFuture[Decimal]":... + async def __add( + self: "ASyncFuture[int]", other: Decimal + ) -> "ASyncFuture[Decimal]": ... @overload - async def __add(self: "ASyncFuture[int]", other: Awaitable[int]) -> "ASyncFuture[int]":... + async def __add( + self: "ASyncFuture[int]", other: Awaitable[int] + ) -> "ASyncFuture[int]": ... @overload - async def __add(self: "ASyncFuture[float]", other: Awaitable[float]) -> "ASyncFuture[float]":... + async def __add( + self: "ASyncFuture[float]", other: Awaitable[float] + ) -> "ASyncFuture[float]": ... @overload - async def __add(self: "ASyncFuture[float]", other: Awaitable[int]) -> "ASyncFuture[float]":... + async def __add( + self: "ASyncFuture[float]", other: Awaitable[int] + ) -> "ASyncFuture[float]": ... @overload - async def __add(self: "ASyncFuture[int]", other: Awaitable[float]) -> "ASyncFuture[float]":... + async def __add( + self: "ASyncFuture[int]", other: Awaitable[float] + ) -> "ASyncFuture[float]": ... @overload - async def __add(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... + async def __add( + self: "ASyncFuture[Decimal]", other: Awaitable[Decimal] + ) -> "ASyncFuture[Decimal]": ... @overload - async def __add(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":... + async def __add( + self: "ASyncFuture[Decimal]", other: Awaitable[int] + ) -> "ASyncFuture[Decimal]": ... @overload - async def __add(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... + async def __add( + self: "ASyncFuture[int]", other: Awaitable[Decimal] + ) -> "ASyncFuture[Decimal]": ... async def __add(self, other) -> "Any": a, b = await _gather_check_and_materialize(self, other) return a + b + @overload - async def __sub(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]":... + async def __sub(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]": ... @overload - async def __sub(self: "ASyncFuture[float]", other: float) -> "ASyncFuture[float]":... + async def __sub( + self: "ASyncFuture[float]", other: float + ) -> "ASyncFuture[float]": ... @overload - async def __sub(self: "ASyncFuture[float]", other: int) -> "ASyncFuture[float]":... + async def __sub(self: "ASyncFuture[float]", other: int) -> "ASyncFuture[float]": ... @overload - async def __sub(self: "ASyncFuture[int]", other: float) -> "ASyncFuture[float]":... + async def __sub(self: "ASyncFuture[int]", other: float) -> "ASyncFuture[float]": ... @overload - async def __sub(self: "ASyncFuture[Decimal]", other: Decimal) -> "ASyncFuture[Decimal]":... + async def __sub( + self: "ASyncFuture[Decimal]", other: Decimal + ) -> "ASyncFuture[Decimal]": ... @overload - async def __sub(self: "ASyncFuture[Decimal]", other: int) -> "ASyncFuture[Decimal]":... + async def __sub( + self: "ASyncFuture[Decimal]", other: int + ) -> "ASyncFuture[Decimal]": ... @overload - async def __sub(self: "ASyncFuture[int]", other: Decimal) -> "ASyncFuture[Decimal]":... + async def __sub( + self: "ASyncFuture[int]", other: Decimal + ) -> "ASyncFuture[Decimal]": ... @overload - async def __sub(self: "ASyncFuture[int]", other: Awaitable[int]) -> "ASyncFuture[int]":... + async def __sub( + self: "ASyncFuture[int]", other: Awaitable[int] + ) -> "ASyncFuture[int]": ... @overload - async def __sub(self: "ASyncFuture[float]", other: Awaitable[float]) -> "ASyncFuture[float]":... + async def __sub( + self: "ASyncFuture[float]", other: Awaitable[float] + ) -> "ASyncFuture[float]": ... @overload - async def __sub(self: "ASyncFuture[float]", other: Awaitable[int]) -> "ASyncFuture[float]":... + async def __sub( + self: "ASyncFuture[float]", other: Awaitable[int] + ) -> "ASyncFuture[float]": ... @overload - async def __sub(self: "ASyncFuture[int]", other: Awaitable[float]) -> "ASyncFuture[float]":... + async def __sub( + self: "ASyncFuture[int]", other: Awaitable[float] + ) -> "ASyncFuture[float]": ... @overload - async def __sub(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... + async def __sub( + self: "ASyncFuture[Decimal]", other: Awaitable[Decimal] + ) -> "ASyncFuture[Decimal]": ... @overload - async def __sub(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":... + async def __sub( + self: "ASyncFuture[Decimal]", other: Awaitable[int] + ) -> "ASyncFuture[Decimal]": ... @overload - async def __sub(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... + async def __sub( + self: "ASyncFuture[int]", other: Awaitable[Decimal] + ) -> "ASyncFuture[Decimal]": ... async def __sub(self, other) -> "Any": a, b = await _gather_check_and_materialize(self, other) return a - b + async def __mul(self, other) -> "Any": a, b = await _gather_check_and_materialize(self, other) return a * b + async def __truediv(self, other) -> "Any": a, b = await _gather_check_and_materialize(self, other) return a / b + async def __floordiv(self, other) -> "Any": a, b = await _gather_check_and_materialize(self, other) return a // b + async def __pow(self, other) -> "Any": a, b = await _gather_check_and_materialize(self, other) - return a ** b - + return a**b + # rMaths @overload - async def __radd(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]":... + async def __radd(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]": ... @overload - async def __radd(self: "ASyncFuture[float]", other: float) -> "ASyncFuture[float]":... + async def __radd( + self: "ASyncFuture[float]", other: float + ) -> "ASyncFuture[float]": ... @overload - async def __radd(self: "ASyncFuture[float]", other: int) -> "ASyncFuture[float]":... + async def __radd( + self: "ASyncFuture[float]", other: int + ) -> "ASyncFuture[float]": ... @overload - async def __radd(self: "ASyncFuture[int]", other: float) -> "ASyncFuture[float]":... + async def __radd( + self: "ASyncFuture[int]", other: float + ) -> "ASyncFuture[float]": ... @overload - async def __radd(self: "ASyncFuture[Decimal]", other: Decimal) -> "ASyncFuture[Decimal]":... + async def __radd( + self: "ASyncFuture[Decimal]", other: Decimal + ) -> "ASyncFuture[Decimal]": ... @overload - async def __radd(self: "ASyncFuture[Decimal]", other: int) -> "ASyncFuture[Decimal]":... + async def __radd( + self: "ASyncFuture[Decimal]", other: int + ) -> "ASyncFuture[Decimal]": ... @overload - async def __radd(self: "ASyncFuture[int]", other: Decimal) -> "ASyncFuture[Decimal]":... + async def __radd( + self: "ASyncFuture[int]", other: Decimal + ) -> "ASyncFuture[Decimal]": ... @overload - async def __radd(self: "ASyncFuture[int]", other: Awaitable[int]) -> "ASyncFuture[int]":... + async def __radd( + self: "ASyncFuture[int]", other: Awaitable[int] + ) -> "ASyncFuture[int]": ... @overload - async def __radd(self: "ASyncFuture[float]", other: Awaitable[float]) -> "ASyncFuture[float]":... + async def __radd( + self: "ASyncFuture[float]", other: Awaitable[float] + ) -> "ASyncFuture[float]": ... @overload - async def __radd(self: "ASyncFuture[float]", other: Awaitable[int]) -> "ASyncFuture[float]":... + async def __radd( + self: "ASyncFuture[float]", other: Awaitable[int] + ) -> "ASyncFuture[float]": ... @overload - async def __radd(self: "ASyncFuture[int]", other: Awaitable[float]) -> "ASyncFuture[float]":... + async def __radd( + self: "ASyncFuture[int]", other: Awaitable[float] + ) -> "ASyncFuture[float]": ... @overload - async def __radd(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... + async def __radd( + self: "ASyncFuture[Decimal]", other: Awaitable[Decimal] + ) -> "ASyncFuture[Decimal]": ... @overload - async def __radd(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":... + async def __radd( + self: "ASyncFuture[Decimal]", other: Awaitable[int] + ) -> "ASyncFuture[Decimal]": ... @overload - async def __radd(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... + async def __radd( + self: "ASyncFuture[int]", other: Awaitable[Decimal] + ) -> "ASyncFuture[Decimal]": ... async def __radd(self, other) -> "Any": a, b = await _gather_check_and_materialize(other, self) return a + b + @overload - async def __rsub(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]":... + async def __rsub(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]": ... @overload - async def __rsub(self: "ASyncFuture[float]", other: float) -> "ASyncFuture[float]":... + async def __rsub( + self: "ASyncFuture[float]", other: float + ) -> "ASyncFuture[float]": ... @overload - async def __rsub(self: "ASyncFuture[float]", other: int) -> "ASyncFuture[float]":... + async def __rsub( + self: "ASyncFuture[float]", other: int + ) -> "ASyncFuture[float]": ... @overload - async def __rsub(self: "ASyncFuture[int]", other: float) -> "ASyncFuture[float]":... + async def __rsub( + self: "ASyncFuture[int]", other: float + ) -> "ASyncFuture[float]": ... @overload - async def __rsub(self: "ASyncFuture[Decimal]", other: Decimal) -> "ASyncFuture[Decimal]":... + async def __rsub( + self: "ASyncFuture[Decimal]", other: Decimal + ) -> "ASyncFuture[Decimal]": ... @overload - async def __rsub(self: "ASyncFuture[Decimal]", other: int) -> "ASyncFuture[Decimal]":... + async def __rsub( + self: "ASyncFuture[Decimal]", other: int + ) -> "ASyncFuture[Decimal]": ... @overload - async def __rsub(self: "ASyncFuture[int]", other: Decimal) -> "ASyncFuture[Decimal]":... + async def __rsub( + self: "ASyncFuture[int]", other: Decimal + ) -> "ASyncFuture[Decimal]": ... @overload - async def __rsub(self: "ASyncFuture[int]", other: Awaitable[int]) -> "ASyncFuture[int]":... + async def __rsub( + self: "ASyncFuture[int]", other: Awaitable[int] + ) -> "ASyncFuture[int]": ... @overload - async def __rsub(self: "ASyncFuture[float]", other: Awaitable[float]) -> "ASyncFuture[float]":... + async def __rsub( + self: "ASyncFuture[float]", other: Awaitable[float] + ) -> "ASyncFuture[float]": ... @overload - async def __rsub(self: "ASyncFuture[float]", other: Awaitable[int]) -> "ASyncFuture[float]":... + async def __rsub( + self: "ASyncFuture[float]", other: Awaitable[int] + ) -> "ASyncFuture[float]": ... @overload - async def __rsub(self: "ASyncFuture[int]", other: Awaitable[float]) -> "ASyncFuture[float]":... + async def __rsub( + self: "ASyncFuture[int]", other: Awaitable[float] + ) -> "ASyncFuture[float]": ... @overload - async def __rsub(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... + async def __rsub( + self: "ASyncFuture[Decimal]", other: Awaitable[Decimal] + ) -> "ASyncFuture[Decimal]": ... @overload - async def __rsub(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":... + async def __rsub( + self: "ASyncFuture[Decimal]", other: Awaitable[int] + ) -> "ASyncFuture[Decimal]": ... @overload - async def __rsub(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... + async def __rsub( + self: "ASyncFuture[int]", other: Awaitable[Decimal] + ) -> "ASyncFuture[Decimal]": ... async def __rsub(self, other) -> "Any": a, b = await _gather_check_and_materialize(other, self) return a - b + async def __rmul(self, other) -> "Any": a, b = await _gather_check_and_materialize(other, self) return a * b + async def __rtruediv(self, other) -> "Any": a, b = await _gather_check_and_materialize(other, self) return a / b + async def __rfloordiv(self, other) -> "Any": a, b = await _gather_check_and_materialize(other, self) return a // b + async def __rpow(self, other) -> "Any": a, b = await _gather_check_and_materialize(other, self) - return a ** b - + return a**b + async def __iadd(self, other) -> "Any": a, b = await _gather_check_and_materialize(self, other) self._result = a + b return self._result + async def __isub(self, other) -> "Any": a, b = await _gather_check_and_materialize(self, other) self._result = a - b return self._result + async def __imul(self, other) -> "Any": a, b = await _gather_check_and_materialize(self, other) self._result = a * b return self._result + async def __itruediv(self, other) -> "Any": a, b = await _gather_check_and_materialize(self, other) self._result = a / b return self._result + async def __ifloordiv(self, other) -> "Any": a, b = await _gather_check_and_materialize(self, other) self._result = a // b return self._result + async def __ipow(self, other) -> "Any": a, b = await _gather_check_and_materialize(self, other) - self._result = a ** b + self._result = a**b return self._result - + # Comparisons async def __eq(self, other) -> bool: a, b = await _gather_check_and_materialize(self, other) return a == b + async def __gt(self, other) -> bool: a, b = await _gather_check_and_materialize(self, other) return a > b + async def __ge(self, other) -> bool: a, b = await _gather_check_and_materialize(self, other) return a >= b + async def __lt(self, other) -> bool: a, b = await _gather_check_and_materialize(self, other) return a < b + async def __le(self, other) -> bool: a, b = await _gather_check_and_materialize(self, other) return a <= b - # not sure what to call these async def __contains(self, item: Any) -> bool: _self, _item = await _gather_check_and_materialize(self, item) return _item in _self - + # conversion # NOTE: We aren't allowed to return ASyncFutures here :( def __bool__(self) -> bool: return bool(_materialize(self)) + def __bytes__(self) -> bytes: return bytes(_materialize(self)) + def __str__(self) -> str: return str(_materialize(self)) + def __int__(self) -> int: return int(_materialize(self)) + def __float__(self) -> float: return float(_materialize(self)) - + # WIP internals - + @property def __dependants__(self) -> Set["ASyncFuture"]: dependants = set() @@ -466,6 +746,7 @@ def __dependants__(self) -> Set["ASyncFuture"]: dependants.add(dep) dependants.union(dep.__dependants__) return dependants + @property def __dependencies__(self) -> Set["ASyncFuture"]: dependencies = set() @@ -473,39 +754,56 @@ def __dependencies__(self) -> Set["ASyncFuture"]: dependencies.add(dep) dependencies.union(dep.__dependencies__) return dependencies + def __sizeof__(self) -> int: if isinstance(self.__awaitable__, Coroutine): - return sum(sys.getsizeof(v) for v in self.__awaitable__.cr_frame.f_locals.values()) + return sum( + sys.getsizeof(v) for v in self.__awaitable__.cr_frame.f_locals.values() + ) elif isinstance(self.__awaitable__, asyncio.Future): raise NotImplementedError raise NotImplementedError -@final +@final class _ASyncFutureWrappedFn(Callable[P, ASyncFuture[T]]): __slots__ = "callable", "wrapped", "_callable_name" - def __init__(self, callable: Union[Callable[P, Awaitable[T]], Callable[P, T]] = None, **kwargs: Unpack[ModifierKwargs]): + + def __init__( + self, + callable: AnyFn[P, T] = None, + **kwargs: Unpack[ModifierKwargs], + ): from a_sync import a_sync + if callable: self.callable = callable self._callable_name = callable.__name__ a_sync_callable = a_sync(callable, default="async", **kwargs) + @wraps(callable) def future_wrap(*args: P.args, **kwargs: P.kwargs) -> "ASyncFuture[T]": return ASyncFuture(a_sync_callable(*args, **kwargs, sync=False)) + self.wrapped = future_wrap else: self.wrapped = partial(_ASyncFutureWrappedFn, **kwargs) + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ASyncFuture[T]: return self.wrapped(*args, **kwargs) + def __repr__(self) -> str: return f"<{self.__class__.__name__} {self.callable}>" - def __get__(self, instance: I, owner: Type[I]) -> Union[Self, "_ASyncFutureInstanceMethod[I, P, T]"]: + + def __get__( + self, instance: I, owner: Type[I] + ) -> Union[Self, "_ASyncFutureInstanceMethod[I, P, T]"]: if owner is None: return self else: return _ASyncFutureInstanceMethod(self, instance) + @final class _ASyncFutureInstanceMethod(Generic[I, P, T]): # NOTE: probably could just replace this with functools.partial @@ -543,6 +841,7 @@ def __init__( def __repr__(self) -> str: return f"<{self.__class__.__name__} for {self.__wrapper.callable} bound to {self.__instance}>" + def __call__(self, /, *fn_args: P.args, **fn_kwargs: P.kwargs) -> T: return self.__wrapper(self.__instance, *fn_args, **fn_kwargs) diff --git a/a_sync/iter.py b/a_sync/iter.py index 0742750e..bf1e0c9a 100644 --- a/a_sync/iter.py +++ b/a_sync/iter.py @@ -1,4 +1,3 @@ - import asyncio import functools import inspect @@ -22,6 +21,7 @@ SortKey = SyncFn[[T], bool] ViewFn = AnyFn[[T], bool] + class _AwaitableAsyncIterableMixin(AsyncIterable[T]): """ A mixin class defining logic for making an AsyncIterable awaitable. @@ -35,19 +35,20 @@ class _AwaitableAsyncIterableMixin(AsyncIterable[T]): ... def __aiter__(self): ... for i in range(4): ... yield i - ... + ... >>> aiterable = MyAwaitableAIterable() >>> await aiterable [0, 1, 2, 3, 4] ``` """ + __wrapped__: AsyncIterable[T] - + def __await__(self) -> Generator[Any, Any, List[T]]: """ Asynchronously iterate through the {cls} and return all objects. - + Returns: A list of the objects yielded by the {cls}. """ @@ -57,13 +58,15 @@ def __await__(self) -> Generator[Any, Any, List[T]]: def materialized(self) -> List[T]: """ Synchronously iterate through the {cls} and return all objects. - + Returns: A list of the objects yielded by the {cls}. """ return _helpers._await(self._materialized) - def sort(self, *, key: SortKey[T] = None, reverse: bool = False) -> "ASyncSorter[T]": + def sort( + self, *, key: SortKey[T] = None, reverse: bool = False + ) -> "ASyncSorter[T]": """ Sort the contents of the {cls}. @@ -92,7 +95,7 @@ def filter(self, function: ViewFn[T]) -> "ASyncFilter[T]": async def _materialized(self) -> List[T]: """ Asynchronously iterate through the {cls} and return all objects. - + Returns: A list of the objects yielded by the {cls}. """ @@ -106,7 +109,7 @@ def __init_subclass__(cls, **kwargs) -> None: cls.__doc__ = new else: cls.__doc__ += f"\n\n{new}" - + # format the member docstrings for attr_name in dir(cls): attr = getattr(cls, attr_name, None) @@ -115,27 +118,35 @@ def __init_subclass__(cls, **kwargs) -> None: return super().__init_subclass__(**kwargs) - __slots__ = '__async_property__', - + __slots__ = ("__async_property__",) + + class ASyncIterable(_AwaitableAsyncIterableMixin[T], Iterable[T]): """ A hybrid Iterable/AsyncIterable implementation designed to offer dual compatibility with both synchronous and asynchronous iteration protocols. - + This class allows objects to be iterated over using either a standard `for` loop or an `async for` loop, making it versatile in scenarios where the mode of iteration (synchronous or asynchronous) needs to be flexible or is determined at runtime. The class achieves this by implementing both `__iter__` and `__aiter__` methods, enabling it to return appropriate iterator objects that can handle synchronous and asynchronous iteration, respectively. This dual functionality is particularly useful in codebases that are transitioning between synchronous and asynchronous code, or in libraries that aim to support both synchronous and asynchronous usage patterns without requiring the user to manage different types of iterable objects. """ + @classmethod def wrap(cls, wrapped: AsyncIterable[T]) -> "ASyncIterable[T]": "Class method to wrap an AsyncIterable for backward compatibility." - logger.warning("ASyncIterable.wrap will be removed soon. Please replace uses with simple instantiation ie `ASyncIterable(wrapped)`") + logger.warning( + "ASyncIterable.wrap will be removed soon. Please replace uses with simple instantiation ie `ASyncIterable(wrapped)`" + ) return cls(wrapped) + def __init__(self, async_iterable: AsyncIterable[T]): "Initializes the ASyncIterable with an async iterable." if not isinstance(async_iterable, AsyncIterable): - raise TypeError(f"`async_iterable` must be an AsyncIterable. You passed {async_iterable}") + raise TypeError( + f"`async_iterable` must be an AsyncIterable. You passed {async_iterable}" + ) self.__wrapped__ = async_iterable "The wrapped async iterable object." + def __repr__(self) -> str: start = f"<{type(self).__name__}" if wrapped := getattr(self, "__wrapped__", None): @@ -151,10 +162,13 @@ def __aiter__(self) -> AsyncIterator[T]: def __iter__(self) -> Iterator[T]: "Return an iterator that yields :obj:`T` objects from the {cls}." yield from ASyncIterator(self.__aiter__()) - __slots__ = "__wrapped__", + + __slots__ = ("__wrapped__",) + AsyncGenFunc = Callable[P, Union[AsyncGenerator[T, None], AsyncIterator[T]]] + class ASyncIterator(_AwaitableAsyncIterableMixin[T], Iterator[T]): """ A hybrid Iterator/AsyncIterator implementation that bridges the gap between synchronous and asynchronous iteration. This class provides a unified interface for iteration that can seamlessly operate in both synchronous (`for` loop) and asynchronous (`async for` loop) contexts. It allows the wrapping of asynchronous iterable objects or async generator functions, making them usable in synchronous code without explicitly managing event loops or asynchronous context switches. @@ -163,11 +177,11 @@ class ASyncIterator(_AwaitableAsyncIterableMixin[T], Iterator[T]): This class is particularly useful for library developers seeking to provide a consistent iteration interface across synchronous and asynchronous code, reducing the cognitive load on users and promoting code reusability and simplicity. """ - + def __next__(self) -> T: """ Synchronously fetch the next item from the {cls}. - + Raises: :class:`StopIteration`: Once all items have been fetched from the {cls}. """ @@ -177,34 +191,44 @@ def __next__(self) -> T: raise StopIteration from e except RuntimeError as e: if str(e) == "This event loop is already running": - raise SyncModeInAsyncContextError("The event loop is already running. Try iterating using `async for` instead of `for`.") from e + raise SyncModeInAsyncContextError( + "The event loop is already running. Try iterating using `async for` instead of `for`." + ) from e raise @overload - def wrap(cls, aiterator: AsyncIterator[T]) -> "ASyncIterator[T]":... + def wrap(cls, aiterator: AsyncIterator[T]) -> "ASyncIterator[T]": ... @overload - def wrap(cls, async_gen_func: AsyncGenFunc[P, T]) -> "ASyncGeneratorFunction[P, T]":... + def wrap( + cls, async_gen_func: AsyncGenFunc[P, T] + ) -> "ASyncGeneratorFunction[P, T]": ... @classmethod def wrap(cls, wrapped): "Class method to wrap either an AsyncIterator or an async generator function." if isinstance(wrapped, AsyncIterator): - logger.warning("This use case for ASyncIterator.wrap will be removed soon. Please replace uses with simple instantiation ie `ASyncIterator(wrapped)`") + logger.warning( + "This use case for ASyncIterator.wrap will be removed soon. Please replace uses with simple instantiation ie `ASyncIterator(wrapped)`" + ) return cls(wrapped) elif inspect.isasyncgenfunction(wrapped): return ASyncGeneratorFunction(wrapped) - raise TypeError(f"`wrapped` must be an AsyncIterator or an async generator function. You passed {wrapped}") + raise TypeError( + f"`wrapped` must be an AsyncIterator or an async generator function. You passed {wrapped}" + ) def __init__(self, async_iterator: AsyncIterator[T]): "Initializes the ASyncIterator with an async iterator." if not isinstance(async_iterator, AsyncIterator): - raise TypeError(f"`async_iterator` must be an AsyncIterator. You passed {async_iterator}") + raise TypeError( + f"`async_iterator` must be an AsyncIterator. You passed {async_iterator}" + ) self.__wrapped__ = async_iterator "The wrapped :class:`AsyncIterator`." async def __anext__(self) -> T: """ Asynchronously fetch the next item from the {cls}. - + Raises: :class:`StopAsyncIteration`: Once all items have been fetched from the {cls}. """ @@ -218,6 +242,7 @@ def __aiter__(self) -> Self: "Return the {cls} for aiteration." return self + class ASyncGeneratorFunction(Generic[P, T]): """ Encapsulates an asynchronous generator function, providing a mechanism to use it as an asynchronous iterator with enhanced capabilities. This class wraps an async generator function, allowing it to be called with parameters and return an :class:`~ASyncIterator` object. It is particularly useful for situations where an async generator function needs to be used in a manner that is consistent with both synchronous and asynchronous execution contexts. @@ -233,7 +258,9 @@ class ASyncGeneratorFunction(Generic[P, T]): __weakself__: "weakref.ref[object]" = None "A weak reference to the instance the function is bound to, if any." - def __init__(self, async_gen_func: AsyncGenFunc[P, T], instance: Any = None) -> None: + def __init__( + self, async_gen_func: AsyncGenFunc[P, T], instance: Any = None + ) -> None: """ Initializes the ASyncGeneratorFunction with the given async generator function and optionally an instance. @@ -241,7 +268,7 @@ def __init__(self, async_gen_func: AsyncGenFunc[P, T], instance: Any = None) -> async_gen_func: The async generator function to wrap. instance (optional): The object to bind to the function, if applicable. """ - + self.field_name = async_gen_func.__name__ "The name of the async generator function." @@ -263,7 +290,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ASyncIterator[T]: Args: *args: Positional arguments for the function. **kwargs: Keyword arguments for the function. - + Returns: An :class:`ASyncIterator` wrapping the :class:`AsyncIterator` returned from the wrapped function call. """ @@ -296,11 +323,14 @@ def __self__(self) -> object: def __get_cache_handle(self, instance: object) -> asyncio.TimerHandle: # NOTE: we create a strong reference to instance here. I'm not sure if this is good or not but its necessary for now. - return asyncio.get_event_loop().call_later(300, delattr, instance, self.field_name) + return asyncio.get_event_loop().call_later( + 300, delattr, instance, self.field_name + ) def __cancel_cache_handle(self, instance: object) -> None: self._cache_handle.cancel() + class _ASyncView(ASyncIterator[T]): """ Internal mixin class containing logic for creating specialized views for :class:`~ASyncIterable` objects. @@ -313,8 +343,8 @@ class _ASyncView(ASyncIterator[T]): """An optional iterator. If None, :attr:`~_ASyncView.__aiterator__` will have a value.""" def __init__( - self, - function: ViewFn[T], + self, + function: ViewFn[T], iterable: AnyIterable[T], ) -> None: """ @@ -331,18 +361,21 @@ def __init__( elif isinstance(iterable, Iterable): self.__iterator__ = iterable.__iter__() else: - raise TypeError(f"`iterable` must be AsyncIterable or Iterable, you passed {iterable}") + raise TypeError( + f"`iterable` must be AsyncIterable or Iterable, you passed {iterable}" + ) + -@final +@final class ASyncFilter(_ASyncView[T]): """ - An async filter class that filters items of an async iterable based on a provided function. - + An async filter class that filters items of an async iterable based on a provided function. + This class inherits from :class:`~_ASyncView` and provides the functionality to asynchronously iterate over items, applying the filter function to each item to determine if it should be included in the result. """ - + def __repr__(self) -> str: return f"" @@ -388,22 +421,24 @@ def _key_if_no_key(obj: T) -> T: """ return obj + @final class ASyncSorter(_ASyncView[T]): """ - An async sorter class that sorts items of an async iterable based on a provided key function. - + An async sorter class that sorts items of an async iterable based on a provided key function. + This class inherits from :class:`~_ASyncView` and provides the functionality to asynchronously iterate over items, applying the key function to each item for sorting. """ + reversed: bool = False _consumed: bool = False def __init__( - self, + self, iterable: AsyncIterable[T], *, - key: SortKey[T] = None, + key: SortKey[T] = None, reverse: bool = False, ) -> None: """ @@ -467,7 +502,9 @@ async def __sort(self, reverse: bool) -> AsyncIterator[T]: for obj in self.__iterator__: items.append(obj) sort_tasks.append(asyncio.create_task(self._function(obj))) - for sort_value, obj in sorted(zip(await asyncio.gather(*sort_tasks), items), reverse=reverse): + for sort_value, obj in sorted( + zip(await asyncio.gather(*sort_tasks), items), reverse=reverse + ): yield obj else: if self.__aiterator__: @@ -480,4 +517,10 @@ async def __sort(self, reverse: bool) -> AsyncIterator[T]: self._consumed = True -__all__ = ["ASyncIterable", "ASyncIterator", "ASyncFilter", "ASyncSorter", "ASyncGeneratorFunction"] +__all__ = [ + "ASyncIterable", + "ASyncIterator", + "ASyncFilter", + "ASyncSorter", + "ASyncGeneratorFunction", +] diff --git a/a_sync/primitives/__init__.py b/a_sync/primitives/__init__.py index 199b253f..58dae18a 100644 --- a/a_sync/primitives/__init__.py +++ b/a_sync/primitives/__init__.py @@ -1,4 +1,3 @@ - """ While not the focus of this lib, this module includes some new primitives and some modified versions of standard asyncio primitives. """ diff --git a/a_sync/primitives/_debug.py b/a_sync/primitives/_debug.py index f92860af..64be655c 100644 --- a/a_sync/primitives/_debug.py +++ b/a_sync/primitives/_debug.py @@ -14,17 +14,17 @@ class _DebugDaemonMixin(_LoggerMixin, metaclass=abc.ABCMeta): """ A mixin class that provides debugging capabilities using a daemon task. - + This mixin ensures that rich debug logs are automagically emitted from subclass instances whenever debug logging is enabled. """ - - __slots__ = "_daemon", + + __slots__ = ("_daemon",) @abc.abstractmethod async def _debug_daemon(self, fut: asyncio.Future, fn, *args, **kwargs) -> None: """ Abstract method to define the debug daemon's behavior. - + Args: fut: The future associated with the daemon. fn: The function to be debugged. @@ -35,11 +35,11 @@ async def _debug_daemon(self, fut: asyncio.Future, fn, *args, **kwargs) -> None: def _start_debug_daemon(self, *args, **kwargs) -> "asyncio.Future[None]": """ Starts the debug daemon task if debug logging is enabled and the event loop is running. - + Args: *args: Positional arguments for the debug daemon. **kwargs: Keyword arguments for the debug daemon. - + Returns: The debug daemon task, or a dummy future if debug logs are not enabled or if the daemon cannot be created. """ @@ -51,17 +51,17 @@ def _start_debug_daemon(self, *args, **kwargs) -> "asyncio.Future[None]": def _ensure_debug_daemon(self, *args, **kwargs) -> "asyncio.Future[None]": """ Ensures that the debug daemon task is running. - + Args: *args: Positional arguments for the debug daemon. **kwargs: Keyword arguments for the debug daemon. - + Returns: Either the debug daemon task or a dummy future if debug logging is not enabled. """ if not self.debug_logs_enabled: self._daemon = asyncio.get_event_loop().create_future() - if not hasattr(self, '_daemon') or self._daemon is None: + if not hasattr(self, "_daemon") or self._daemon is None: self._daemon = self._start_debug_daemon(*args, **kwargs) self._daemon.add_done_callback(self._stop_debug_daemon) return self._daemon @@ -69,10 +69,10 @@ def _ensure_debug_daemon(self, *args, **kwargs) -> "asyncio.Future[None]": def _stop_debug_daemon(self, t: Optional[asyncio.Task] = None) -> None: """ Stops the debug daemon task. - + Args: - t (optional): The task to be stopped, if any. - + t (optional): The task to be stopped, if any. + Raises: ValueError: If `t` is not the current daemon. """ diff --git a/a_sync/primitives/_loggable.py b/a_sync/primitives/_loggable.py index 85bf0d3a..caf464a9 100644 --- a/a_sync/primitives/_loggable.py +++ b/a_sync/primitives/_loggable.py @@ -9,22 +9,23 @@ class _LoggerMixin: """ A mixin class that adds logging capabilities to other classes. - + This mixin provides a cached property for accessing a logger instance and a property to check if debug logging is enabled. """ + @cached_property def logger(self) -> Logger: """ Returns a logger instance specific to the class using this mixin. - + The logger ID is constructed from the module and class name, and optionally includes an instance name if available. Returns: Logger: A logger instance for the class. """ logger_id = type(self).__qualname__ - if hasattr(self, '_name') and self._name: - logger_id += f'.{self._name}' + if hasattr(self, "_name") and self._name: + logger_id += f".{self._name}" return getLogger(logger_id) @property diff --git a/a_sync/primitives/locks/__init__.py b/a_sync/primitives/locks/__init__.py index 75d1dcac..e69674a0 100644 --- a/a_sync/primitives/locks/__init__.py +++ b/a_sync/primitives/locks/__init__.py @@ -1,5 +1,8 @@ - from a_sync.primitives.locks.counter import CounterLock from a_sync.primitives.locks.event import Event -from a_sync.primitives.locks.semaphore import DummySemaphore, Semaphore, ThreadsafeSemaphore +from a_sync.primitives.locks.semaphore import ( + DummySemaphore, + Semaphore, + ThreadsafeSemaphore, +) from a_sync.primitives.locks.prio_semaphore import PrioritySemaphore diff --git a/a_sync/primitives/locks/counter.py b/a_sync/primitives/locks/counter.py index 5789e18a..8bb332b3 100644 --- a/a_sync/primitives/locks/counter.py +++ b/a_sync/primitives/locks/counter.py @@ -16,13 +16,15 @@ class CounterLock(_DebugDaemonMixin): """ An async primitive that blocks until the internal counter has reached a specific value. - + A coroutine can `await counter.wait_for(3)` and it will block until the internal counter >= 3. If some other task executes `counter.value = 5` or `counter.set(5)`, the first coroutine will unblock as 5 >= 3. - + The internal counter can only increase. """ + __slots__ = "is_ready", "_name", "_value", "_events" + def __init__(self, start_value: int = 0, name: Optional[str] = None): """ Initializes the CounterLock with a starting value and an optional name. @@ -43,7 +45,7 @@ def __init__(self, start_value: int = 0, name: Optional[str] = None): self.is_ready = lambda v: self._value >= v """A lambda function that indicates whether a given value has already been surpassed.""" - + async def wait_for(self, value: int) -> bool: """ Waits until the counter reaches or exceeds the specified value. @@ -58,23 +60,23 @@ async def wait_for(self, value: int) -> bool: self._ensure_debug_daemon() await self._events[value].wait() return True - + def set(self, value: int) -> None: """ Sets the counter to the specified value. Args: value: The value to set the counter to. Must be >= the current value. - + Raises: ValueError: If the new value is less than the current value. """ self.value = value - + def __repr__(self) -> str: waiters = {v: len(self._events[v]._waiters) for v in sorted(self._events)} return f"" - + @property def value(self) -> int: """ @@ -84,7 +86,7 @@ def value(self) -> int: The current value of the counter. """ return self._value - + @value.setter def value(self, value: int) -> None: """ @@ -98,28 +100,37 @@ def value(self, value: int) -> None: """ if value > self._value: self._value = value - ready = [self._events.pop(key) for key in list(self._events.keys()) if key <= self._value] + ready = [ + self._events.pop(key) + for key in list(self._events.keys()) + if key <= self._value + ] for event in ready: event.set() elif value < self._value: raise ValueError("You cannot decrease the value.") - + async def _debug_daemon(self) -> None: """ Periodically logs debug information about the counter state and waiters. """ start = time() while self._events: - self.logger.debug("%s is still locked after %sm", self, round(time() - start / 60, 2)) + self.logger.debug( + "%s is still locked after %sm", self, round(time() - start / 60, 2) + ) await asyncio.sleep(300) + class CounterLockCluster: """ An asyncio primitive that represents 2 or more CounterLock objects. - + `wait_for(i)` will block until the value of all CounterLock objects is >= i. """ - __slots__ = "locks", + + __slots__ = ("locks",) + def __init__(self, counter_locks: Iterable[CounterLock]) -> None: """ Initializes the CounterLockCluster with a collection of CounterLock objects. @@ -128,7 +139,7 @@ def __init__(self, counter_locks: Iterable[CounterLock]) -> None: counter_locks: The CounterLock objects to manage. """ self.locks = list(counter_locks) - + async def wait_for(self, value: int) -> bool: """ Waits until the value of all CounterLock objects in the cluster reaches or exceeds the specified value. @@ -139,6 +150,7 @@ async def wait_for(self, value: int) -> bool: Returns: True when the value of all CounterLock objects reach or exceed the specified value. """ - await asyncio.gather(*[counter_lock.wait_for(value) for counter_lock in self.locks]) + await asyncio.gather( + *[counter_lock.wait_for(value) for counter_lock in self.locks] + ) return True - \ No newline at end of file diff --git a/a_sync/primitives/locks/event.py b/a_sync/primitives/locks/event.py index f3b3cbcd..f04f951d 100644 --- a/a_sync/primitives/locks/event.py +++ b/a_sync/primitives/locks/event.py @@ -8,14 +8,16 @@ from a_sync._typing import * from a_sync.primitives._debug import _DebugDaemonMixin + class Event(asyncio.Event, _DebugDaemonMixin): """ An asyncio.Event with additional debug logging to help detect deadlocks. - + This event class extends asyncio.Event by adding debug logging capabilities. It logs - detailed information about the event state and waiters, which can be useful for + detailed information about the event state and waiters, which can be useful for diagnosing and debugging potential deadlocks. """ + _value: bool _loop: asyncio.AbstractEventLoop _waiters: Deque["asyncio.Future[None]"] @@ -23,7 +25,14 @@ class Event(asyncio.Event, _DebugDaemonMixin): __slots__ = "_value", "_waiters", "_debug_daemon_interval" else: __slots__ = "_value", "_loop", "_waiters", "_debug_daemon_interval" - def __init__(self, name: str = "", debug_daemon_interval: int = 300, *, loop: Optional[asyncio.AbstractEventLoop] = None): + + def __init__( + self, + name: str = "", + debug_daemon_interval: int = 300, + *, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): """ Initializes the Event. @@ -41,12 +50,14 @@ def __init__(self, name: str = "", debug_daemon_interval: int = 300, *, loop: Op if hasattr(self, "_loop"): self._loop = self._loop or asyncio.get_event_loop() self._debug_daemon_interval = debug_daemon_interval + def __repr__(self) -> str: - label = f'name={self._name}' if self._name else 'object' - status = 'set' if self._value else 'unset' + label = f"name={self._name}" if self._name else "object" + status = "set" if self._value else "unset" if self._waiters: - status += f', waiters:{len(self._waiters)}' + status += f", waiters:{len(self._waiters)}" return f"<{self.__class__.__module__}.{self.__class__.__name__} {label} at {hex(id(self))} [{status}]>" + async def wait(self) -> Literal[True]: """ Wait until the event is set. @@ -58,6 +69,7 @@ async def wait(self) -> Literal[True]: return True self._ensure_debug_daemon() return await super().wait() + async def _debug_daemon(self) -> None: """ Periodically logs debug information about the event state and waiters. @@ -68,4 +80,6 @@ async def _debug_daemon(self) -> None: del self # no need to hold a reference here await asyncio.sleep(self._debug_daemon_interval) if (self := weakself()) and not self.is_set(): - self.logger.debug("Waiting for %s for %sm", self, round((time() - start) / 60, 2)) + self.logger.debug( + "Waiting for %s for %sm", self, round((time() - start) / 60, 2) + ) diff --git a/a_sync/primitives/locks/prio_semaphore.py b/a_sync/primitives/locks/prio_semaphore.py index 60891c65..d17fd01b 100644 --- a/a_sync/primitives/locks/prio_semaphore.py +++ b/a_sync/primitives/locks/prio_semaphore.py @@ -1,4 +1,3 @@ - import asyncio import heapq import logging @@ -12,30 +11,41 @@ class Priority(Protocol): - def __lt__(self, other) -> bool: - ... + def __lt__(self, other) -> bool: ... + + +PT = TypeVar("PT", bound=Priority) + +CM = TypeVar("CM", bound="_AbstractPrioritySemaphoreContextManager[Priority]") -PT = TypeVar('PT', bound=Priority) - -CM = TypeVar('CM', bound="_AbstractPrioritySemaphoreContextManager[Priority]") class _AbstractPrioritySemaphore(Semaphore, Generic[PT, CM]): name: Optional[str] _value: int _waiters: List["_AbstractPrioritySemaphoreContextManager[PT]"] # type: ignore [assignment] - __slots__ = "name", "_value", "_waiters", "_context_managers", "_capacity", "_potential_lost_waiters" + _context_managers: Dict[PT, "_AbstractPrioritySemaphoreContextManager[PT]"] + __slots__ = ( + "name", + "_value", + "_waiters", + "_context_managers", + "_capacity", + "_potential_lost_waiters", + ) @property - def _context_manager_class(self) -> Type["_AbstractPrioritySemaphoreContextManager[PT]"]: + def _context_manager_class( + self, + ) -> Type["_AbstractPrioritySemaphoreContextManager[PT]"]: raise NotImplementedError - + @property def _top_priority(self) -> PT: # You can use this so you can set priorities with non numeric comparable values raise NotImplementedError def __init__(self, value: int = 1, *, name: Optional[str] = None) -> None: - self._context_managers: Dict[PT, _AbstractPrioritySemaphoreContextManager[PT]] = {} + self._context_managers = {} self._capacity = value super().__init__(value, name=name) self._waiters = [] @@ -50,14 +60,18 @@ async def __aenter__(self) -> None: async def __aexit__(self, *_) -> None: self[self._top_priority].release() - + async def acquire(self) -> Literal[True]: return await self[self._top_priority].acquire() - - def __getitem__(self, priority: Optional[PT]) -> "_AbstractPrioritySemaphoreContextManager[PT]": + + def __getitem__( + self, priority: Optional[PT] + ) -> "_AbstractPrioritySemaphoreContextManager[PT]": priority = self._top_priority if priority is None else priority if priority not in self._context_managers: - context_manager = self._context_manager_class(self, priority, name=self.name) + context_manager = self._context_manager_class( + self, priority, name=self.name + ) heapq.heappush(self._waiters, context_manager) # type: ignore [misc] self._context_managers[priority] = context_manager return self._context_managers[priority] @@ -66,30 +80,37 @@ def locked(self) -> bool: """Returns True if semaphore cannot be acquired immediately.""" return self._value == 0 or ( any( - cm._waiters and any(not w.cancelled() for w in cm._waiters) + cm._waiters and any(not w.cancelled() for w in cm._waiters) for cm in (self._context_managers.values() or ()) ) ) - + def _count_waiters(self) -> Dict[PT, int]: - return {manager._priority: len(manager.waiters) for manager in sorted(self._waiters, key=lambda m: m._priority)} - + return { + manager._priority: len(manager.waiters) + for manager in sorted(self._waiters, key=lambda m: m._priority) + } + def _wake_up_next(self) -> None: while self._waiters: manager = heapq.heappop(self._waiters) if len(manager) == 0: # There are no more waiters, get rid of the empty manager - logger.debug("manager %s has no more waiters, popping from %s", manager._repr_no_parent_(), self) + logger.debug( + "manager %s has no more waiters, popping from %s", + manager._repr_no_parent_(), + self, + ) self._context_managers.pop(manager._priority) continue logger.debug("waking up next for %s", manager._repr_no_parent_()) - + woke_up = False start_len = len(manager) - + if not manager._waiters: - logger.debug('not manager._waiters') - + logger.debug("not manager._waiters") + while manager._waiters: waiter = manager._waiters.popleft() self._potential_lost_waiters.remove(waiter) @@ -98,15 +119,15 @@ def _wake_up_next(self) -> None: logger.debug("woke up %s", waiter) woke_up = True break - + if not woke_up: self._context_managers.pop(manager._priority) continue - + end_len = len(manager) - + assert start_len > end_len, f"start {start_len} end {end_len}" - + if end_len: # There are still waiters, put the manager back heapq.heappush(self._waiters, manager) # type: ignore [misc] @@ -114,52 +135,58 @@ def _wake_up_next(self) -> None: # There are no more waiters, get rid of the empty manager self._context_managers.pop(manager._priority) return - - # emergency procedure (hopefully temporary): + + # emergency procedure (hopefully temporary): while self._potential_lost_waiters: waiter = self._potential_lost_waiters.pop(0) - logger.debug('we found a lost waiter %s', waiter) + logger.debug("we found a lost waiter %s", waiter) if not waiter.done(): waiter.set_result(None) logger.debug("woke up lost waiter %s", waiter) return logger.debug("%s has no waiters to wake", self) + class _AbstractPrioritySemaphoreContextManager(Semaphore, Generic[PT]): _loop: asyncio.AbstractEventLoop _waiters: Deque[asyncio.Future] # type: ignore [assignment] __slots__ = "_parent", "_priority" - + @property def _priority_name(self) -> str: raise NotImplementedError - - def __init__(self, parent: _AbstractPrioritySemaphore, priority: PT, name: Optional[str] = None) -> None: + + def __init__( + self, + parent: _AbstractPrioritySemaphore, + priority: PT, + name: Optional[str] = None, + ) -> None: self._parent = parent self._priority = priority super().__init__(0, name=name) def __repr__(self) -> str: return f"<{self.__class__.__name__} parent={self._parent} {self._priority_name}={self._priority} waiters={len(self)}>" - + def _repr_no_parent_(self) -> str: return f"<{self.__class__.__name__} parent_name={self._parent.name} {self._priority_name}={self._priority} waiters={len(self)}>" - + def __lt__(self, other) -> bool: if type(other) is not type(self): raise TypeError(f"{other} is not type {self.__class__.__name__}") return self._priority < other._priority - + @cached_property def loop(self) -> asyncio.AbstractEventLoop: return self._loop or asyncio.get_event_loop() - + @property - def waiters (self) -> Deque[asyncio.Future]: + def waiters(self) -> Deque[asyncio.Future]: if self._waiters is None: self._waiters = deque() return self._waiters - + async def acquire(self) -> Literal[True]: """Acquire a semaphore. @@ -185,12 +212,17 @@ async def acquire(self) -> Literal[True]: raise self._parent._value -= 1 return True + def release(self) -> None: self._parent.release() - -class _PrioritySemaphoreContextManager(_AbstractPrioritySemaphoreContextManager[Numeric]): + + +class _PrioritySemaphoreContextManager( + _AbstractPrioritySemaphoreContextManager[Numeric] +): _priority_name = "priority" + class PrioritySemaphore(_AbstractPrioritySemaphore[Numeric, _PrioritySemaphoreContextManager]): # type: ignore [type-var] _context_manager_class = _PrioritySemaphoreContextManager _top_priority = -1 diff --git a/a_sync/primitives/locks/semaphore.py b/a_sync/primitives/locks/semaphore.py index 6b6f5989..d7bac45c 100644 --- a/a_sync/primitives/locks/semaphore.py +++ b/a_sync/primitives/locks/semaphore.py @@ -10,17 +10,18 @@ logger = logging.getLogger(__name__) + class Semaphore(asyncio.Semaphore, _DebugDaemonMixin): """ A semaphore with additional debugging capabilities. - + This semaphore includes debug logging. - + Also, it can be used to decorate coroutine functions so you can rewrite this pattern: ``` semaphore = Semaphore(5) - + async def limited(): async with semaphore: return 1 @@ -37,23 +38,24 @@ async def limited(): return 1 ``` """ + if sys.version_info >= (3, 10): __slots__ = "name", "_value", "_waiters", "_decorated" else: __slots__ = "name", "_value", "_waiters", "_loop", "_decorated" - + def __init__(self, value: int, name=None, **kwargs) -> None: """ Initialize the semaphore with a given value and optional name for debugging. - + Args: value: The initial value for the semaphore. name (optional): An optional name used only to provide useful context in debug logs. """ super().__init__(value, **kwargs) - self.name = name or self.__origin__ if hasattr(self, '__origin__') else None + self.name = name or self.__origin__ if hasattr(self, "__origin__") else None self._decorated: Set[str] = set() - + # Dank new functionality def __call__(self, fn: CoroFn[P, T]) -> CoroFn[P, T]: @@ -62,7 +64,7 @@ def __call__(self, fn: CoroFn[P, T]) -> CoroFn[P, T]: ``` semaphore = Semaphore(5) - + async def limited(): async with semaphore: return 1 @@ -80,26 +82,26 @@ async def limited(): ``` """ return self.decorate(fn) # type: ignore [arg-type, return-value] - + def __repr__(self) -> str: representation = f"<{self.__class__.__name__} name={self.name} value={self._value} waiters={len(self)}>" if self._decorated: representation = f"{representation[:-1]} decorates={self._decorated}" return representation - + def __len__(self) -> int: return len(self._waiters) if self._waiters else 0 - + def decorate(self, fn: CoroFn[P, T]) -> CoroFn[P, T]: """ Wrap a coroutine function to ensure it runs with the semaphore. - + Example: Now you can rewrite this pattern: ``` semaphore = Semaphore(5) - + async def limited(): async with semaphore: return 1 @@ -118,10 +120,12 @@ async def limited(): """ if not asyncio.iscoroutinefunction(fn): raise TypeError(f"{fn} must be a coroutine function") + @functools.wraps(fn) async def semaphore_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: async with self: return await fn(*args, **kwargs) + self._decorated.add(f"{fn.__module__}.{fn.__name__}") return semaphore_wrapper @@ -129,7 +133,7 @@ async def acquire(self) -> Literal[True]: if self._value <= 0: self._ensure_debug_daemon() return await super().acquire() - + # Everything below just adds some debug logs async def _debug_daemon(self) -> None: """ @@ -137,71 +141,70 @@ async def _debug_daemon(self) -> None: """ while self._waiters: await asyncio.sleep(60) - self.logger.debug(f"{self} has {len(self)} waiters for any of: {self._decorated}") - - + self.logger.debug( + f"{self} has {len(self)} waiters for any of: {self._decorated}" + ) + + class DummySemaphore(asyncio.Semaphore): """ A dummy semaphore that implements the standard :class:`asyncio.Semaphore` API but does nothing. """ __slots__ = "name", "_value" - + def __init__(self, name: Optional[str] = None): self.name = name self._value = 0 - + def __repr__(self) -> str: return f"<{self.__class__.__name__} name={self.name}>" - + async def acquire(self) -> Literal[True]: return True - - def release(self) -> None: - ... - - async def __aenter__(self): - ... - - async def __aexit__(self, *args): - ... - + + def release(self) -> None: ... + + async def __aenter__(self): ... + + async def __aexit__(self, *args): ... + class ThreadsafeSemaphore(Semaphore): """ - While its a bit weird to run multiple event loops, sometimes either you or a lib you're using must do so. + While its a bit weird to run multiple event loops, sometimes either you or a lib you're using must do so. When in use in threaded applications, this semaphore will not work as intended but at least your program will function. You may need to reduce the semaphore value for multi-threaded applications. - + # TL;DR it's a janky fix for an edge case problem and will otherwise function as a normal a_sync.Semaphore (which is just an asyncio.Semaphore with extra bells and whistles). """ + __slots__ = "semaphores", "dummy" - + def __init__(self, value: Optional[int], name: Optional[str] = None) -> None: assert isinstance(value, int), f"{value} should be an integer." super().__init__(value, name=name) self.semaphores: DefaultDict[Thread, Semaphore] = defaultdict(lambda: Semaphore(value, name=self.name)) # type: ignore [arg-type] self.dummy = DummySemaphore(name=name) - + def __len__(self) -> int: return sum(len(sem._waiters) for sem in self.semaphores.values()) - + @functools.cached_property def use_dummy(self) -> bool: return self._value is None - + @property def semaphore(self) -> Semaphore: """ Returns the appropriate semaphore for the current thread. - + NOTE: We can't cache this property because we need to check the current thread every time we access it. """ return self.dummy if self.use_dummy else self.semaphores[current_thread()] - + async def __aenter__(self): await self.semaphore.acquire() - + async def __aexit__(self, *args): self.semaphore.release() - \ No newline at end of file diff --git a/a_sync/primitives/queue.py b/a_sync/primitives/queue.py index 9bc06f94..447d6922 100644 --- a/a_sync/primitives/queue.py +++ b/a_sync/primitives/queue.py @@ -18,30 +18,45 @@ logger = logging.getLogger(__name__) if sys.version_info < (3, 9): + class _Queue(asyncio.Queue, Generic[T]): - __slots__ = "_maxsize", "_loop", "_getters", "_putters", "_unfinished_tasks", "_finished" + __slots__ = ( + "_maxsize", + "_loop", + "_getters", + "_putters", + "_unfinished_tasks", + "_finished", + ) + else: + class _Queue(asyncio.Queue[T]): __slots__ = "_maxsize", "_getters", "_putters", "_unfinished_tasks", "_finished" + class Queue(_Queue[T]): # for type hint support, no functional difference async def get(self) -> T: self._queue return await _Queue.get(self) + def get_nowait(self) -> T: return _Queue.get_nowait(self) + async def put(self, item: T) -> None: return _Queue.put(self, item) + def put_nowait(self, item: T) -> None: return _Queue.put_nowait(self, item) - + async def get_all(self) -> List[T]: """returns 1 or more items""" try: return self.get_all_nowait() except asyncio.QueueEmpty: return [await self.get()] + def get_all_nowait(self) -> List[T]: """returns 1 or more items, or raises asyncio.QueueEmpty""" values: List[T] = [] @@ -52,16 +67,19 @@ def get_all_nowait(self) -> List[T]: if not values: raise asyncio.QueueEmpty from e return values - + async def get_multi(self, i: int, can_return_less: bool = False) -> List[T]: _validate_args(i, can_return_less) items = [] while len(items) < i and not can_return_less: try: - items.extend(self.get_multi_nowait(i - len(items), can_return_less=True)) + items.extend( + self.get_multi_nowait(i - len(items), can_return_less=True) + ) except asyncio.QueueEmpty: items = [await self.get()] return items + def get_multi_nowait(self, i: int, can_return_less: bool = False) -> List[T]: """ Just like `asyncio.Queue.get_nowait`, but will return `i` items instead of 1. @@ -85,31 +103,37 @@ def get_multi_nowait(self, i: int, can_return_less: bool = False) -> List[T]: class ProcessingQueue(_Queue[Tuple[P, "asyncio.Future[V]"]], Generic[P, V]): _closed: bool = False __slots__ = "func", "num_workers", "_worker_coro" + def __init__( - self, - func: Callable[P, Awaitable[V]], - num_workers: int, - *, - return_data: bool = True, + self, + func: Callable[P, Awaitable[V]], + num_workers: int, + *, + return_data: bool = True, name: str = "", loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: if sys.version_info < (3, 10): super().__init__(loop=loop) elif loop: - raise NotImplementedError(f"You cannot pass a value for `loop` in python {sys.version_info}") + raise NotImplementedError( + f"You cannot pass a value for `loop` in python {sys.version_info}" + ) else: super().__init__() - + self.func = func self.num_workers = num_workers self._name = name self._no_futs = not return_data + @functools.wraps(func) async def _worker_coro() -> NoReturn: # we use this little helper so we can have context of `func` in any err logs return await self.__worker_coro() + self._worker_coro = _worker_coro + # NOTE: asyncio defines both this and __str__ def __repr__(self) -> str: repr_string = f"<{type(self).__name__} at {hex(id(self))}" @@ -119,6 +143,7 @@ def __repr__(self) -> str: if self._unfinished_tasks: repr_string += f" pending={self._unfinished_tasks}" return f"{repr_string}>" + # NOTE: asyncio defines both this and __repr__ def __str__(self) -> str: repr_string = f"<{type(self).__name__}" @@ -128,21 +153,26 @@ def __str__(self) -> str: if self._unfinished_tasks: repr_string += f" pending={self._unfinished_tasks}" return f"{repr_string}>" + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]": return self.put_nowait(*args, **kwargs) + def __del__(self) -> None: if self._closed: return if self._unfinished_tasks > 0: context = { - 'message': f'{self} was destroyed but has work pending!', + "message": f"{self} was destroyed but has work pending!", } asyncio.get_event_loop().call_exception_handler(context) + @property def name(self) -> str: return self._name or repr(self) + def close(self) -> None: self._closed = True + async def put(self, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]": self._ensure_workers() if self._no_futs: @@ -150,6 +180,7 @@ async def put(self, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]": fut = self._create_future() await super().put((args, kwargs, fut)) return fut + def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]": self._ensure_workers() if self._no_futs: @@ -157,8 +188,10 @@ def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]": fut = self._create_future() super().put_nowait((args, kwargs, weakref.proxy(fut))) return fut + def _create_future(self) -> "asyncio.Future[V]": return asyncio.get_event_loop().create_future() + def _ensure_workers(self) -> None: if self._closed: raise RuntimeError(f"{type(self).__name__} is closed: ", self) from None @@ -179,19 +212,26 @@ def _ensure_workers(self) -> None: raise type(exc)(*exc.args).with_traceback(exc.__traceback__) # type: ignore [union-attr] except TypeError: raise exc.with_traceback(exc.__traceback__) + @functools.cached_property def _workers(self) -> "asyncio.Task[NoReturn]": logger.debug("starting worker task for %s", self) workers = [ create_task( - coro=self._worker_coro(), + coro=self._worker_coro(), name=f"{self.name} [Task-{i}]", log_destroy_pending=False, - ) for i in range(self.num_workers) + ) + for i in range(self.num_workers) ] - task = create_task(asyncio.gather(*workers), name=f"{self.name} worker main Task", log_destroy_pending=False) + task = create_task( + asyncio.gather(*workers), + name=f"{self.name} worker main Task", + log_destroy_pending=False, + ) task._workers = workers return task + async def __worker_coro(self) -> NoReturn: args: P.args kwargs: P.kwargs @@ -217,15 +257,27 @@ async def __worker_coro(self) -> NoReturn: result = await self.func(*args, **kwargs) fut.set_result(result) except asyncio.exceptions.InvalidStateError: - logger.error("cannot set result for %s %s: %s", self.func.__name__, fut, result) + logger.error( + "cannot set result for %s %s: %s", + self.func.__name__, + fut, + result, + ) except Exception as e: try: fut.set_exception(e) except asyncio.exceptions.InvalidStateError: - logger.error("cannot set exception for %s %s: %s", self.func.__name__, fut, e) + logger.error( + "cannot set exception for %s %s: %s", + self.func.__name__, + fut, + e, + ) self.task_done() except Exception as e: - logger.error("%s for %s is broken!!!", type(self).__name__, self.func) + logger.error( + "%s for %s is broken!!!", type(self).__name__, self.func + ) logger.exception(e) raise @@ -245,18 +297,19 @@ def _validate_args(i: int, can_return_less: bool) -> None: if not isinstance(i, int): raise TypeError(f"`i` must be an integer greater than 1. You passed {i}") if not isinstance(can_return_less, bool): - raise TypeError(f"`can_return_less` must be boolean. You passed {can_return_less}") + raise TypeError( + f"`can_return_less` must be boolean. You passed {can_return_less}" + ) if i <= 1: raise ValueError(f"`i` must be an integer greater than 1. You passed {i}") - class _SmartFutureRef(weakref.ref, Generic[T]): def __lt__(self, other: "_SmartFutureRef[T]") -> bool: """ Compares two weak references to SmartFuture objects for ordering. - This comparison is used in priority queues to determine the order of processing. A SmartFuture + This comparison is used in priority queues to determine the order of processing. A SmartFuture reference is considered less than another if it has more waiters or if it has been garbage collected. Args: @@ -273,30 +326,41 @@ def __lt__(self, other: "_SmartFutureRef[T]") -> bool: return False return strong_self < strong_other + class _PriorityQueueMixin(Generic[T]): def _init(self, maxsize): self._queue: List[T] = [] + def _put(self, item, heappush=heapq.heappush): heappush(self._queue, item) + def _get(self, heappop=heapq.heappop): return heappop(self._queue) + class PriorityProcessingQueue(_PriorityQueueMixin[T], ProcessingQueue[T, V]): # NOTE: WIP - async def put(self, priority: Any, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]": + async def put( + self, priority: Any, *args: P.args, **kwargs: P.kwargs + ) -> "asyncio.Future[V]": self._ensure_workers() fut = asyncio.get_event_loop().create_future() await super().put(self, (priority, args, kwargs, fut)) return fut - def put_nowait(self, priority: Any, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]": + + def put_nowait( + self, priority: Any, *args: P.args, **kwargs: P.kwargs + ) -> "asyncio.Future[V]": self._ensure_workers() fut = self._create_future() super().put_nowait(self, (priority, args, kwargs, fut)) return fut + def _get(self, heappop=heapq.heappop): priority, args, kwargs, fut = heappop(self._queue) return args, kwargs, fut + class _VariablePriorityQueueMixin(_PriorityQueueMixin[T]): def _get(self, heapify=heapq.heapify, heappop=heapq.heappop): "Resort the heap to consider any changes in priorities and pop the smallest value" @@ -304,27 +368,38 @@ def _get(self, heapify=heapq.heapify, heappop=heapq.heappop): heapify(self._queue) # take the job with the most waiters return heappop(self._queue) + def _get_key(self, *args, **kwargs) -> _smart._Key: return (args, tuple((kwarg, kwargs[kwarg]) for kwarg in sorted(kwargs))) + class VariablePriorityQueue(_VariablePriorityQueueMixin[T], asyncio.PriorityQueue): """A PriorityQueue subclass that allows priorities to be updated (or computed) on the fly""" + # NOTE: WIP - -class SmartProcessingQueue(_VariablePriorityQueueMixin[T], ProcessingQueue[Concatenate[T, P], V]): + + +class SmartProcessingQueue( + _VariablePriorityQueueMixin[T], ProcessingQueue[Concatenate[T, P], V] +): """A PriorityProcessingQueue subclass that will execute jobs with the most waiters first""" + _no_futs = False _futs: "weakref.WeakValueDictionary[_smart._Key, _smart.SmartFuture[T]]" + def __init__( - self, - func: Callable[Concatenate[T, P], Awaitable[V]], - num_workers: int, - *, + self, + func: Callable[Concatenate[T, P], Awaitable[V]], + num_workers: int, + *, name: str = "", loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: super().__init__(func, num_workers, return_data=True, name=name, loop=loop) - self._futs: Dict[_smart._Key[T], _smart.SmartFuture[T]] = weakref.WeakValueDictionary() + self._futs: Dict[_smart._Key[T], _smart.SmartFuture[T]] = ( + weakref.WeakValueDictionary() + ) + async def put(self, *args: P.args, **kwargs: P.kwargs) -> _smart.SmartFuture[V]: self._ensure_workers() key = self._get_key(*args, **kwargs) @@ -334,6 +409,7 @@ async def put(self, *args: P.args, **kwargs: P.kwargs) -> _smart.SmartFuture[V]: self._futs[key] = fut await Queue.put(self, (_SmartFutureRef(fut), args, kwargs)) return fut + def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> _smart.SmartFuture[V]: self._ensure_workers() key = self._get_key(*args, **kwargs) @@ -343,11 +419,14 @@ def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> _smart.SmartFuture[V] self._futs[key] = fut Queue.put_nowait(self, (_SmartFutureRef(fut), args, kwargs)) return fut + def _create_future(self, key: _smart._Key) -> "asyncio.Future[V]": return _smart.create_future(queue=self, key=key, loop=self._loop) + def _get(self): fut, args, kwargs = super()._get() return args, kwargs, fut() + async def __worker_coro(self) -> NoReturn: args: P.args kwargs: P.kwargs @@ -364,13 +443,23 @@ async def __worker_coro(self) -> NoReturn: result = await self.func(*args, **kwargs) fut.set_result(result) except asyncio.exceptions.InvalidStateError: - logger.error("cannot set result for %s %s: %s", self.func.__name__, fut, result) + logger.error( + "cannot set result for %s %s: %s", + self.func.__name__, + fut, + result, + ) except Exception as e: logger.debug("%s: %s", type(e).__name__, e) try: fut.set_exception(e) except asyncio.exceptions.InvalidStateError: - logger.error("cannot set exception for %s %s: %s", self.func.__name__, fut, e) + logger.error( + "cannot set exception for %s %s: %s", + self.func.__name__, + fut, + e, + ) self.task_done() except Exception as e: logger.error("%s for %s is broken!!!", type(self).__name__, self.func) diff --git a/a_sync/sphinx/__init__.py b/a_sync/sphinx/__init__.py index 772384bf..25aae202 100644 --- a/a_sync/sphinx/__init__.py +++ b/a_sync/sphinx/__init__.py @@ -1,4 +1,3 @@ - from a_sync.sphinx import ext __all__ = ["ext"] diff --git a/a_sync/sphinx/ext.py b/a_sync/sphinx/ext.py index b974ad86..ea5840e7 100644 --- a/a_sync/sphinx/ext.py +++ b/a_sync/sphinx/ext.py @@ -32,6 +32,7 @@ Use ``.. autotask::`` to alternatively manually document a task. """ + from inspect import signature from docutils import nodes @@ -39,17 +40,22 @@ from sphinx.ext.autodoc import FunctionDocumenter, MethodDocumenter from a_sync.a_sync._descriptor import ASyncDescriptor -from a_sync.a_sync.function import ASyncFunction, ASyncFunctionAsyncDefault, ASyncFunctionSyncDefault +from a_sync.a_sync.function import ( + ASyncFunction, + ASyncFunctionAsyncDefault, + ASyncFunctionSyncDefault, +) from a_sync.iter import ASyncGeneratorFunction - class _ASyncWrapperDocumenter: typ: type @classmethod def can_document_member(cls, member, membername, isattr, parent): - return isinstance(member, cls.typ) and getattr(member, '__wrapped__') is not None + return ( + isinstance(member, cls.typ) and getattr(member, "__wrapped__") is not None + ) def document_members(self, all_members=False): pass @@ -59,98 +65,118 @@ def check_module(self): # given by *self.modname*. But since functions decorated with the @task # decorator are instances living in the celery.local, we have to check # the wrapped function instead. - wrapped = getattr(self.object, '__wrapped__', None) - if wrapped and getattr(wrapped, '__module__') == self.modname: + wrapped = getattr(self.object, "__wrapped__", None) + if wrapped and getattr(wrapped, "__module__") == self.modname: return True return super().check_module() + class _ASyncFunctionDocumenter(_ASyncWrapperDocumenter, FunctionDocumenter): def format_args(self): - wrapped = getattr(self.object, '__wrapped__', None) + wrapped = getattr(self.object, "__wrapped__", None) if wrapped is not None: sig = signature(wrapped) if "self" in sig.parameters or "cls" in sig.parameters: sig = sig.replace(parameters=list(sig.parameters.values())[1:]) return str(sig) - return '' + return "" + class _ASyncMethodDocumenter(_ASyncWrapperDocumenter, MethodDocumenter): def format_args(self): - wrapped = getattr(self.object, '__wrapped__') + wrapped = getattr(self.object, "__wrapped__") if wrapped is not None: return str(signature(wrapped)) - return '' - + return "" + + class _ASyncDirective: prefix_env: str + def get_signature_prefix(self, sig): return [nodes.Text(getattr(self.env.config, self.prefix_env))] + class _ASyncFunctionDirective(_ASyncDirective, PyFunction): pass + class _ASyncMethodDirective(_ASyncDirective, PyMethod): pass - + class ASyncFunctionDocumenter(_ASyncFunctionDocumenter): """Document ASyncFunction instance definitions.""" - objtype = 'a_sync_function' + + objtype = "a_sync_function" typ = ASyncFunction priority = 15 - #member_order = 11 + # member_order = 11 + class ASyncFunctionSyncDocumenter(_ASyncFunctionDocumenter): """Document ASyncFunction instance definitions.""" - objtype = 'a_sync_function_sync' + + objtype = "a_sync_function_sync" typ = ASyncFunctionSyncDefault priority = 14 - #member_order = 11 + # member_order = 11 + class ASyncFunctionAsyncDocumenter(_ASyncFunctionDocumenter): """Document ASyncFunction instance definitions.""" - objtype = 'a_sync_function_async' + + objtype = "a_sync_function_async" typ = ASyncFunctionAsyncDefault priority = 13 - #member_order = 11 + # member_order = 11 class ASyncFunctionDirective(_ASyncFunctionDirective): prefix_env = "a_sync_function_prefix" + class ASyncFunctionSyncDirective(_ASyncFunctionDirective): prefix_env = "a_sync_function_sync_prefix" + class ASyncFunctionAsyncDirective(_ASyncFunctionDirective): prefix_env = "a_sync_function_async_prefix" class ASyncDescriptorDocumenter(_ASyncMethodDocumenter): """Document ASyncDescriptor instance definitions.""" - objtype = 'a_sync_descriptor' + + objtype = "a_sync_descriptor" typ = ASyncDescriptor - #member_order = 11 + # member_order = 11 class ASyncDescriptorDirective(_ASyncMethodDirective): """Sphinx task directive.""" + prefix_env = "a_sync_descriptor_prefix" class ASyncGeneratorFunctionDocumenter(_ASyncFunctionDocumenter): """Document ASyncFunction instance definitions.""" - objtype = 'a_sync_generator_function' + + objtype = "a_sync_generator_function" typ = ASyncGeneratorFunction - #member_order = 11 + # member_order = 11 class ASyncGeneratorFunctionDirective(_ASyncFunctionDirective): """Sphinx task directive.""" + prefix_env = "a_sync_generator_function_prefix" + def autodoc_skip_member_handler(app, what, name, obj, skip, options): """Handler for autodoc-skip-member event.""" - if isinstance(obj, (ASyncFunction, ASyncDescriptor, ASyncGeneratorFunction)) and getattr(obj, '__wrapped__'): + if isinstance( + obj, (ASyncFunction, ASyncDescriptor, ASyncGeneratorFunction) + ) and getattr(obj, "__wrapped__"): if skip: return False return None @@ -158,32 +184,38 @@ def autodoc_skip_member_handler(app, what, name, obj, skip, options): def setup(app): """Setup Sphinx extension.""" - app.setup_extension('sphinx.ext.autodoc') - + app.setup_extension("sphinx.ext.autodoc") + # function app.add_autodocumenter(ASyncFunctionDocumenter) app.add_autodocumenter(ASyncFunctionSyncDocumenter) app.add_autodocumenter(ASyncFunctionAsyncDocumenter) - app.add_directive_to_domain('py', 'a_sync_function', ASyncFunctionDirective) - app.add_directive_to_domain('py', 'a_sync_function_sync', ASyncFunctionSyncDirective) - app.add_directive_to_domain('py', 'a_sync_function_async', ASyncFunctionAsyncDirective) - app.add_config_value('a_sync_function_sync_prefix', 'ASyncFunction (sync)', True) - app.add_config_value('a_sync_function_async_prefix', 'ASyncFunction (async)', True) - app.add_config_value('a_sync_function_prefix', 'ASyncFunction', True) + app.add_directive_to_domain("py", "a_sync_function", ASyncFunctionDirective) + app.add_directive_to_domain( + "py", "a_sync_function_sync", ASyncFunctionSyncDirective + ) + app.add_directive_to_domain( + "py", "a_sync_function_async", ASyncFunctionAsyncDirective + ) + app.add_config_value("a_sync_function_sync_prefix", "ASyncFunction (sync)", True) + app.add_config_value("a_sync_function_async_prefix", "ASyncFunction (async)", True) + app.add_config_value("a_sync_function_prefix", "ASyncFunction", True) # descriptor app.add_autodocumenter(ASyncDescriptorDocumenter) - app.add_directive_to_domain('py', 'a_sync_descriptor', ASyncDescriptorDirective) - app.add_config_value('a_sync_descriptor_prefix', 'ASyncDescriptor', True) + app.add_directive_to_domain("py", "a_sync_descriptor", ASyncDescriptorDirective) + app.add_config_value("a_sync_descriptor_prefix", "ASyncDescriptor", True) # generator - + app.add_autodocumenter(ASyncGeneratorFunctionDocumenter) - app.add_directive_to_domain('py', 'a_sync_generator_function', ASyncGeneratorFunctionDirective) - app.add_config_value('a_sync_generator_function_prefix', 'ASyncGeneratorFunction', True) + app.add_directive_to_domain( + "py", "a_sync_generator_function", ASyncGeneratorFunctionDirective + ) + app.add_config_value( + "a_sync_generator_function_prefix", "ASyncGeneratorFunction", True + ) - app.connect('autodoc-skip-member', autodoc_skip_member_handler) + app.connect("autodoc-skip-member", autodoc_skip_member_handler) - return { - 'parallel_read_safe': True - } + return {"parallel_read_safe": True} diff --git a/a_sync/task.py b/a_sync/task.py index 63d6d988..51992239 100644 --- a/a_sync/task.py +++ b/a_sync/task.py @@ -1,4 +1,3 @@ - import asyncio import contextlib import functools @@ -12,7 +11,11 @@ from a_sync.a_sync import _kwargs from a_sync.a_sync.base import ASyncGenericBase from a_sync.a_sync.function import ASyncFunction -from a_sync.a_sync.method import ASyncBoundMethod, ASyncMethodDescriptor, ASyncMethodDescriptorSyncDefault +from a_sync.a_sync.method import ( + ASyncBoundMethod, + ASyncMethodDescriptor, + ASyncMethodDescriptorSyncDefault, +) from a_sync.a_sync.property import _ASyncPropertyDescriptorBase from a_sync.asyncio.as_completed import as_completed from a_sync.asyncio.gather import Excluder, gather @@ -25,9 +28,9 @@ logger = logging.getLogger(__name__) - MappingFn = Callable[Concatenate[K, P], Awaitable[V]] + class TaskMapping(DefaultDict[K, "asyncio.Task[V]"], AsyncIterable[Tuple[K, V]]): """ A mapping of keys to asynchronous tasks with additional functionality. @@ -48,7 +51,7 @@ async def fetch_data(url: str) -> str: async for key, result in tasks: print(f"Data for {key}: {result}") """ - + concurrency: Optional[int] = None "The max number of tasks that will run at one time." @@ -57,8 +60,10 @@ async def fetch_data(url: str) -> str: _init_loader: Optional["asyncio.Task[None]"] = None "An asyncio Task used to preload values from the iterables." - - _init_loader_next: Optional[Callable[[], Awaitable[Tuple[Tuple[K, "asyncio.Task[V]"]]]]] = None + + _init_loader_next: Optional[ + Callable[[], Awaitable[Tuple[Tuple[K, "asyncio.Task[V]"]]]] + ] = None "A coro function that blocks until the _init_loader starts a new task(s), and then returns a `Tuple[Tuple[K, asyncio.Task[V]]]` with all of the new tasks and the keys that started them." _name: Optional[str] = None @@ -71,19 +76,20 @@ async def fetch_data(url: str) -> str: "Additional keyword arguments passed to `_wrapped_func`." __iterables__: Tuple[AnyIterableOrAwaitableIterable[K], ...] = () - "The original iterables, if any, used to initialize this mapping.""" - + "The original iterables, if any, used to initialize this mapping." + __init_loader_coro: Optional[Awaitable[None]] = None """An optional asyncio Coroutine to be run by the `_init_loader`""" __slots__ = "_wrapped_func", "__wrapped__", "__dict__", "__weakref__" + # NOTE: maybe since we use so many classvars here we are better off getting rid of slots def __init__( - self, - wrapped_func: MappingFn[K, P, V] = None, - *iterables: AnyIterableOrAwaitableIterable[K], - name: str = '', - concurrency: Optional[int] = None, + self, + wrapped_func: MappingFn[K, P, V] = None, + *iterables: AnyIterableOrAwaitableIterable[K], + name: str = "", + concurrency: Optional[int] = None, **wrapped_func_kwargs: P.kwargs, ) -> None: """ @@ -101,7 +107,7 @@ def __init__( self.concurrency = concurrency self.__wrapped__ = wrapped_func - "The original callable used to initialize this mapping without any modifications.""" + "The original callable used to initialize this mapping without any modifications." if iterables: self.__iterables__ = iterables @@ -121,48 +127,67 @@ def __init__( if iterables: self._next = Event(name=f"{self} `_next`") + @functools.wraps(wrapped_func) - async def _wrapped_set_next(*args: P.args, __a_sync_recursion: int = 0, **kwargs: P.kwargs) -> V: + async def _wrapped_set_next( + *args: P.args, __a_sync_recursion: int = 0, **kwargs: P.kwargs + ) -> V: try: return await wrapped_func(*args, **kwargs) except exceptions.SyncModeInAsyncContextError as e: raise Exception(e, self.__wrapped__) except TypeError as e: - if __a_sync_recursion > 2 or not (str(e).startswith(wrapped_func.__name__) and "got multiple values for argument" in str(e)): + if __a_sync_recursion > 2 or not ( + str(e).startswith(wrapped_func.__name__) + and "got multiple values for argument" in str(e) + ): raise # NOTE: args ordering is clashing with provided kwargs. We can handle this in a hacky way. # TODO: perform this check earlier and pre-prepare the args/kwargs ordering new_args = list(args) new_kwargs = dict(kwargs) try: - for i, arg in enumerate(inspect.getfullargspec(self.__wrapped__).args): + for i, arg in enumerate( + inspect.getfullargspec(self.__wrapped__).args + ): if arg in kwargs: new_args.insert(i, new_kwargs.pop(arg)) else: break - return await _wrapped_set_next(*new_args, **new_kwargs, __a_sync_recursion=__a_sync_recursion+1) + return await _wrapped_set_next( + *new_args, + **new_kwargs, + __a_sync_recursion=__a_sync_recursion + 1, + ) except TypeError as e2: - raise e.with_traceback(e.__traceback__) if str(e2) == "unsupported callable" else e2.with_traceback(e2.__traceback__) + raise ( + e.with_traceback(e.__traceback__) + if str(e2) == "unsupported callable" + else e2.with_traceback(e2.__traceback__) + ) finally: self._next.set() self._next.clear() + self._wrapped_func = _wrapped_set_next init_loader_queue: Queue[Tuple[K, "asyncio.Future[V]"]] = Queue() - self.__init_loader_coro = exhaust_iterator(self._tasks_for_iterables(*iterables), queue=init_loader_queue) + self.__init_loader_coro = exhaust_iterator( + self._tasks_for_iterables(*iterables), queue=init_loader_queue + ) with contextlib.suppress(_NoRunningLoop): # its okay if we get this exception, we can start the task as soon as the loop starts self._init_loader self._init_loader_next = init_loader_queue.get_all - + def __repr__(self) -> str: return f"<{type(self).__name__} for {self._wrapped_func} kwargs={self._wrapped_func_kwargs} tasks={len(self)} at {hex(id(self))}>" - + def __hash__(self) -> int: return id(self) - + def __setitem__(self, item: Any, value: Any) -> None: raise NotImplementedError("You cannot manually set items in a TaskMapping") - + def __getitem__(self, item: K) -> "asyncio.Task[V]": try: return super().__getitem__(item) @@ -172,11 +197,11 @@ def __getitem__(self, item: K) -> "asyncio.Task[V]": fut = self._queue.put_nowait(item) else: coro = self._wrapped_func(item, **self._wrapped_func_kwargs) - name = f"{self._name}[{item}]" if self._name else f"{item}", + name = (f"{self._name}[{item}]" if self._name else f"{item}",) fut = create_task(coro=coro, name=name) super().__setitem__(item, fut) return fut - + def __await__(self) -> Generator[Any, None, Dict[K, V]]: """Wait for all tasks to complete and return a dictionary of the results.""" return self.gather(sync=False).__await__() @@ -195,7 +220,9 @@ async def __aiter__(self, pop: bool = False) -> AsyncIterator[Tuple[K, V]]: while not self._init_loader.done(): await self._wait_for_next_key() while unyielded := [key for key in self if key not in yielded]: - if ready := {key: task for key in unyielded if (task:=self[key]).done()}: + if ready := { + key: task for key in unyielded if (task := self[key]).done() + }: if pop: for key, task in ready.items(): yield key, await self.pop(key) @@ -231,34 +258,41 @@ def keys(self, pop: bool = False) -> "TaskMappingKeys[K, V]": def values(self, pop: bool = False) -> "TaskMappingValues[K, V]": return TaskMappingValues(super().values(), self, pop=pop) - + def items(self, pop: bool = False) -> "TaskMappingValues[K, V]": return TaskMappingItems(super().items(), self, pop=pop) - + async def close(self) -> None: await self._if_pop_clear(True) @ASyncGeneratorFunction - async def map(self, *iterables: AnyIterableOrAwaitableIterable[K], pop: bool = True, yields: Literal['keys', 'both'] = 'both') -> AsyncIterator[Tuple[K, V]]: + async def map( + self, + *iterables: AnyIterableOrAwaitableIterable[K], + pop: bool = True, + yields: Literal["keys", "both"] = "both", + ) -> AsyncIterator[Tuple[K, V]]: """ - Asynchronously map iterables to tasks and yield their results. + Asynchronously map iterables to tasks and yield their results. - Args: - *iterables: Iterables to map over. - pop: Whether to remove tasks from the internal storage once they are completed. - yields: Whether to yield 'keys', 'values', or 'both' (key-value pairs). - - Yields: - Depending on `yields`, either keys, values, - or tuples of key-value pairs representing the results of completed tasks. + Args: + *iterables: Iterables to map over. + pop: Whether to remove tasks from the internal storage once they are completed. + yields: Whether to yield 'keys', 'values', or 'both' (key-value pairs). + + Yields: + Depending on `yields`, either keys, values, + or tuples of key-value pairs representing the results of completed tasks. """ self._if_pop_check_destroyed(pop) - + # make sure the init loader is started if needed init_loader = self._init_loader if iterables and init_loader: - raise ValueError("You cannot pass `iterables` to map if the TaskMapping was initialized with an (a)iterable.") - + raise ValueError( + "You cannot pass `iterables` to map if the TaskMapping was initialized with an (a)iterable." + ) + try: if iterables: self._raise_if_not_empty() @@ -269,15 +303,19 @@ async def map(self, *iterables: AnyIterableOrAwaitableIterable[K], pop: bool = T except _EmptySequenceError: if len(iterables) > 1: # TODO gotta handle this situation - raise exceptions.EmptySequenceError("bob needs to code something so you can do this, go tell him") from None + raise exceptions.EmptySequenceError( + "bob needs to code something so you can do this, go tell him" + ) from None # just pass thru - + elif init_loader: # check for exceptions if you passed an iterable(s) into the class init await init_loader - + else: - self._raise_if_empty("You must either initialize your TaskMapping with an iterable(s) or provide them during your call to map") + self._raise_if_empty( + "You must either initialize your TaskMapping with an iterable(s) or provide them during your call to map" + ) if self: if pop: @@ -289,7 +327,7 @@ async def map(self, *iterables: AnyIterableOrAwaitableIterable[K], pop: bool = T yield _yield(key, value, yields) finally: await self._if_pop_clear(pop) - + @ASyncMethodDescriptorSyncDefault async def all(self, pop: bool = True) -> bool: try: @@ -301,7 +339,7 @@ async def all(self, pop: bool = True) -> bool: return True finally: await self._if_pop_clear(pop) - + @ASyncMethodDescriptorSyncDefault async def any(self, pop: bool = True) -> bool: try: @@ -313,7 +351,7 @@ async def any(self, pop: bool = True) -> bool: return False finally: await self._if_pop_clear(pop) - + @ASyncMethodDescriptorSyncDefault async def max(self, pop: bool = True) -> V: max = None @@ -322,11 +360,15 @@ async def max(self, pop: bool = True) -> V: if max is None or result > max: max = result except _EmptySequenceError: - raise exceptions.EmptySequenceError("max() arg is an empty sequence") from None + raise exceptions.EmptySequenceError( + "max() arg is an empty sequence" + ) from None if max is None: - raise exceptions.EmptySequenceError("max() arg is an empty sequence") from None + raise exceptions.EmptySequenceError( + "max() arg is an empty sequence" + ) from None return max - + @ASyncMethodDescriptorSyncDefault async def min(self, pop: bool = True) -> V: min = None @@ -335,11 +377,15 @@ async def min(self, pop: bool = True) -> V: if min is None or result < min: min = result except _EmptySequenceError: - raise exceptions.EmptySequenceError("min() arg is an empty sequence") from None + raise exceptions.EmptySequenceError( + "min() arg is an empty sequence" + ) from None if min is None: - raise exceptions.EmptySequenceError("min() arg is an empty sequence") from None + raise exceptions.EmptySequenceError( + "min() arg is an empty sequence" + ) from None return min - + @ASyncMethodDescriptorSyncDefault async def sum(self, pop: bool = False) -> V: retval = 0 @@ -369,11 +415,11 @@ async def yield_completed(self, pop: bool = True) -> AsyncIterator[Tuple[K, V]]: for k, task in dict(self).items(): if task.done(): yield k, await task - + @ASyncMethodDescriptorSyncDefault async def gather( - self, - return_exceptions: bool = False, + self, + return_exceptions: bool = False, exclude_if: Excluder[V] = None, tqdm: bool = False, **tqdm_kwargs: Any, @@ -382,18 +428,30 @@ async def gather( if self._init_loader: await self._init_loader self._raise_if_empty() - return await gather(self, return_exceptions=return_exceptions, exclude_if=exclude_if, tqdm=tqdm, **tqdm_kwargs) - + return await gather( + self, + return_exceptions=return_exceptions, + exclude_if=exclude_if, + tqdm=tqdm, + **tqdm_kwargs, + ) + @overload - def pop(self, item: K, cancel: bool = False) -> "Union[asyncio.Task[V], asyncio.Future[V]]":... + def pop( + self, item: K, cancel: bool = False + ) -> "Union[asyncio.Task[V], asyncio.Future[V]]": ... @overload - def pop(self, item: K, default: K, cancel: bool = False) -> "Union[asyncio.Task[V], asyncio.Future[V]]":... - def pop(self, *args: K, cancel: bool = False) -> "Union[asyncio.Task[V], asyncio.Future[V]]": + def pop( + self, item: K, default: K, cancel: bool = False + ) -> "Union[asyncio.Task[V], asyncio.Future[V]]": ... + def pop( + self, *args: K, cancel: bool = False + ) -> "Union[asyncio.Task[V], asyncio.Future[V]]": fut_or_task = super().pop(*args) if cancel: fut_or_task.cancel() return fut_or_task - + def clear(self, cancel: bool = False) -> None: if cancel and self._init_loader and not self._init_loader.done(): logger.debug("cancelling %s", self._init_loader) @@ -412,106 +470,121 @@ def clear(self, cancel: bool = False) -> None: def _init_loader(self) -> Optional["asyncio.Task[None]"]: if self.__init_loader_coro: logger.debug("starting %s init loader", self) - name=f"{type(self).__name__} init loader loading {self.__iterables__} for {self}" + name = f"{type(self).__name__} init loader loading {self.__iterables__} for {self}" try: task = create_task(coro=self.__init_loader_coro, name=name) except RuntimeError as e: raise _NoRunningLoop if str(e) == "no running event loop" else e task.add_done_callback(self.__cleanup) return task - + @functools.cached_property def _queue(self) -> ProcessingQueue: fn = functools.partial(self._wrapped_func, **self._wrapped_func_kwargs) return ProcessingQueue(fn, self.concurrency, name=self._name) - - def _raise_if_empty(self, msg: str = '') -> None: + + def _raise_if_empty(self, msg: str = "") -> None: if not self: raise exceptions.MappingIsEmptyError(self, msg) - + def _raise_if_not_empty(self) -> None: if self: raise exceptions.MappingNotEmptyError(self) @ASyncGeneratorFunction - async def _tasks_for_iterables(self, *iterables: AnyIterableOrAwaitableIterable[K]) -> AsyncIterator[Tuple[K, "asyncio.Task[V]"]]: + async def _tasks_for_iterables( + self, *iterables: AnyIterableOrAwaitableIterable[K] + ) -> AsyncIterator[Tuple[K, "asyncio.Task[V]"]]: """Ensure tasks are running for each key in the provided iterables.""" # if we have any regular containers we can yield their contents right away - containers = [iterable for iterable in iterables if not isinstance(iterable, AsyncIterable) and isinstance(iterable, Iterable)] + containers = [ + iterable + for iterable in iterables + if not isinstance(iterable, AsyncIterable) + and isinstance(iterable, Iterable) + ] for iterable in containers: async for key in _yield_keys(iterable): yield key, self[key] - - if remaining := [iterable for iterable in iterables if iterable not in containers]: + + if remaining := [ + iterable for iterable in iterables if iterable not in containers + ]: try: - async for key in as_yielded(*[_yield_keys(iterable) for iterable in remaining]): # type: ignore [attr-defined] + async for key in as_yielded(*[_yield_keys(iterable) for iterable in remaining]): # type: ignore [attr-defined] yield key, self[key] # ensure task is running except _EmptySequenceError: if len(iterables) == 1: raise - raise RuntimeError("DEV: figure out how to handle this situation") from None - + raise RuntimeError( + "DEV: figure out how to handle this situation" + ) from None + def _if_pop_check_destroyed(self, pop: bool) -> None: if pop: if self._destroyed: raise RuntimeError(f"{self} has already been consumed") self._destroyed = True - + async def _if_pop_clear(self, pop: bool) -> None: if pop: self._destroyed = True # _queue is a cached_property, we don't want to create it if it doesn't exist - if self.concurrency and '_queue' in self.__dict__: + if self.concurrency and "_queue" in self.__dict__: self._queue.close() del self._queue self.clear(cancel=True) # we need to let the loop run once so the tasks can fully cancel await asyncio.sleep(0) - + async def _wait_for_next_key(self) -> None: # NOTE if `_init_loader` has an exception it will return first, otherwise `_init_loader_next` will return always done, pending = await asyncio.wait( - [create_task(self._init_loader_next(), log_destroy_pending=False), self._init_loader], - return_when=asyncio.FIRST_COMPLETED + [ + create_task(self._init_loader_next(), log_destroy_pending=False), + self._init_loader, + ], + return_when=asyncio.FIRST_COMPLETED, ) for task in done: # check for exceptions await task - + def __cleanup(self, t: "asyncio.Task[None]") -> None: # clear the slot and let the bound Queue die del self.__init_loader_coro -class _NoRunningLoop(Exception): - ... +class _NoRunningLoop(Exception): ... + @overload -def _yield(key: K, value: V, yields: Literal['keys']) -> K:... +def _yield(key: K, value: V, yields: Literal["keys"]) -> K: ... @overload -def _yield(key: K, value: V, yields: Literal['both']) -> Tuple[K, V]:... -def _yield(key: K, value: V, yields: Literal['keys', 'both']) -> Union[K, Tuple[K, V]]: +def _yield(key: K, value: V, yields: Literal["both"]) -> Tuple[K, V]: ... +def _yield(key: K, value: V, yields: Literal["keys", "both"]) -> Union[K, Tuple[K, V]]: """ Yield either the key, value, or both based on the 'yields' parameter. - + Args: key: The key of the task. value: The result of the task. yields: Determines what to yield; 'keys' for keys, 'both' for key-value pairs. - + Returns: The key, the value, or a tuple of both based on the 'yields' parameter. """ - if yields == 'both': + if yields == "both": return key, value - elif yields == 'keys': + elif yields == "keys": return key else: raise ValueError(f"`yields` must be 'keys' or 'both'. You passed {yields}") -class _EmptySequenceError(ValueError): - ... - + +class _EmptySequenceError(ValueError): ... + + async def _yield_keys(iterable: AnyIterableOrAwaitableIterable[K]) -> AsyncIterator[K]: """ Asynchronously yield keys from the provided iterable. @@ -536,9 +609,15 @@ async def _yield_keys(iterable: AnyIterableOrAwaitableIterable[K]) -> AsyncItera else: raise TypeError(iterable) + __unwrapped = weakref.WeakKeyDictionary() -def _unwrap(wrapped_func: Union[AnyFn[P, T], "ASyncMethodDescriptor[P, T]", _ASyncPropertyDescriptorBase[I, T]]) -> Callable[P, Awaitable[T]]: + +def _unwrap( + wrapped_func: Union[ + AnyFn[P, T], "ASyncMethodDescriptor[P, T]", _ASyncPropertyDescriptorBase[I, T] + ] +) -> Callable[P, Awaitable[T]]: if unwrapped := __unwrapped.get(wrapped_func): return unwrapped if isinstance(wrapped_func, (ASyncBoundMethod, ASyncMethodDescriptor)): @@ -548,7 +627,11 @@ def _unwrap(wrapped_func: Union[AnyFn[P, T], "ASyncMethodDescriptor[P, T]", _ASy elif isinstance(wrapped_func, ASyncFunction): # this speeds things up a bit by bypassing some logic # TODO implement it like this elsewhere if profilers suggest - unwrapped = wrapped_func._modified_fn if wrapped_func._async_def else wrapped_func._asyncified + unwrapped = ( + wrapped_func._modified_fn + if wrapped_func._async_def + else wrapped_func._asyncified + ) else: unwrapped = wrapped_func __unwrapped[wrapped_func] = unwrapped @@ -558,34 +641,50 @@ def _unwrap(wrapped_func: Union[AnyFn[P, T], "ASyncMethodDescriptor[P, T]", _ASy _get_key: Callable[[Tuple[K, V]], K] = lambda k_and_v: k_and_v[0] _get_value: Callable[[Tuple[K, V]], V] = lambda k_and_v: k_and_v[1] + class _TaskMappingView(ASyncGenericBase, Iterable[T], Generic[T, K, V]): _get_from_item: Callable[[Tuple[K, V]], T] _pop: bool = False - def __init__(self, view: Iterable[T], task_mapping: TaskMapping[K, V], pop: bool = False) -> None: + + def __init__( + self, view: Iterable[T], task_mapping: TaskMapping[K, V], pop: bool = False + ) -> None: self.__view__ = view self.__mapping__: TaskMapping = weakref.proxy(task_mapping) "actually a weakref.ProxyType[TaskMapping] but then type hints weren't working" if pop: self._pop = True + def __iter__(self) -> Iterator[T]: return iter(self.__view__) + def __await__(self) -> Generator[Any, None, List[T]]: return self._await().__await__() + def __len__(self) -> int: return len(self.__view__) + async def _await(self) -> List[T]: return [result async for result in self] + __slots__ = "__view__", "__mapping__" + async def aiterbykeys(self, reverse: bool = False) -> ASyncIterator[T]: - async for tup in ASyncSorter(self.__mapping__.items(pop=self._pop), key=_get_key, reverse=reverse): + async for tup in ASyncSorter( + self.__mapping__.items(pop=self._pop), key=_get_key, reverse=reverse + ): yield self._get_from_item(tup) + async def aiterbyvalues(self, reverse: bool = False) -> ASyncIterator[T]: - async for tup in ASyncSorter(self.__mapping__.items(pop=self._pop), key=_get_value, reverse=reverse): + async for tup in ASyncSorter( + self.__mapping__.items(pop=self._pop), key=_get_value, reverse=reverse + ): yield self._get_from_item(tup) class TaskMappingKeys(_TaskMappingView[K, K, V], Generic[K, V]): _get_from_item = lambda self, item: _get_key(item) + async def __aiter__(self) -> AsyncIterator[K]: # strongref mapping = self.__mapping__ @@ -610,6 +709,7 @@ async def __aiter__(self) -> AsyncIterator[K]: for key in self.__load_existing(): if key not in yielded: yield key + def __load_existing(self) -> Iterator[K]: # strongref mapping = self.__mapping__ @@ -620,6 +720,7 @@ def __load_existing(self) -> Iterator[K]: else: for key in tuple(mapping): yield key + async def __load_init_loader(self, yielded: Set[K]) -> AsyncIterator[K]: # strongref mapping = self.__mapping__ @@ -637,8 +738,10 @@ async def __load_init_loader(self, yielded: Set[K]) -> AsyncIterator[K]: # check for any exceptions await mapping._init_loader + class TaskMappingItems(_TaskMappingView[Tuple[K, V], K, V], Generic[K, V]): _get_from_item = lambda self, item: item + async def __aiter__(self) -> AsyncIterator[Tuple[K, V]]: # strongref mapping = self.__mapping__ @@ -649,9 +752,11 @@ async def __aiter__(self) -> AsyncIterator[Tuple[K, V]]: else: async for key in mapping.keys(): yield key, await mapping[key] - + + class TaskMappingValues(_TaskMappingView[V, K, V], Generic[K, V]): _get_from_item = lambda self, item: _get_value(item) + async def __aiter__(self) -> AsyncIterator[V]: # strongref mapping = self.__mapping__ @@ -664,4 +769,10 @@ async def __aiter__(self) -> AsyncIterator[V]: yield await mapping[key] -__all__ = ["create_task", "TaskMapping", "TaskMappingKeys", "TaskMappingValues", "TaskMappingItems"] +__all__ = [ + "create_task", + "TaskMapping", + "TaskMappingKeys", + "TaskMappingValues", + "TaskMappingItems", +] diff --git a/a_sync/utils/__init__.py b/a_sync/utils/__init__.py index 88431d4b..8b07e0d4 100644 --- a/a_sync/utils/__init__.py +++ b/a_sync/utils/__init__.py @@ -1,17 +1,17 @@ import asyncio -from a_sync.utils.iterators import (as_yielded, exhaust_iterator, - exhaust_iterators) +from a_sync.utils.iterators import as_yielded, exhaust_iterator, exhaust_iterators __all__ = [ - #"all", - #"any", + # "all", + # "any", "as_yielded", "exhaust_iterator", "exhaust_iterators", ] + async def any(*awaitables) -> bool: """ Asynchronously evaluates whether any of the given awaitables evaluates to True. @@ -56,12 +56,13 @@ async def any(*awaitables) -> bool: fut.cancel() return True return False - + + async def all(*awaitables) -> bool: """ Asynchronously evaluates whether all of the given awaitables evaluate to True. - This function takes multiple awaitable objects and returns True if all of them evaluate to True. It cancels + This function takes multiple awaitable objects and returns True if all of them evaluate to True. It cancels the remaining awaitables once a False result is found. Args: diff --git a/a_sync/utils/iterators.py b/a_sync/utils/iterators.py index 99ba1d2b..93184310 100644 --- a/a_sync/utils/iterators.py +++ b/a_sync/utils/iterators.py @@ -1,4 +1,3 @@ - import asyncio import asyncio.futures import logging @@ -11,7 +10,9 @@ logger = logging.getLogger(__name__) -async def exhaust_iterator(iterator: AsyncIterator[T], *, queue: Optional[asyncio.Queue] = None) -> None: +async def exhaust_iterator( + iterator: AsyncIterator[T], *, queue: Optional[asyncio.Queue] = None +) -> None: """ Asynchronously iterates over items from the given async iterator and optionally places them into a queue. @@ -26,12 +27,14 @@ async def exhaust_iterator(iterator: AsyncIterator[T], *, queue: Optional[asynci """ async for thing in iterator: if queue: - logger.debug('putting %s from %s to queue %s', thing, iterator, queue) + logger.debug("putting %s from %s to queue %s", thing, iterator, queue) queue.put_nowait(thing) -async def exhaust_iterators(iterators, *, queue: Optional[asyncio.Queue] = None, join: bool = False) -> None: - """ +async def exhaust_iterators( + iterators, *, queue: Optional[asyncio.Queue] = None, join: bool = False +) -> None: + """ Asynchronously iterates over multiple async iterators concurrently and optionally places their items into a queue. This function leverages asyncio.gather to concurrently exhaust multiple async iterators. It's useful in scenarios where items from multiple async sources need to be processed or collected together, supporting concurrent operations and efficient multitasking. @@ -44,7 +47,10 @@ async def exhaust_iterators(iterators, *, queue: Optional[asyncio.Queue] = None, Returns: None """ - for x in await asyncio.gather(*[exhaust_iterator(iterator, queue=queue) for iterator in iterators], return_exceptions=True): + for x in await asyncio.gather( + *[exhaust_iterator(iterator, queue=queue) for iterator in iterators], + return_exceptions=True, + ): if isinstance(x, Exception): # raise it with its original traceback instead of from here raise x.with_traceback(x.__traceback__) @@ -55,40 +61,108 @@ async def exhaust_iterators(iterators, *, queue: Optional[asyncio.Queue] = None, elif join: raise ValueError("You must provide a `queue` to use kwarg `join`") - -T0 = TypeVar('T0') -T1 = TypeVar('T1') -T2 = TypeVar('T2') -T3 = TypeVar('T3') -T4 = TypeVar('T4') -T5 = TypeVar('T5') -T6 = TypeVar('T6') -T7 = TypeVar('T7') -T8 = TypeVar('T8') -T9 = TypeVar('T9') + +T0 = TypeVar("T0") +T1 = TypeVar("T1") +T2 = TypeVar("T2") +T3 = TypeVar("T3") +T4 = TypeVar("T4") +T5 = TypeVar("T5") +T6 = TypeVar("T6") +T7 = TypeVar("T7") +T8 = TypeVar("T8") +T9 = TypeVar("T9") + @overload -def as_yielded(*iterators: AsyncIterator[T]) -> AsyncIterator[T]:... +def as_yielded(*iterators: AsyncIterator[T]) -> AsyncIterator[T]: ... @overload -def as_yielded(iterator0: AsyncIterator[T0], iterator1: AsyncIterator[T1], iterator2: AsyncIterator[T2], iterator3: AsyncIterator[T3], iterator4: AsyncIterator[T4], iterator5: AsyncIterator[T5], iterator6: AsyncIterator[T6], iterator7: AsyncIterator[T7], iterator8: AsyncIterator[T8], iterator9: AsyncIterator[T9]) -> AsyncIterator[Union[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]]:... +def as_yielded( + iterator0: AsyncIterator[T0], + iterator1: AsyncIterator[T1], + iterator2: AsyncIterator[T2], + iterator3: AsyncIterator[T3], + iterator4: AsyncIterator[T4], + iterator5: AsyncIterator[T5], + iterator6: AsyncIterator[T6], + iterator7: AsyncIterator[T7], + iterator8: AsyncIterator[T8], + iterator9: AsyncIterator[T9], +) -> AsyncIterator[Union[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]]: ... @overload -def as_yielded(iterator0: AsyncIterator[T0], iterator1: AsyncIterator[T1], iterator2: AsyncIterator[T2], iterator3: AsyncIterator[T3], iterator4: AsyncIterator[T4], iterator5: AsyncIterator[T5], iterator6: AsyncIterator[T6], iterator7: AsyncIterator[T7], iterator8: AsyncIterator[T8]) -> AsyncIterator[Union[T0, T1, T2, T3, T4, T5, T6, T7, T8]]:... +def as_yielded( + iterator0: AsyncIterator[T0], + iterator1: AsyncIterator[T1], + iterator2: AsyncIterator[T2], + iterator3: AsyncIterator[T3], + iterator4: AsyncIterator[T4], + iterator5: AsyncIterator[T5], + iterator6: AsyncIterator[T6], + iterator7: AsyncIterator[T7], + iterator8: AsyncIterator[T8], +) -> AsyncIterator[Union[T0, T1, T2, T3, T4, T5, T6, T7, T8]]: ... @overload -def as_yielded(iterator0: AsyncIterator[T0], iterator1: AsyncIterator[T1], iterator2: AsyncIterator[T2], iterator3: AsyncIterator[T3], iterator4: AsyncIterator[T4], iterator5: AsyncIterator[T5], iterator6: AsyncIterator[T6], iterator7: AsyncIterator[T7]) -> AsyncIterator[Union[T0, T1, T2, T3, T4, T5, T6, T7]]:... +def as_yielded( + iterator0: AsyncIterator[T0], + iterator1: AsyncIterator[T1], + iterator2: AsyncIterator[T2], + iterator3: AsyncIterator[T3], + iterator4: AsyncIterator[T4], + iterator5: AsyncIterator[T5], + iterator6: AsyncIterator[T6], + iterator7: AsyncIterator[T7], +) -> AsyncIterator[Union[T0, T1, T2, T3, T4, T5, T6, T7]]: ... @overload -def as_yielded(iterator0: AsyncIterator[T0], iterator1: AsyncIterator[T1], iterator2: AsyncIterator[T2], iterator3: AsyncIterator[T3], iterator4: AsyncIterator[T4], iterator5: AsyncIterator[T5], iterator6: AsyncIterator[T6]) -> AsyncIterator[Union[T0, T1, T2, T3, T4, T5, T6]]:... +def as_yielded( + iterator0: AsyncIterator[T0], + iterator1: AsyncIterator[T1], + iterator2: AsyncIterator[T2], + iterator3: AsyncIterator[T3], + iterator4: AsyncIterator[T4], + iterator5: AsyncIterator[T5], + iterator6: AsyncIterator[T6], +) -> AsyncIterator[Union[T0, T1, T2, T3, T4, T5, T6]]: ... @overload -def as_yielded(iterator0: AsyncIterator[T0], iterator1: AsyncIterator[T1], iterator2: AsyncIterator[T2], iterator3: AsyncIterator[T3], iterator4: AsyncIterator[T4], iterator5: AsyncIterator[T5]) -> AsyncIterator[Union[T0, T1, T2, T3, T4, T5]]:... +def as_yielded( + iterator0: AsyncIterator[T0], + iterator1: AsyncIterator[T1], + iterator2: AsyncIterator[T2], + iterator3: AsyncIterator[T3], + iterator4: AsyncIterator[T4], + iterator5: AsyncIterator[T5], +) -> AsyncIterator[Union[T0, T1, T2, T3, T4, T5]]: ... @overload -def as_yielded(iterator0: AsyncIterator[T0], iterator1: AsyncIterator[T1], iterator2: AsyncIterator[T2], iterator3: AsyncIterator[T3], iterator4: AsyncIterator[T4]) -> AsyncIterator[Union[T0, T1, T2, T3, T4]]:... +def as_yielded( + iterator0: AsyncIterator[T0], + iterator1: AsyncIterator[T1], + iterator2: AsyncIterator[T2], + iterator3: AsyncIterator[T3], + iterator4: AsyncIterator[T4], +) -> AsyncIterator[Union[T0, T1, T2, T3, T4]]: ... @overload -def as_yielded(iterator0: AsyncIterator[T0], iterator1: AsyncIterator[T1], iterator2: AsyncIterator[T2], iterator3: AsyncIterator[T3]) -> AsyncIterator[Union[T0, T1, T2, T3]]:... +def as_yielded( + iterator0: AsyncIterator[T0], + iterator1: AsyncIterator[T1], + iterator2: AsyncIterator[T2], + iterator3: AsyncIterator[T3], +) -> AsyncIterator[Union[T0, T1, T2, T3]]: ... @overload -def as_yielded(iterator0: AsyncIterator[T0], iterator1: AsyncIterator[T1], iterator2: AsyncIterator[T2]) -> AsyncIterator[Union[T0, T1, T2]]:... +def as_yielded( + iterator0: AsyncIterator[T0], + iterator1: AsyncIterator[T1], + iterator2: AsyncIterator[T2], +) -> AsyncIterator[Union[T0, T1, T2]]: ... @overload -def as_yielded(iterator0: AsyncIterator[T0], iterator1: AsyncIterator[T1]) -> AsyncIterator[Union[T0, T1]]:... +def as_yielded( + iterator0: AsyncIterator[T0], iterator1: AsyncIterator[T1] +) -> AsyncIterator[Union[T0, T1]]: ... @overload -def as_yielded(iterator0: AsyncIterator[T0], iterator1: AsyncIterator[T1], iterator2: AsyncIterator[T2], *iterators: AsyncIterator[T]) -> AsyncIterator[Union[T0, T1, T2, T]]:... +def as_yielded( + iterator0: AsyncIterator[T0], + iterator1: AsyncIterator[T1], + iterator2: AsyncIterator[T2], + *iterators: AsyncIterator[T], +) -> AsyncIterator[Union[T0, T1, T2, T]]: ... async def as_yielded(*iterators: AsyncIterator[T]) -> AsyncIterator[T]: # type: ignore [misc] """ Merges multiple async iterators into a single async iterator that yields items as they become available from any of the source iterators. @@ -108,20 +182,20 @@ async def as_yielded(*iterators: AsyncIterator[T]) -> AsyncIterator[T]: # type: """ # hypothesis idea: _Done should never be exposed to user, works for all desired input types queue: Queue[Union[T, _Done]] = Queue() - + def _as_yielded_done_callback(t: asyncio.Task) -> None: if t.cancelled(): return - if e := t.exception(): + if e := t.exception(): traceback.extract_stack traceback.clear_frames(e.__traceback__) queue.put_nowait(_Done(e)) task = asyncio.create_task( - coro=exhaust_iterators(iterators, queue=queue, join=True), + coro=exhaust_iterators(iterators, queue=queue, join=True), name=f"a_sync.as_yielded queue populating task for {iterators}", ) - + task.add_done_callback(_as_yielded_done_callback) while not task.done(): @@ -139,24 +213,29 @@ def _as_yielded_done_callback(t: asyncio.Task) -> None: del task del queue if item._exc: - raise type(item._exc)(*item._exc.args).with_traceback(item._tb) from item._exc.__cause__ + raise type(item._exc)(*item._exc.args).with_traceback( + item._tb + ) from item._exc.__cause__ return yield item # ensure it isn't done due to an internal exception await task - + class _Done: """ A sentinel class used to signal the completion of processing in the as_yielded function. This class acts as a marker to indicate that all items have been processed and the asynchronous iteration can be concluded. It is used internally within the implementation of as_yielded to efficiently manage the termination of the iteration process once all source iterators have been exhausted. """ + def __init__(self, exc: Optional[Exception] = None) -> None: self._exc = exc + @property def _tb(self) -> TracebackType: return self._exc.__traceback__ # type: ignore [union-attr] + __all__ = ["as_yielded", "exhaust_iterator", "exhaust_iterators"] diff --git a/docs/conf.py b/docs/conf.py index bffbce65..9e0eb7a8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -9,64 +9,68 @@ import os import sys -project = 'ez-a-sync' -copyright = '2024, BobTheBuidler' -author = 'BobTheBuidler' +project = "ez-a-sync" +copyright = "2024, BobTheBuidler" +author = "BobTheBuidler" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.napoleon', - 'sphinx.ext.intersphinx', - 'sphinx.ext.viewcode', - 'a_sync.sphinx.ext', + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx.ext.intersphinx", + "sphinx.ext.viewcode", + "a_sync.sphinx.ext", ] -templates_path = ['_templates'] -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +templates_path = ["_templates"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'typing_extensions': ('https://typing-extensions.readthedocs.io/en/latest/', None), + "python": ("https://docs.python.org/3", None), + "typing_extensions": ("https://typing-extensions.readthedocs.io/en/latest/", None), } # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = 'sphinx_rtd_theme' -html_static_path = ['_static'] +html_theme = "sphinx_rtd_theme" +html_static_path = ["_static"] autodoc_default_options = { - 'undoc-members': True, - 'private-members': True, - 'special-members': ','.join([ - '__init__', - '__call__', - '__getitem__', - '__iter__', - '__aiter__', - '__next__', - '__anext__', - '_Done', - '_AsyncExecutorMixin', - ]), - 'inherited-members': True, - 'member-order': 'groupwise', + "undoc-members": True, + "private-members": True, + "special-members": ",".join( + [ + "__init__", + "__call__", + "__getitem__", + "__iter__", + "__aiter__", + "__next__", + "__anext__", + "_Done", + "_AsyncExecutorMixin", + ] + ), + "inherited-members": True, + "member-order": "groupwise", # hide private methods that aren't relevant to us here - 'exclude-members': ','.join([ - '__new__', - '_abc_impl', - '_fget', - '_fset', - '_fdel', - '_ASyncSingletonMeta__instances', - '_ASyncSingletonMeta__lock', - '_is_protocol', - '_is_runtime_protocol', - '_materialized', - ]), + "exclude-members": ",".join( + [ + "__new__", + "_abc_impl", + "_fget", + "_fset", + "_fdel", + "_ASyncSingletonMeta__instances", + "_ASyncSingletonMeta__lock", + "_is_protocol", + "_is_runtime_protocol", + "_materialized", + ] + ), } autodoc_typehints = "description" @@ -75,7 +79,7 @@ automodule_generate_module_stub = True -sys.path.insert(0, os.path.abspath('./a_sync')) +sys.path.insert(0, os.path.abspath("./a_sync")) SKIP_MODULES = [ "a_sync.a_sync._kwargs", @@ -87,16 +91,24 @@ "a_sync.utils.iterators", ] + def skip_undesired_members(app, what, name, obj, skip, options): # skip some submodules (not sure if this works right or if its even desired) - if what == "module" and getattr(obj, '__name__', None) in SKIP_MODULES: + if what == "module" and getattr(obj, "__name__", None) in SKIP_MODULES: skip = True - + # Skip the __init__, __str__, __getattribute__, args, and with_traceback members of all Exceptions - if issubclass(getattr(obj, '__objclass__', type), BaseException) and name in ["__init__", "__str__", "__getattribute__", "args", "with_traceback"]: + if issubclass(getattr(obj, "__objclass__", type), BaseException) and name in [ + "__init__", + "__str__", + "__getattribute__", + "args", + "with_traceback", + ]: return True - + return skip + def setup(sphinx): - sphinx.connect("autodoc-skip-member", skip_undesired_members) \ No newline at end of file + sphinx.connect("autodoc-skip-member", skip_undesired_members) diff --git a/pyproject.yaml b/pyproject.yaml new file mode 100644 index 00000000..aa4949aa --- /dev/null +++ b/pyproject.yaml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 100 diff --git a/setup.py b/setup.py index aec248ba..e0feaaac 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ requirements = list(map(str.strip, f.read().split("\n")))[:-1] setup( - name='ez-a-sync', + name="ez-a-sync", packages=find_packages(), use_scm_version={ "root": ".", @@ -12,14 +12,14 @@ "local_scheme": "no-local-version", "version_scheme": "python-simplified-semver", }, - description='A library that makes it easy to define objects that can be used for both sync and async use cases.', - author='BobTheBuidler', - author_email='bobthebuidlerdefi@gmail.com', - url='https://github.com/BobTheBuidler/a-sync', - license='MIT', + description="A library that makes it easy to define objects that can be used for both sync and async use cases.", + author="BobTheBuidler", + author_email="bobthebuidlerdefi@gmail.com", + url="https://github.com/BobTheBuidler/a-sync", + license="MIT", install_requires=requirements, setup_requires=[ - 'setuptools_scm', + "setuptools_scm", ], python_requires=">=3.8,<3.13", package_data={ diff --git a/tests/conftest.py b/tests/conftest.py index 27c6d5f0..328fb1ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,3 @@ - import sys, os -sys.path.insert(0, os.path.abspath('.')) \ No newline at end of file +sys.path.insert(0, os.path.abspath(".")) diff --git a/tests/executor.py b/tests/executor.py index d730d694..9a1b231a 100644 --- a/tests/executor.py +++ b/tests/executor.py @@ -1,9 +1,12 @@ def work(): import time + time.sleep(5) + from a_sync import ProcessPoolExecutor + @pytest.marks.asyncio async def test_executor(): executor = ProcessPoolExecutor(6) diff --git a/tests/fixtures.py b/tests/fixtures.py index 96d74638..2da6c7f0 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,4 +1,3 @@ - import asyncio import time from threading import current_thread, main_thread @@ -11,97 +10,124 @@ from a_sync.a_sync._meta import ASyncMeta, ASyncSingletonMeta from a_sync.a_sync.singleton import ASyncGenericSingleton -increment = pytest.mark.parametrize('i', range(10)) +increment = pytest.mark.parametrize("i", range(10)) + + +class WrongThreadError(Exception): ... -class WrongThreadError(Exception): - ... class TestClass(ASyncBase): def __init__(self, v: int, sync: bool = False): self.v = v self.sync = sync - + async def test_fn(self) -> int: if self.sync == False and main_thread() != current_thread(): - raise WrongThreadError('This should be running on an event loop in the main thread.') + raise WrongThreadError( + "This should be running on an event loop in the main thread." + ) elif self.sync == True and main_thread() != current_thread(): - raise WrongThreadError('This should be awaited in the main thread') + raise WrongThreadError("This should be awaited in the main thread") return self.v - + @a_sync.aka.property async def test_property(self) -> int: if self.sync == False and main_thread() != current_thread(): - raise WrongThreadError('This should be running on an event loop in the main thread.') + raise WrongThreadError( + "This should be running on an event loop in the main thread." + ) elif self.sync == True and main_thread() != current_thread(): - raise WrongThreadError('This should be awaited in the main thread') + raise WrongThreadError("This should be awaited in the main thread") return self.v * 2 - + @a_sync.alias.cached_property async def test_cached_property(self) -> int: if self.sync == False and main_thread() != current_thread(): - raise WrongThreadError('This should be running on an event loop in the main thread.') + raise WrongThreadError( + "This should be running on an event loop in the main thread." + ) elif self.sync == True and main_thread() != current_thread(): - raise WrongThreadError('This should be awaited in the main thread') + raise WrongThreadError("This should be awaited in the main thread") await asyncio.sleep(2) return self.v * 3 + class TestSync(ASyncBase): main = main_thread() + def __init__(self, v: int, sync: bool): self.v = v self.sync = sync - + def test_fn(self) -> int: # Sync bound methods are actually async functions that are run in an executor and awaited if self.sync == False and main_thread() == current_thread(): - raise WrongThreadError('This should be running in an executor, not the main thread.') + raise WrongThreadError( + "This should be running in an executor, not the main thread." + ) elif self.sync == True and main_thread() != current_thread(): - raise WrongThreadError('This should be running synchronously in the main thread') + raise WrongThreadError( + "This should be running synchronously in the main thread" + ) return self.v - + @a_sync.aka.property def test_property(self) -> int: if self.sync == False and main_thread() == current_thread(): - raise WrongThreadError('This should be running in an executor, not the main thread.') + raise WrongThreadError( + "This should be running in an executor, not the main thread." + ) if self.sync == True and main_thread() == current_thread(): # Sync properties are actually async functions that are run in an executor and awaited - raise WrongThreadError('This should be running in an executor, not the main thread.') + raise WrongThreadError( + "This should be running in an executor, not the main thread." + ) return self.v * 2 - + @a_sync.alias.cached_property def test_cached_property(self) -> int: if self.sync == False and main_thread() == current_thread(): - raise WrongThreadError('This should be running in an executor, not the main thread.') + raise WrongThreadError( + "This should be running in an executor, not the main thread." + ) if self.sync == True and main_thread() == current_thread(): # Sync properties are actually async functions that are run in an executor and awaited - raise WrongThreadError('This should be running in an executor, not the main thread.') + raise WrongThreadError( + "This should be running in an executor, not the main thread." + ) time.sleep(2) return self.v * 3 + class TestLimiter(TestClass): limiter = 1 - + + class TestInheritor(TestClass): pass + class TestMeta(TestClass, metaclass=ASyncMeta): pass + class TestSingleton(ASyncGenericSingleton, TestClass): runs_per_minute = 100 pass + class TestSingletonMeta(TestClass, metaclass=ASyncSingletonMeta): semaphore = 1 pass + class TestSemaphore(ASyncBase): - #semaphore=1 # NOTE: this is detected propely by undecorated test_fn but not the properties - + # semaphore=1 # NOTE: this is detected propely by undecorated test_fn but not the properties + def __init__(self, v: int, sync: bool): self.v = v self.sync = sync - + # spec on class and function both working @a_sync.a_sync(semaphore=1) async def test_fn(self) -> int: @@ -109,11 +135,11 @@ async def test_fn(self) -> int: return self.v # spec on class, function, property all working - @a_sync.aka.property('async', semaphore=1) + @a_sync.aka.property("async", semaphore=1) async def test_property(self) -> int: await asyncio.sleep(1) return self.v * 2 - + # spec on class, function, property all working @a_sync.alias.cached_property(semaphore=50) async def test_cached_property(self) -> int: @@ -121,7 +147,7 @@ async def test_cached_property(self) -> int: return self.v * 3 -def _test_kwargs(fn, default: Literal['sync','async',None]): +def _test_kwargs(fn, default: Literal["sync", "async", None]): # force async assert asyncio.get_event_loop().run_until_complete(fn(sync=False)) == 2 assert asyncio.get_event_loop().run_until_complete(fn(asynchronous=True)) == 2 @@ -132,18 +158,21 @@ def _test_kwargs(fn, default: Literal['sync','async',None]): assert asyncio.get_event_loop().run_until_complete(fn(asynchronous=False)) == 2 assert fn(sync=True) == 2 assert fn(asynchronous=False) == 2 - if default == 'sync': + if default == "sync": assert fn() == 2 - elif default == 'async': + elif default == "async": assert asyncio.get_event_loop().run_until_complete(fn()) == 2 + async def sample_task(n): await asyncio.sleep(0.01) return n + async def timeout_task(n): await asyncio.sleep(0.1) return n + async def sample_exc(n): - raise ValueError("Sample error") \ No newline at end of file + raise ValueError("Sample error") diff --git a/tests/test_abstract.py b/tests/test_abstract.py index d2b7f2af..6f67be58 100644 --- a/tests/test_abstract.py +++ b/tests/test_abstract.py @@ -1,17 +1,20 @@ - import sys import pytest from a_sync.a_sync.abstract import ASyncABC -_methods = '__a_sync_default_mode__', '__a_sync_flag_name__', '__a_sync_flag_value__' +_methods = "__a_sync_default_mode__", "__a_sync_flag_name__", "__a_sync_flag_value__" if sys.version_info >= (3, 12): _MIDDLE = "without an implementation for abstract methods" _methods = (f"'{method}'" for method in _methods) else: _MIDDLE = "with abstract methods" + def test_abc_direct_init(): - with pytest.raises(TypeError, match=f"Can't instantiate abstract class ASyncABC {_MIDDLE} {', '.join(_methods)}"): - ASyncABC() \ No newline at end of file + with pytest.raises( + TypeError, + match=f"Can't instantiate abstract class ASyncABC {_MIDDLE} {', '.join(_methods)}", + ): + ASyncABC() diff --git a/tests/test_as_completed.py b/tests/test_as_completed.py index eeda2a68..346f388a 100644 --- a/tests/test_as_completed.py +++ b/tests/test_as_completed.py @@ -1,4 +1,3 @@ - import asyncio import a_sync @@ -11,7 +10,10 @@ async def test_as_completed_with_awaitables(): tasks = [sample_task(i) for i in range(5)] results = [await result for result in a_sync.as_completed(tasks, aiter=False)] - assert sorted(results) == list(range(5)), "Results should be in ascending order from 0 to 4" + assert sorted(results) == list( + range(5) + ), "Results should be in ascending order from 0 to 4" + @pytest.mark.asyncio_cooperative async def test_as_completed_with_awaitables_aiter(): @@ -19,42 +21,61 @@ async def test_as_completed_with_awaitables_aiter(): results = [] async for result in a_sync.as_completed(tasks, aiter=True): results.append(result) - assert sorted(results) == list(range(5)), "Results should be in ascending order from 0 to 4" + assert sorted(results) == list( + range(5) + ), "Results should be in ascending order from 0 to 4" + @pytest.mark.asyncio_cooperative async def test_as_completed_with_mapping(): - tasks = {'task1': sample_task(1), 'task2': sample_task(2)} + tasks = {"task1": sample_task(1), "task2": sample_task(2)} results = {} for result in a_sync.as_completed(tasks, aiter=False): key, value = await result results[key] = value - assert results == {'task1': 1, 'task2': 2}, "Results should match the input mapping" + assert results == {"task1": 1, "task2": 2}, "Results should match the input mapping" + @pytest.mark.asyncio_cooperative async def test_as_completed_with_mapping_aiter(): - tasks = {'task1': sample_task(1), 'task2': sample_task(2)} + tasks = {"task1": sample_task(1), "task2": sample_task(2)} results = {} async for key, result in a_sync.as_completed(tasks, aiter=True): results[key] = result - assert results == {'task1': 1, 'task2': 2}, "Results should match the input mapping" + assert results == {"task1": 1, "task2": 2}, "Results should match the input mapping" + @pytest.mark.asyncio_cooperative async def test_as_completed_with_timeout(): tasks = [timeout_task(i) for i in range(2)] with pytest.raises(asyncio.TimeoutError): - [await result for result in a_sync.as_completed(tasks, aiter=False, timeout=0.05)] + [ + await result + for result in a_sync.as_completed(tasks, aiter=False, timeout=0.05) + ] + @pytest.mark.asyncio_cooperative async def test_as_completed_with_timeout_aiter(): tasks = [timeout_task(i) for i in range(2)] with pytest.raises(asyncio.TimeoutError): - [result async for result in a_sync.as_completed(tasks, aiter=True, timeout=0.05)] + [ + result + async for result in a_sync.as_completed(tasks, aiter=True, timeout=0.05) + ] + @pytest.mark.asyncio_cooperative async def test_as_completed_return_exceptions(): tasks = [sample_exc(i) for i in range(1)] - results = [await result for result in a_sync.as_completed(tasks, aiter=False, return_exceptions=True)] - assert isinstance(results[0], ValueError), f"The result should be an exception {results}" + results = [ + await result + for result in a_sync.as_completed(tasks, aiter=False, return_exceptions=True) + ] + assert isinstance( + results[0], ValueError + ), f"The result should be an exception {results}" + @pytest.mark.asyncio_cooperative async def test_as_completed_return_exceptions_aiter(): @@ -64,11 +85,17 @@ async def test_as_completed_return_exceptions_aiter(): results.append(result) assert isinstance(results[0], ValueError), "The result should be an exception" + @pytest.mark.asyncio_cooperative async def test_as_completed_with_tqdm_disabled(): tasks = [sample_task(i) for i in range(5)] - results = [await result for result in a_sync.as_completed(tasks, aiter=False, tqdm=False)] - assert sorted(results) == list(range(5)), "Results should be in ascending order from 0 to 4" + results = [ + await result for result in a_sync.as_completed(tasks, aiter=False, tqdm=False) + ] + assert sorted(results) == list( + range(5) + ), "Results should be in ascending order from 0 to 4" + @pytest.mark.asyncio_cooperative async def test_as_completed_with_tqdm_disabled_aiter(): @@ -76,23 +103,29 @@ async def test_as_completed_with_tqdm_disabled_aiter(): results = [] async for result in a_sync.as_completed(tasks, aiter=True, tqdm=False): results.append(result) - assert sorted(results) == list(range(5)), "Results should be in ascending order from 0 to 4" + assert sorted(results) == list( + range(5) + ), "Results should be in ascending order from 0 to 4" + @pytest.mark.asyncio_cooperative async def test_as_completed_with_mapping_and_return_exceptions(): - tasks = {'task1': sample_exc(1), 'task2': sample_task(2)} + tasks = {"task1": sample_exc(1), "task2": sample_task(2)} results = {} for result in a_sync.as_completed(tasks, return_exceptions=True, aiter=False): key, value = await result results[key] = value - assert isinstance(results['task1'], ValueError), "Result should be ValueError" - assert results['task2'] == 2, "Results should match the input mapping" + assert isinstance(results["task1"], ValueError), "Result should be ValueError" + assert results["task2"] == 2, "Results should match the input mapping" + @pytest.mark.asyncio_cooperative async def test_as_completed_with_mapping_and_return_exceptions_aiter(): - tasks = {'task1': sample_exc(1), 'task2': sample_task(2)} + tasks = {"task1": sample_exc(1), "task2": sample_task(2)} results = {} - async for key, result in a_sync.as_completed(tasks, return_exceptions=True, aiter=True): + async for key, result in a_sync.as_completed( + tasks, return_exceptions=True, aiter=True + ): results[key] = result - assert isinstance(results['task1'], ValueError), "Result should be ValueError" - assert results['task2'] == 2, "Results should match the input mapping" \ No newline at end of file + assert isinstance(results["task1"], ValueError), "Result should be ValueError" + assert results["task2"] == 2, "Results should match the input mapping" diff --git a/tests/test_base.py b/tests/test_base.py index b9c431f1..bc842c8f 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,4 +1,3 @@ - import asyncio import time @@ -9,7 +8,14 @@ from a_sync.a_sync._meta import ASyncMeta from a_sync.a_sync.method import ASyncBoundMethodAsyncDefault from a_sync.exceptions import SyncModeInAsyncContextError -from tests.fixtures import TestClass, TestInheritor, TestMeta, increment, TestSync, WrongThreadError +from tests.fixtures import ( + TestClass, + TestInheritor, + TestMeta, + increment, + TestSync, + WrongThreadError, +) def test_base_direct_init(): @@ -17,9 +23,10 @@ def test_base_direct_init(): ASyncGenericBase() -classes = pytest.mark.parametrize('cls', [TestClass, TestSync, TestInheritor, TestMeta]) +classes = pytest.mark.parametrize("cls", [TestClass, TestSync, TestInheritor, TestMeta]) + +both_modes = pytest.mark.parametrize("sync", [True, False]) -both_modes = pytest.mark.parametrize('sync', [True, False]) @classes @both_modes @@ -28,12 +35,13 @@ def test_inheritance(cls, sync: bool): assert isinstance(instance, ASyncGenericBase) assert isinstance(instance.__class__, ASyncMeta) + @classes @increment def test_method_sync(cls: type, i: int): sync_instance = cls(i, sync=True) assert sync_instance.test_fn() == i - + # Can we override with kwargs? assert sync_instance.test_fn(sync=True) == i assert sync_instance.test_fn(asynchronous=False) == i @@ -42,12 +50,27 @@ def test_method_sync(cls: type, i: int): if isinstance(sync_instance, TestSync): # this raises an assertion error inside of the test_fn execution. this is okay. with pytest.raises(WrongThreadError): - asyncio.get_event_loop().run_until_complete(sync_instance.test_fn(sync=False)) + asyncio.get_event_loop().run_until_complete( + sync_instance.test_fn(sync=False) + ) with pytest.raises(WrongThreadError): - asyncio.get_event_loop().run_until_complete(sync_instance.test_fn(asynchronous=True)) + asyncio.get_event_loop().run_until_complete( + sync_instance.test_fn(asynchronous=True) + ) else: - assert isinstance(asyncio.get_event_loop().run_until_complete(sync_instance.test_fn(sync=False)), int) - assert isinstance(asyncio.get_event_loop().run_until_complete(sync_instance.test_fn(asynchronous=True)), int) + assert isinstance( + asyncio.get_event_loop().run_until_complete( + sync_instance.test_fn(sync=False) + ), + int, + ) + assert isinstance( + asyncio.get_event_loop().run_until_complete( + sync_instance.test_fn(asynchronous=True) + ), + int, + ) + @classes @increment @@ -58,7 +81,7 @@ async def test_method_async(cls: type, i: int): # this raises an assertion error inside of the test_fn execution. this is okay. with pytest.raises(WrongThreadError): assert await async_instance.test_fn() == i - + # Can we override with kwargs? with pytest.raises(WrongThreadError): async_instance.test_fn(sync=True) @@ -84,7 +107,7 @@ async def test_method_async(cls: type, i: int): def test_property_sync(cls: type, i: int): sync_instance = cls(i, sync=True) assert sync_instance.test_property == i * 2 - + # Can we access hidden methods for properties? getter = sync_instance.__test_property__ assert isinstance(getter, HiddenMethod), getter @@ -92,6 +115,7 @@ def test_property_sync(cls: type, i: int): assert asyncio.iscoroutine(getter_coro), getter_coro assert asyncio.get_event_loop().run_until_complete(getter_coro) == i * 2 + @classes @increment @pytest.mark.asyncio_cooperative @@ -117,7 +141,9 @@ def test_cached_property_sync(cls: type, i: int): assert sync_instance.test_cached_property == i * 3 assert isinstance(sync_instance.test_cached_property, int) duration = time.time() - start - assert duration < 3, "There is a 2 second sleep in 'test_cached_property' but it should only run once." + assert ( + duration < 3 + ), "There is a 2 second sleep in 'test_cached_property' but it should only run once." # Can we access hidden methods for properties? start = time.time() @@ -128,9 +154,17 @@ def test_cached_property_sync(cls: type, i: int): assert asyncio.get_event_loop().run_until_complete(getter_coro) == i * 3 # Can we override them too? - assert asyncio.get_event_loop().run_until_complete(sync_instance.__test_cached_property__(sync=False)) == i * 3 + assert ( + asyncio.get_event_loop().run_until_complete( + sync_instance.__test_cached_property__(sync=False) + ) + == i * 3 + ) duration = time.time() - start - assert duration < 3, "There is a 2 second sleep in 'test_cached_property' but it should only run once." + assert ( + duration < 3 + ), "There is a 2 second sleep in 'test_cached_property' but it should only run once." + @classes @increment @@ -146,7 +180,7 @@ async def test_cached_property_async(cls: type, i: int): getter_coro = getter() assert asyncio.iscoroutine(getter_coro), getter_coro assert await getter_coro == i * 3 - + # Can we override them too? with pytest.raises(SyncModeInAsyncContextError): getter(sync=True) @@ -156,7 +190,9 @@ async def test_cached_property_async(cls: type, i: int): duration = time.time() - start # For TestSync, the duration can be higher because the calls execute inside of a threadpool which limits the amount of concurrency. target_duration = 5 if isinstance(async_instance, TestSync) else 2.1 - assert duration < target_duration, "There is a 2 second sleep in 'test_cached_property' but it should only run once." + assert ( + duration < target_duration + ), "There is a 2 second sleep in 'test_cached_property' but it should only run once." @pytest.mark.asyncio_cooperative @@ -166,6 +202,7 @@ class AsyncContextManager(ASyncGenericBase): async def __aenter__(self): self.entered = True return self + async def __aexit__(self, exc_type, exc_val, exc_tb): self.exited = True @@ -173,6 +210,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): assert cm.entered assert cm.exited + def test_synchronous_context_manager(): # Can the implementation work with a context manager? @@ -180,6 +218,7 @@ class SyncContextManager(ASyncGenericBase): def __enter__(self): self.entered = True return self + def __exit__(self, exc_type, exc_val, exc_tb): self.exited = True @@ -187,33 +226,41 @@ def __exit__(self, exc_type, exc_val, exc_tb): assert cm.entered assert cm.exited + @pytest.mark.asyncio_cooperative async def test_asynchronous_iteration(): # Does the implementation screw anything up with aiteration? class ASyncObjectWithAiter(ASyncGenericBase): def __init__(self): self.count = 0 + def __aiter__(self): return self + async def __anext__(self): if self.count < 3: self.count += 1 return self.count raise StopAsyncIteration + assert [item async for item in ASyncObjectWithAiter()] == [1, 2, 3] + def test_synchronous_iteration(): # Does the implementation screw anything up with iteration? class ASyncObjectWithIter(ASyncGenericBase): def __init__(self): self.count = 0 + def __iter__(self): return self + def __next__(self): if self.count < 3: self.count += 1 return self.count raise StopIteration + assert list(ASyncObjectWithIter()) == [1, 2, 3] @@ -223,13 +270,15 @@ async def generate(self): yield 1 yield 2 + def test_bound_generator_meta_sync(): """Does the metaclass handle generator functions correctly?""" for _ in ClassWithGenFunc().generate(): assert isinstance(_, int) + @pytest.mark.asyncio_cooperative async def test_bound_generator_meta_async(): """Does the metaclass handle generator functions correctly?""" async for _ in ClassWithGenFunc().generate(): - assert isinstance(_, int) \ No newline at end of file + assert isinstance(_, int) diff --git a/tests/test_cache.py b/tests/test_cache.py index e52fa4aa..a72d1e84 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -8,63 +8,79 @@ def test_decorator_async_with_cache_type(): - @a_sync.a_sync(default='async', cache_type='memory') + @a_sync.a_sync(default="async", cache_type="memory") async def some_test_fn() -> int: return 2 + start = time() assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 duration = time() - start - assert duration < 1.5, "There is a 1 second sleep in this function but it should only run once." - _test_kwargs(some_test_fn, 'async') + assert ( + duration < 1.5 + ), "There is a 1 second sleep in this function but it should only run once." + _test_kwargs(some_test_fn, "async") + def test_decorator_async_with_cache_maxsize(): - @a_sync.a_sync(default='async', ram_cache_maxsize=100) + @a_sync.a_sync(default="async", ram_cache_maxsize=100) def some_test_fn() -> int: sleep(1) return 2 + start = time() assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 duration = time() - start - assert duration < 1.5, "There is a 1 second sleep in this function but it should only run once." - _test_kwargs(some_test_fn, 'async') + assert ( + duration < 1.5 + ), "There is a 1 second sleep in this function but it should only run once." + _test_kwargs(some_test_fn, "async") + # This will never succeed due to some task the ttl kwargs creates def test_decorator_async_with_cache_ttl(): - @a_sync.a_sync(default='async', cache_type='memory', ram_cache_ttl=5) + @a_sync.a_sync(default="async", cache_type="memory", ram_cache_ttl=5) async def some_test_fn() -> int: return 2 + assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 - _test_kwargs(some_test_fn, 'async') + _test_kwargs(some_test_fn, "async") def test_decorator_sync_with_cache_type(): # Fails - @a_sync.a_sync(default='sync', cache_type='memory') + @a_sync.a_sync(default="sync", cache_type="memory") async def some_test_fn() -> int: sleep(1) return 2 + start = time() assert some_test_fn() == 2 assert some_test_fn() == 2 assert some_test_fn() == 2 duration = time() - start - assert duration < 1.5, "There is a 1 second sleep in this function but it should only run once." - _test_kwargs(some_test_fn, 'sync') + assert ( + duration < 1.5 + ), "There is a 1 second sleep in this function but it should only run once." + _test_kwargs(some_test_fn, "sync") -@pytest.mark.skipif(True, reason='skipped manually') + +@pytest.mark.skipif(True, reason="skipped manually") def test_decorator_sync_with_cache_maxsize(): # Fails # TODO diagnose and fix - @a_sync.a_sync(default="sync", cache_type='memory') + @a_sync.a_sync(default="sync", cache_type="memory") def some_test_fn() -> int: sleep(1) return 2 + start = time() assert some_test_fn() == 2 assert some_test_fn() == 2 assert some_test_fn() == 2 duration = time() - start - assert duration < 1.5, "There is a 1 second sleep in this function but it should only run once." - _test_kwargs(some_test_fn, 'sync') \ No newline at end of file + assert ( + duration < 1.5 + ), "There is a 1 second sleep in this function but it should only run once." + _test_kwargs(some_test_fn, "sync") diff --git a/tests/test_decorator.py b/tests/test_decorator.py index a071aaf7..01c8fa27 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -1,4 +1,3 @@ - import asyncio import pytest @@ -6,66 +5,79 @@ from tests.fixtures import _test_kwargs - # ASYNC DEF def test_decorator_no_args(): @a_sync.a_sync async def some_test_fn() -> int: return 2 + assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 _test_kwargs(some_test_fn, None) - + @a_sync.a_sync() async def some_test_fn() -> int: return 2 + assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 _test_kwargs(some_test_fn, None) - + + def test_decorator_default_none_arg(): @a_sync.a_sync(None) async def some_test_fn() -> int: return 2 + assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 _test_kwargs(some_test_fn, None) - + + def test_decorator_default_none_kwarg(): @a_sync.a_sync(default=None) async def some_test_fn() -> int: return 2 + assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 _test_kwargs(some_test_fn, None) + def test_decorator_default_sync_arg(): - @a_sync.a_sync('sync') + @a_sync.a_sync("sync") async def some_test_fn() -> int: return 2 + with pytest.raises(TypeError): asyncio.get_event_loop().run_until_complete(some_test_fn()) assert some_test_fn() == 2 - _test_kwargs(some_test_fn, 'sync') + _test_kwargs(some_test_fn, "sync") + def test_decorator_default_sync_kwarg(): - @a_sync.a_sync(default='sync') + @a_sync.a_sync(default="sync") async def some_test_fn() -> int: return 2 + with pytest.raises(TypeError): asyncio.get_event_loop().run_until_complete(some_test_fn()) assert some_test_fn() == 2 - _test_kwargs(some_test_fn, 'sync') - + _test_kwargs(some_test_fn, "sync") + + def test_decorator_default_async_arg(): - @a_sync.a_sync('async') + @a_sync.a_sync("async") async def some_test_fn() -> int: return 2 + assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 - _test_kwargs(some_test_fn, 'async') - + _test_kwargs(some_test_fn, "async") + + def test_decorator_default_async_kwarg(): - @a_sync.a_sync(default='async') + @a_sync.a_sync(default="async") async def some_test_fn() -> int: return 2 + assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 - _test_kwargs(some_test_fn, 'async') + _test_kwargs(some_test_fn, "async") # SYNC DEF @@ -73,53 +85,67 @@ def test_sync_decorator_no_args(): @a_sync.a_sync def some_test_fn() -> int: return 2 + assert some_test_fn() == 2 _test_kwargs(some_test_fn, None) - + @a_sync.a_sync() def some_test_fn() -> int: return 2 + assert some_test_fn() == 2 _test_kwargs(some_test_fn, None) - + + def test_sync_decorator_default_none_arg(): @a_sync.a_sync(None) def some_test_fn() -> int: return 2 + assert some_test_fn() == 2 _test_kwargs(some_test_fn, None) - + + def test_sync_decorator_default_none_kwarg(): @a_sync.a_sync(default=None) def some_test_fn() -> int: return 2 + assert some_test_fn() == 2 _test_kwargs(some_test_fn, None) + def test_sync_decorator_default_sync_arg(): - @a_sync.a_sync('sync') + @a_sync.a_sync("sync") def some_test_fn() -> int: return 2 + assert some_test_fn() == 2 - _test_kwargs(some_test_fn, 'sync') + _test_kwargs(some_test_fn, "sync") + def test_sync_decorator_default_sync_kwarg(): - @a_sync.a_sync(default='sync') + @a_sync.a_sync(default="sync") def some_test_fn() -> int: return 2 + assert some_test_fn() == 2 - _test_kwargs(some_test_fn, 'sync') - + _test_kwargs(some_test_fn, "sync") + + def test_sync_decorator_default_async_arg(): - @a_sync.a_sync('async') + @a_sync.a_sync("async") def some_test_fn() -> int: return 2 + assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 - _test_kwargs(some_test_fn, 'async') - + _test_kwargs(some_test_fn, "async") + + def test_sync_decorator_default_async_kwarg(): - @a_sync.a_sync(default='async') + @a_sync.a_sync(default="async") def some_test_fn() -> int: return 2 + assert asyncio.get_event_loop().run_until_complete(some_test_fn()) == 2 - _test_kwargs(some_test_fn, 'async') + _test_kwargs(some_test_fn, "async") diff --git a/tests/test_executor.py b/tests/test_executor.py index 4f4dcf9e..621cbbca 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -3,15 +3,15 @@ import pytest -from a_sync.executor import (AsyncProcessPoolExecutor, - ProcessPoolExecutor) +from a_sync.executor import AsyncProcessPoolExecutor, ProcessPoolExecutor def do_work(i, kwarg=None): time.sleep(i) if kwarg: assert kwarg == i - + + def test_executor(): executor = ProcessPoolExecutor(1) assert isinstance(executor, AsyncProcessPoolExecutor) @@ -21,20 +21,15 @@ def test_executor(): fut = executor.submit(do_work, 3) assert isinstance(fut, asyncio.Future), fut.__dict__ asyncio.get_event_loop().run_until_complete(fut) - + # asyncio implementation cant handle kwargs :( with pytest.raises(TypeError): asyncio.get_event_loop().run_until_complete( asyncio.get_event_loop().run_in_executor(executor, do_work, 3, kwarg=3) ) - + # but our clean implementation can :) fut = executor.submit(do_work, 3, kwarg=3) - - asyncio.get_event_loop().run_until_complete( - executor.run(do_work, 3, kwarg=3) - ) - asyncio.get_event_loop().run_until_complete( - fut - ) - \ No newline at end of file + + asyncio.get_event_loop().run_until_complete(executor.run(do_work, 3, kwarg=3)) + asyncio.get_event_loop().run_until_complete(fut) diff --git a/tests/test_future.py b/tests/test_future.py index 0bace418..5cbe0c3c 100644 --- a/tests/test_future.py +++ b/tests/test_future.py @@ -1,4 +1,3 @@ - from a_sync.future import ASyncFuture @@ -8,90 +7,129 @@ def do_stuff(self, smth=None): raise ValueError return 1 + async def dct(): - return {1:2} + return {1: 2} + async def one() -> int: return 1 + async def two() -> int: return 2 + async def zero(): return 0 + def test_result(): assert ASyncFuture(one()).result() == 1 + + def test_add(): assert ASyncFuture(one()) + ASyncFuture(two()) == 3 assert ASyncFuture(one()) + 2 == 3 assert 1 + ASyncFuture(two()) == 3 + + def test_sum(): assert sum([ASyncFuture(one()), ASyncFuture(two())]) == 3 assert sum([ASyncFuture(one()), 2]) == 3 assert sum([1, ASyncFuture(two())]) == 3 + + def test_sub(): assert ASyncFuture(one()) - ASyncFuture(two()) == -1 assert ASyncFuture(one()) - 2 == -1 assert 1 - ASyncFuture(two()) == -1 + + def test_mul(): assert ASyncFuture(one()) * ASyncFuture(two()) == 2 assert ASyncFuture(one()) * 2 == 2 assert 1 * ASyncFuture(two()) == 2 + + def test_pow(): assert ASyncFuture(two()) ** ASyncFuture(two()) == 4 assert ASyncFuture(two()) ** 2 == 4 assert 2 ** ASyncFuture(two()) == 4 + + def test_realdiv(): assert ASyncFuture(one()) / ASyncFuture(two()) == 0.5 assert ASyncFuture(one()) / 2 == 0.5 assert 1 / ASyncFuture(two()) == 0.5 + + def test_floordiv(): assert ASyncFuture(one()) // ASyncFuture(two()) == 0 assert ASyncFuture(one()) // 2 == 0 assert 1 // ASyncFuture(two()) == 0 + + def test_gt(): assert not ASyncFuture(one()) > ASyncFuture(two()) assert ASyncFuture(two()) > ASyncFuture(one()) assert not ASyncFuture(one()) > ASyncFuture(one()) + + def test_gte(): assert not ASyncFuture(one()) >= ASyncFuture(two()) assert ASyncFuture(two()) >= ASyncFuture(one()) assert ASyncFuture(one()) >= ASyncFuture(one()) + + def test_lt(): assert ASyncFuture(one()) < ASyncFuture(two()) assert not ASyncFuture(two()) < ASyncFuture(one()) assert not ASyncFuture(one()) < ASyncFuture(one()) + + def test_lte(): assert ASyncFuture(one()) <= ASyncFuture(two()) assert not ASyncFuture(two()) <= ASyncFuture(one()) assert ASyncFuture(one()) <= ASyncFuture(one()) + + def test_bool(): assert bool(ASyncFuture(one())) == True assert bool(ASyncFuture(zero())) == False + + def test_float(): assert float(ASyncFuture(one())) == float(1) + + def test_str(): assert str(ASyncFuture(one())) == "1" + def test_getitem(): assert ASyncFuture(dct())[1] == 2 + # NOTE: does not work def test_setitem(): meta = ASyncFuture(dct()) print(meta) meta[3] = 4 - assert meta == {1:2, 3:4} + assert meta == {1: 2, 3: 4} assert meta[3] == 4 - + + async def stuff_doer(): return StuffDoer() + def test_getattr(): assert ASyncFuture(stuff_doer()).do_stuff() -import pytest + +import pytest + def test_multi_maths(): some_cool_evm_value = ASyncFuture(two()) @@ -100,8 +138,8 @@ def test_multi_maths(): other = ASyncFuture(two()) stuff = ASyncFuture(two()) idrk = ASyncFuture(one()) - output0 = some_cool_evm_value / scale ** some # 2 - output1 = other + stuff - idrk # 3 + output0 = some_cool_evm_value / scale**some # 2 + output1 = other + stuff - idrk # 3 output = output0 * output1 assert output0 < output1 assert not output0 > output1 @@ -113,4 +151,3 @@ def test_multi_maths(): assert output == 6 with pytest.raises(ValueError): ASyncFuture(stuff_doer()).do_stuff(6) - \ No newline at end of file diff --git a/tests/test_gather.py b/tests/test_gather.py index 8c6ccace..8a38bd8e 100644 --- a/tests/test_gather.py +++ b/tests/test_gather.py @@ -10,32 +10,38 @@ async def sample_task(number): await asyncio.sleep(0.1) return number * 2 + @pytest.mark.asyncio_cooperative async def test_gather_with_awaitables(): results = await gather(sample_task(1), sample_task(2), sample_task(3)) assert results == [2, 4, 6] + @pytest.mark.asyncio_cooperative async def test_gather_with_awaitables_and_tqdm(): results = await gather(sample_task(1), sample_task(2), sample_task(3), tqdm=True) assert results == [2, 4, 6] + @pytest.mark.asyncio_cooperative async def test_gather_with_mapping_and_tqdm(): - tasks = {'a': sample_task(1), 'b': sample_task(2), 'c': sample_task(3)} + tasks = {"a": sample_task(1), "b": sample_task(2), "c": sample_task(3)} results = await gather(tasks, tqdm=True) - assert results == {'a': 2, 'b': 4, 'c': 6} + assert results == {"a": 2, "b": 4, "c": 6} + @pytest.mark.asyncio_cooperative async def test_gather_with_exceptions(): with pytest.raises(ValueError): await gather(sample_exc(None)) + @pytest.mark.asyncio_cooperative async def test_gather_with_return_exceptions(): results = await gather(sample_exc(None), return_exceptions=True) assert isinstance(results[0], ValueError) + @pytest.mark.asyncio_cooperative async def test_gather_with_return_exceptions_and_tqdm(): results = await gather(sample_exc(None), return_exceptions=True, tqdm=True) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 736df860..e5da9e41 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -6,8 +6,10 @@ def test_get_event_loop(): assert get_event_loop() == asyncio.get_event_loop() + def test_get_event_loop_in_thread(): def task(): assert get_event_loop() == asyncio.get_event_loop() + loop = get_event_loop() loop.run_until_complete(loop.run_in_executor(None, task)) diff --git a/tests/test_iter.py b/tests/test_iter.py index 1143b562..41781914 100644 --- a/tests/test_iter.py +++ b/tests/test_iter.py @@ -1,4 +1,3 @@ - import asyncio import contextlib import pytest @@ -10,30 +9,35 @@ from a_sync.iter import ASyncIterable, ASyncIterator - test_both = pytest.mark.parametrize("cls_to_test", [ASyncIterable, ASyncIterator]) + @pytest.fixture def async_generator(): async def async_gen(i: int = 3): for i in range(i): yield i + yield async_gen + @pytest.fixture def async_generator_empty(): async def async_gen_empty(): if True: return yield + yield async_gen_empty + @pytest.fixture def async_error_generator(): async def async_err_gen(): yield 0 yield 1 raise ValueError("Simulated error") + return async_err_gen @@ -42,6 +46,7 @@ def test_wrap_types(cls_to_test, async_generator): assert isinstance(cls_to_test(async_generator()), cls_to_test) assert isinstance(cls_to_test.wrap(async_generator()), cls_to_test) + @test_both def test_sync(cls_to_test, async_generator): # sourcery skip: identity-comprehension, list-comprehension @@ -64,13 +69,17 @@ def test_sync(cls_to_test, async_generator): assert cls_to_test(async_generator()).materialized == [0, 1, 2] assert cls_to_test.wrap(async_generator()).materialized == [0, 1, 2] + @test_both @pytest.mark.asyncio_cooperative async def test_async(cls_to_test, async_generator): ait = cls_to_test(async_generator()) # comprehension - with pytest.raises(SyncModeInAsyncContextError, match="The event loop is already running. Try iterating using `async for` instead of `for`."): + with pytest.raises( + SyncModeInAsyncContextError, + match="The event loop is already running. Try iterating using `async for` instead of `for`.", + ): list(ait) assert [i async for i in ait] == [0, 1, 2] @@ -79,7 +88,10 @@ async def test_async(cls_to_test, async_generator): async for item in cls_to_test(async_generator()): result.append(item) assert result == [0, 1, 2] - with pytest.raises(SyncModeInAsyncContextError, match="The event loop is already running. Try iterating using `async for` instead of `for`."): + with pytest.raises( + SyncModeInAsyncContextError, + match="The event loop is already running. Try iterating using `async for` instead of `for`.", + ): for _ in cls_to_test(async_generator()): pass @@ -95,11 +107,15 @@ async def test_async(cls_to_test, async_generator): def test_sync_empty(cls_to_test, async_generator_empty): assert not list(cls_to_test(async_generator_empty())) + @test_both @pytest.mark.asyncio_cooperative async def test_async_empty(cls_to_test, async_generator_empty): ait = cls_to_test(async_generator_empty()) - with pytest.raises(SyncModeInAsyncContextError, match="The event loop is already running. Try iterating using `async for` instead of `for`."): + with pytest.raises( + SyncModeInAsyncContextError, + match="The event loop is already running. Try iterating using `async for` instead of `for`.", + ): list(ait) assert not [i async for i in ait] @@ -118,6 +134,7 @@ def test_sync_partial(cls_to_test, async_generator): remaining = list(iterator) assert remaining == [3, 4] if cls_to_test is ASyncIterator else [0, 1, 2, 3, 4] + @test_both @pytest.mark.asyncio_cooperative async def test_async_partial(cls_to_test, async_generator): @@ -146,6 +163,7 @@ def test_stop_iteration_sync(cls_to_test, async_generator): with pytest.raises(StopIteration): next(it) + @test_both @pytest.mark.asyncio_cooperative async def test_stop_iteration_async(cls_to_test, async_generator): @@ -162,15 +180,23 @@ async def test_stop_iteration_async(cls_to_test, async_generator): # Test decorator + def test_aiterable_decorated_func_sync(): - with pytest.raises(TypeError, match="`async_iterable` must be an AsyncIterable. You passed "): + with pytest.raises( + TypeError, match="`async_iterable` must be an AsyncIterable. You passed " + ): + @ASyncIterable.wrap async def decorated(): yield 0 - + + @pytest.mark.asyncio_cooperative async def test_aiterable_decorated_func_async(async_generator): - with pytest.raises(TypeError, match="`async_iterable` must be an AsyncIterable. You passed "): + with pytest.raises( + TypeError, match="`async_iterable` must be an AsyncIterable. You passed " + ): + @ASyncIterable.wrap async def decorated(): yield 0 @@ -181,16 +207,19 @@ def test_aiterator_decorated_func_sync(async_generator): async def decorated(): async for i in async_generator(): yield i + retval = decorated() assert isinstance(retval, ASyncIterator) assert list(retval) == [0, 1, 2] - + + @pytest.mark.asyncio_cooperative async def test_aiterator_decorated_func_async(async_generator): @ASyncIterator.wrap async def decorated(): async for i in async_generator(): yield i + retval = decorated() assert isinstance(retval, ASyncIterator) assert await retval == [0, 1, 2] @@ -198,14 +227,17 @@ async def decorated(): def test_aiterable_decorated_method_sync(): with pytest.raises(TypeError, match=""): + class Test: @ASyncIterable.wrap async def decorated(self): yield 0 - + + @pytest.mark.asyncio_cooperative async def test_aiterable_decorated_method_async(): with pytest.raises(TypeError, match=""): + class Test: @ASyncIterable.wrap async def decorated(self): @@ -218,10 +250,12 @@ class Test: async def decorated(self): async for i in async_generator(): yield i + retval = Test().decorated() assert isinstance(retval, ASyncIterator) assert list(retval) == [0, 1, 2] - + + @pytest.mark.asyncio_cooperative async def test_aiterator_decorated_method_async(async_generator): class Test: @@ -229,10 +263,11 @@ class Test: async def decorated(self): async for i in async_generator(): yield i + retval = Test().decorated() assert isinstance(retval, ASyncIterator) assert await retval == [0, 1, 2] - + @test_both def test_sync_error_handling(cls_to_test, async_error_generator): @@ -243,6 +278,7 @@ def test_sync_error_handling(cls_to_test, async_error_generator): # we still got some results though assert results == [0, 1] + @test_both @pytest.mark.asyncio_cooperative async def test_async_error_handling(cls_to_test, async_error_generator): @@ -257,11 +293,13 @@ async def test_async_error_handling(cls_to_test, async_error_generator): # Test failures + @test_both def test_sync_with_iterable(cls_to_test): with pytest.raises(TypeError): cls_to_test([0, 1, 2]) + @test_both @pytest.mark.asyncio_cooperative async def test_async_with_iterable(cls_to_test): @@ -271,24 +309,28 @@ async def test_async_with_iterable(cls_to_test): # Type check dunder methods + def test_async_iterable_iter_method(async_generator): ait = ASyncIterable(async_generator()) iterator = iter(ait) assert isinstance(iterator, Iterator) + def test_async_iterator_iter_method(async_generator): ait = ASyncIterator(async_generator()) iterator = iter(ait) assert iterator is ait # Should return self + @pytest.mark.asyncio_cooperative async def test_async_aiter_method(async_generator): ait = ASyncIterable(async_generator()) async_iterator = ait.__aiter__() assert isinstance(async_iterator, AsyncIterator) + @pytest.mark.asyncio_cooperative async def test_async_iterator_aiter_method(async_generator): ait = ASyncIterator(async_generator()) async_iterator = ait.__aiter__() - assert async_iterator is ait # Should return self \ No newline at end of file + assert async_iterator is ait # Should return self diff --git a/tests/test_limiter.py b/tests/test_limiter.py index 7ffaab12..e14b7d51 100644 --- a/tests/test_limiter.py +++ b/tests/test_limiter.py @@ -1,4 +1,3 @@ - import pytest from time import time from tests.fixtures import TestLimiter, increment @@ -9,18 +8,21 @@ # Maybe a problem with test suite interaction? # - semaphore modifier works fine with integer inputs + @increment @pytest.mark.asyncio_cooperative async def test_semaphore(i: int): instance = TestLimiter(i, sync=False) assert await instance.test_fn() == i - + + @increment @pytest.mark.asyncio_cooperative async def test_semaphore_property(i: int): instance = TestLimiter(i, sync=False) assert await instance.test_property == i * 2 - + + @increment @pytest.mark.asyncio_cooperative async def test_semaphore_cached_property(i: int): diff --git a/tests/test_meta.py b/tests/test_meta.py index 8fe53822..8ff41385 100644 --- a/tests/test_meta.py +++ b/tests/test_meta.py @@ -9,7 +9,8 @@ from a_sync.a_sync.singleton import ASyncGenericSingleton from tests.fixtures import TestSingleton, TestSingletonMeta, increment -classes = pytest.mark.parametrize('cls', [TestSingleton, TestSingletonMeta]) +classes = pytest.mark.parametrize("cls", [TestSingleton, TestSingletonMeta]) + @classes @increment @@ -24,7 +25,9 @@ def test_singleton_meta_sync(cls: type, i: int): assert sync_instance.test_cached_property == 0 assert isinstance(sync_instance.test_cached_property, int) duration = time.time() - start - assert duration < 3, "There is a 2 second sleep in 'test_cached_property' but it should only run once." + assert ( + duration < 3 + ), "There is a 2 second sleep in 'test_cached_property' but it should only run once." # Can we override with kwargs? val = asyncio.get_event_loop().run_until_complete(sync_instance.test_fn(sync=False)) @@ -41,9 +44,17 @@ def test_singleton_meta_sync(cls: type, i: int): getter_coro = getter() assert asyncio.get_event_loop().run_until_complete(getter_coro) == 0 # Can we override them too? - assert asyncio.get_event_loop().run_until_complete(sync_instance.__test_cached_property__(sync=False)) == 0 + assert ( + asyncio.get_event_loop().run_until_complete( + sync_instance.__test_cached_property__(sync=False) + ) + == 0 + ) duration = time.time() - start - assert duration < 3, "There is a 2 second sleep in 'test_cached_property' but it should only run once." + assert ( + duration < 3 + ), "There is a 2 second sleep in 'test_cached_property' but it should only run once." + @classes @increment @@ -59,12 +70,14 @@ async def test_singleton_meta_async(cls: type, i: int): assert await async_instance.test_cached_property == 0 assert isinstance(await async_instance.test_cached_property, int) duration = time.time() - start - assert duration < 3, "There is a 2 second sleep in 'test_cached_property' but it should only run once." + assert ( + duration < 3 + ), "There is a 2 second sleep in 'test_cached_property' but it should only run once." # Can we override with kwargs? with pytest.raises(RuntimeError): async_instance.test_fn(sync=True) - + # Can we access hidden methods for properties? assert await async_instance.__test_property__() == 0 assert await async_instance.__test_cached_property__() == 0 @@ -73,24 +86,25 @@ async def test_singleton_meta_async(cls: type, i: int): async_instance.__test_cached_property__(sync=True) - class TestUnspecified(ASyncGenericSingleton): def __init__(self, sync=True): self.sync = sync + def test_singleton_unspecified(): obj = TestUnspecified() assert obj.sync == True obj.test_attr = True newobj = TestUnspecified() - assert hasattr(newobj, 'test_attr') + assert hasattr(newobj, "test_attr") assert TestUnspecified(sync=True).sync == True assert TestUnspecified(sync=False).sync == False + def test_singleton_switching(): obj = TestUnspecified() assert obj.sync == True obj.test_attr = True newobj = TestUnspecified(sync=False) - assert not hasattr(newobj, 'test_attr') \ No newline at end of file + assert not hasattr(newobj, "test_attr") diff --git a/tests/test_modified.py b/tests/test_modified.py index 015136a8..a373abe3 100644 --- a/tests/test_modified.py +++ b/tests/test_modified.py @@ -1,16 +1,19 @@ - import a_sync + @a_sync.a_sync def sync_def(): pass + @a_sync.a_sync async def async_def(): pass + def test_sync_def_repr(): assert sync_def.__name__ == "sync_def" + def test_async_def_repr(): assert async_def.__name__ == "async_def" diff --git a/tests/test_semaphore.py b/tests/test_semaphore.py index 7161b218..450cfa9d 100644 --- a/tests/test_semaphore.py +++ b/tests/test_semaphore.py @@ -1,4 +1,3 @@ - import pytest from time import time from tests.fixtures import TestSemaphore, increment @@ -12,30 +11,38 @@ instance = TestSemaphore(1, sync=False) + @increment @pytest.mark.asyncio_cooperative async def test_semaphore(i: int): start = time() assert await instance.test_fn() == 1 duration = time() - start - assert i < 3 or duration > i # There is a 1 second sleep in this fn. If the semaphore is not working, all tests will complete in 1 second. - - + assert ( + i < 3 or duration > i + ) # There is a 1 second sleep in this fn. If the semaphore is not working, all tests will complete in 1 second. + + @increment @pytest.mark.asyncio_cooperative async def test_semaphore_property(i: int): start = time() assert await instance.test_property == 2 duration = time() - start - assert i < 3 or duration > i # There is a 1 second sleep in this fn. If the semaphore is not working, all tests will complete in 1 second. - + assert ( + i < 3 or duration > i + ) # There is a 1 second sleep in this fn. If the semaphore is not working, all tests will complete in 1 second. + + @increment @pytest.mark.asyncio_cooperative async def test_semaphore_cached_property(i: int): start = time() assert await instance.test_cached_property == 3 duration = time() - start - # There is a 1 second sleep in this fn but a semaphore override with a value of 50. + # There is a 1 second sleep in this fn but a semaphore override with a value of 50. # You can tell it worked correctly because the class-defined semaphore value is just one, whch would cause this test to fail if it were used. # If the override is not working, all tests will complete in just over 1 second. - assert i == 1 or duration < 1.4 # We increased the threshold from 1.05 to 1.4 to help tests pass on slow github runners + assert ( + i == 1 or duration < 1.4 + ) # We increased the threshold from 1.05 to 1.4 to help tests pass on slow github runners diff --git a/tests/test_singleton.py b/tests/test_singleton.py index 4d11c912..5a599ffb 100644 --- a/tests/test_singleton.py +++ b/tests/test_singleton.py @@ -1,20 +1,25 @@ - from a_sync.a_sync.singleton import ASyncGenericSingleton + def test_flag_predefined(): """We had a failure case where the subclass implementation assigned the flag value to the class and did not allow user to determine at init time""" + class Test(ASyncGenericSingleton): - sync=True - def __init__(self): - ... + sync = True + + def __init__(self): ... + Test() - class TestInherit(Test): - ... + + class TestInherit(Test): ... + TestInherit() class Test(ASyncGenericSingleton): - sync=False + sync = False + Test() - class TestInherit(Test): - ... - TestInherit() \ No newline at end of file + + class TestInherit(Test): ... + + TestInherit() diff --git a/tests/test_task.py b/tests/test_task.py index e2d942b4..a35dfa49 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -3,41 +3,51 @@ from a_sync import TaskMapping, create_task, exceptions + @pytest.mark.asyncio_cooperative async def test_create_task(): - await create_task(coro=asyncio.sleep(0), name='test') + await create_task(coro=asyncio.sleep(0), name="test") + @pytest.mark.asyncio_cooperative async def test_persistent_task(): check = False + async def task(): await asyncio.sleep(1) nonlocal check check = True + create_task(coro=task(), skip_gc_until_done=True) - # there is no local reference to the newly created task. does it still complete? + # there is no local reference to the newly created task. does it still complete? await asyncio.sleep(2) assert check is True + @pytest.mark.asyncio_cooperative async def test_pruning(): async def task(): return + create_task(coro=task(), skip_gc_until_done=True) await asyncio.sleep(0) # previously, it failed here create_task(coro=task(), skip_gc_until_done=True) + @pytest.mark.asyncio_cooperative async def test_task_mapping_init(): tasks = TaskMapping(_coro_fn) - assert tasks._wrapped_func is _coro_fn, f"{tasks._wrapped_func} , {_coro_fn}, {tasks._wrapped_func == _coro_fn}" + assert ( + tasks._wrapped_func is _coro_fn + ), f"{tasks._wrapped_func} , {_coro_fn}, {tasks._wrapped_func == _coro_fn}" assert tasks._wrapped_func_kwargs == {} assert tasks._name is None - tasks = TaskMapping(_coro_fn, name='test', kwarg0=1, kwarg1=None) - assert tasks._wrapped_func_kwargs == {'kwarg0': 1, 'kwarg1': None} + tasks = TaskMapping(_coro_fn, name="test", kwarg0=1, kwarg1=None) + assert tasks._wrapped_func_kwargs == {"kwarg0": 1, "kwarg1": None} assert tasks._name == "test" + @pytest.mark.asyncio_cooperative async def test_task_mapping(): tasks = TaskMapping(_coro_fn) @@ -52,9 +62,16 @@ async def test_task_mapping(): # can we await the mapping? assert await tasks == {0: "1", 1: "22"} # can we await one from scratch? - assert await TaskMapping(_coro_fn, range(5)) == {0: "1", 1: "22", 2: "333", 3: "4444", 4: "55555"} + assert await TaskMapping(_coro_fn, range(5)) == { + 0: "1", + 1: "22", + 2: "333", + 3: "4444", + 4: "55555", + } assert len(tasks) == 2 - + + @pytest.mark.asyncio_cooperative async def test_task_mapping_map_with_sync_iter(): tasks = TaskMapping(_coro_fn) @@ -69,9 +86,9 @@ async def test_task_mapping_map_with_sync_iter(): ... i += 1 tasks = TaskMapping(_coro_fn) - async for k in tasks.map(range(5), pop=False, yields='keys'): + async for k in tasks.map(range(5), pop=False, yields="keys"): assert isinstance(k, int) - + # test keys for k in tasks.keys(): assert isinstance(k, int) @@ -81,7 +98,7 @@ async def test_task_mapping_map_with_sync_iter(): assert isinstance(k, int) async for k in tasks.keys(): assert isinstance(k, int) - + # test values for v in tasks.values(): assert isinstance(v, asyncio.Future) @@ -92,7 +109,7 @@ async def test_task_mapping_map_with_sync_iter(): assert isinstance(v, str) async for v in tasks.values(): assert isinstance(v, str) - + # test items for k, v in tasks.items(): assert isinstance(k, int) @@ -106,12 +123,14 @@ async def test_task_mapping_map_with_sync_iter(): async for k, v in tasks.items(): assert isinstance(k, int) assert isinstance(v, str) - + + @pytest.mark.asyncio_cooperative async def test_task_mapping_map_with_async_iter(): async def async_iter(): for i in range(5): yield i + tasks = TaskMapping(_coro_fn) i = 0 async for k, v in tasks.map(async_iter()): @@ -124,9 +143,9 @@ async def async_iter(): ... i += 1 tasks = TaskMapping(_coro_fn) - async for k in tasks.map(async_iter(), pop=False, yields='keys'): + async for k in tasks.map(async_iter(), pop=False, yields="keys"): assert isinstance(k, int) - + # test keys for k in tasks.keys(): assert isinstance(k, int) @@ -138,9 +157,13 @@ async def async_iter(): assert isinstance(k, int) assert await tasks.keys().aiterbykeys() == list(range(5)) assert await tasks.keys().aiterbyvalues() == list(range(5)) - assert await tasks.keys().aiterbykeys(reverse=True) == sorted(range(5), reverse=True) - assert await tasks.keys().aiterbyvalues(reverse=True) == sorted(range(5), reverse=True) - + assert await tasks.keys().aiterbykeys(reverse=True) == sorted( + range(5), reverse=True + ) + assert await tasks.keys().aiterbyvalues(reverse=True) == sorted( + range(5), reverse=True + ) + # test values for v in tasks.values(): assert isinstance(v, asyncio.Future) @@ -153,9 +176,13 @@ async def async_iter(): assert isinstance(v, str) assert await tasks.values().aiterbykeys() == [str(i) * i for i in range(1, 6)] assert await tasks.values().aiterbyvalues() == [str(i) * i for i in range(1, 6)] - assert await tasks.values().aiterbykeys(reverse=True) == [str(i) * i for i in sorted(range(1, 6), reverse=True)] - assert await tasks.values().aiterbyvalues(reverse=True) == [str(i) * i for i in sorted(range(1, 6), reverse=True)] - + assert await tasks.values().aiterbykeys(reverse=True) == [ + str(i) * i for i in sorted(range(1, 6), reverse=True) + ] + assert await tasks.values().aiterbyvalues(reverse=True) == [ + str(i) * i for i in sorted(range(1, 6), reverse=True) + ] + # test items for k, v in tasks.items(): assert isinstance(k, int) @@ -169,20 +196,29 @@ async def async_iter(): async for k, v in tasks.items(): assert isinstance(k, int) assert isinstance(v, str) - assert await tasks.items().aiterbykeys() == [(i, str(i+1) * (i+1)) for i in range(5)] - assert await tasks.items().aiterbyvalues() == [(i, str(i+1) * (i+1)) for i in range(5)] - assert await tasks.items().aiterbykeys(reverse=True) == [(i, str(i+1) * (i+1)) for i in sorted(range(5), reverse=True)] - assert await tasks.items(pop=True).aiterbyvalues(reverse=True) == [(i, str(i+1) * (i+1)) for i in sorted(range(5), reverse=True)] + assert await tasks.items().aiterbykeys() == [ + (i, str(i + 1) * (i + 1)) for i in range(5) + ] + assert await tasks.items().aiterbyvalues() == [ + (i, str(i + 1) * (i + 1)) for i in range(5) + ] + assert await tasks.items().aiterbykeys(reverse=True) == [ + (i, str(i + 1) * (i + 1)) for i in sorted(range(5), reverse=True) + ] + assert await tasks.items(pop=True).aiterbyvalues(reverse=True) == [ + (i, str(i + 1) * (i + 1)) for i in sorted(range(5), reverse=True) + ] assert not tasks # did pop work? + def test_taskmapping_views_sync(): tasks = TaskMapping(_coro_fn, range(5)) - + # keys are currently empty until the loop has a chance to run assert len(tasks.keys()) == 0 assert len(tasks.values()) == 0 assert len(tasks.items()) == 0 - + tasks.gather() assert len(tasks.keys()) == 5 @@ -195,16 +231,17 @@ def test_taskmapping_views_sync(): # test values for v in tasks.values(): assert isinstance(v, asyncio.Future) - + # test items for k, v in tasks.items(): assert isinstance(k, int) assert isinstance(v, asyncio.Future) - + assert len(tasks.keys()) == 5 for k in tasks.keys(): assert isinstance(k, int) + async def _coro_fn(i: int) -> str: i += 1 return str(i) * i