diff --git a/a_sync/_bound.py b/a_sync/_bound.py index 0f908e1c..51b44ddb 100644 --- a/a_sync/_bound.py +++ b/a_sync/_bound.py @@ -63,7 +63,7 @@ def __init__(self, coro: Awaitable[T], property: Union[AsyncPropertyDescriptor[T self._property = property def __repr__(self) -> str: return f"<_PropertyGetter for {self._property}._get at {hex(id(self))}>" - def __await__(self) -> T: + def __await__(self) -> Generator[Any, None, T]: return self._coro.__await__() @overload diff --git a/a_sync/_typing.py b/a_sync/_typing.py index 61975886..b7846fbf 100644 --- a/a_sync/_typing.py +++ b/a_sync/_typing.py @@ -3,9 +3,10 @@ from concurrent.futures._base import Executor from decimal import Decimal from typing import (TYPE_CHECKING, Any, AsyncIterable, AsyncIterator, Awaitable, - Callable, DefaultDict, Dict, Generic, Iterable, Iterator, - List, Literal, Optional, Set, Tuple, Type, TypedDict, TypeVar, - Union, final, overload) + Callable, DefaultDict, Deque, Dict, Generator, Generic, + ItemsView, Iterable, Iterator, KeysView, List, Literal, + Optional, Protocol, Set, Tuple, Type, TypedDict, TypeVar, + Union, ValuesView, final, overload) from typing_extensions import Concatenate, ParamSpec, Self, Unpack @@ -17,6 +18,8 @@ V = TypeVar("V") P = ParamSpec("P") +Numeric = Union[int, float, Decimal] + MaybeAwaitable = Union[Awaitable[T], T] Property = Callable[["ASyncABC"], T] @@ -45,7 +48,7 @@ class ModifierKwargs(TypedDict, total=False): cache_type: CacheType cache_typed: bool ram_cache_maxsize: Optional[int] - ram_cache_ttl: Optional[int] + ram_cache_ttl: Optional[Numeric] runs_per_minute: Optional[int] semaphore: SemaphoreSpec # sync modifiers diff --git a/a_sync/abstract.py b/a_sync/abstract.py index 3f2e3c66..c7566df7 100644 --- a/a_sync/abstract.py +++ b/a_sync/abstract.py @@ -68,6 +68,6 @@ def __a_sync_flag_name__(self) -> str: def __a_sync_flag_value__(self) -> bool: pass - @abc.abstractclassmethod # type: ignore [misc] + @abc.abstractclassmethod # type: ignore [arg-type, misc] def __a_sync_default_mode__(cls) -> bool: - ... \ No newline at end of file + pass \ No newline at end of file diff --git a/a_sync/base.py b/a_sync/base.py index 47636fdd..e3c8c533 100644 --- a/a_sync/base.py +++ b/a_sync/base.py @@ -59,7 +59,7 @@ def __a_sync_default_mode__(cls) -> bool: return sync @classmethod - def __get_a_sync_flag_name_from_signature(cls) -> Optional[str]: + def __get_a_sync_flag_name_from_signature(cls) -> str: logger.debug("Searching for flags defined on %s.__init__", cls) if cls.__name__ == "ASyncGenericBase": logger.debug("There are no flags defined on the base class, this is expected. Skipping.") @@ -69,7 +69,7 @@ def __get_a_sync_flag_name_from_signature(cls) -> Optional[str]: return cls.__parse_flag_name_from_list(parameters) @classmethod - def __get_a_sync_flag_name_from_class_def(cls) -> Optional[str]: + def __get_a_sync_flag_name_from_class_def(cls) -> str: logger.debug("Searching for flags defined on %s", cls) try: return cls.__parse_flag_name_from_list(cls.__dict__) @@ -93,14 +93,14 @@ def __a_sync_flag_default_value_from_signature(cls) -> bool: return flag_value @classmethod - def __get_a_sync_flag_value_from_class_def(cls, flag: str) -> Optional[bool]: + def __get_a_sync_flag_value_from_class_def(cls, flag: str) -> bool: for spec in [cls, *cls.__bases__]: - flag_value = spec.__dict__.get(flag) - if flag_value is not None: - return flag_value + if flag in spec.__dict__: + return spec.__dict__[flag] + raise exceptions.FlagNotDefined(cls, flag) @classmethod - def __parse_flag_name_from_list(cls, items: Dict[str, Any]) -> Optional[str]: + def __parse_flag_name_from_list(cls, items: Dict[str, Any]) -> str: present_flags = [flag for flag in _flags.VIABLE_FLAGS if flag in items] if len(present_flags) == 0: logger.debug("There are too many flags defined on %s", cls) diff --git a/a_sync/exceptions.py b/a_sync/exceptions.py index 9ad1277a..172f8801 100644 --- a/a_sync/exceptions.py +++ b/a_sync/exceptions.py @@ -1,5 +1,5 @@ -from typing import Any, Set +from typing import Any, Optional, Set, Type class ASyncFlagException(ValueError): @@ -31,7 +31,7 @@ def __init__(self, target, present_flags): super().__init__(err) class InvalidFlag(ASyncFlagException): - def __init__(self, flag: str): + def __init__(self, flag: Optional[str]): err = f"'flag' must be one of: {self.viable_flags}. You passed {flag}." err += "\nThis code should not be reached and likely indicates an issue with a custom subclass definition." super().__init__(err) @@ -40,6 +40,9 @@ class InvalidFlagValue(ASyncFlagException): def __init__(self, flag: str, flag_value: Any): super().__init__(f"'{flag}' should be boolean. You passed {flag_value}.") +class FlagNotDefined(ASyncFlagException): + def __init__(self, obj: Type, flag: str): + super().__init__(f"{obj} flag {flag} is not defined.") class ImproperFunctionType(ValueError): diff --git a/a_sync/future.py b/a_sync/future.py index 35e70595..1225a87e 100644 --- a/a_sync/future.py +++ b/a_sync/future.py @@ -23,7 +23,7 @@ def _materialize(meta: "ASyncFuture[T]") -> T: except RuntimeError as e: raise RuntimeError(f"{meta} result is not set and the event loop is running, you will need to await it first") from e -Numeric = Union[int, float, Decimal, "ASyncFuture[int]", "ASyncFuture[float]", "ASyncFuture[Decimal]"] +MetaNumeric = Union[Numeric, "ASyncFuture[int]", "ASyncFuture[float]", "ASyncFuture[Decimal]"] class ASyncFuture(concurrent.futures.Future, Awaitable[T]): __slots__ = "__awaitable__", "__dependencies", "__dependants", "__task" @@ -120,7 +120,7 @@ def __add__(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFu def __add__(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":... @overload def __add__(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... - def __add__(self, other: Numeric) -> "ASyncFuture": + def __add__(self, other: MetaNumeric) -> "ASyncFuture": return ASyncFuture(self.__add(other), dependencies=self.__list_dependencies(other)) @overload def __sub__(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]":... @@ -150,7 +150,7 @@ def __sub__(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFu def __sub__(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":... @overload def __sub__(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":... - def __sub__(self, other: Numeric) -> "ASyncFuture": + def __sub__(self, other: MetaNumeric) -> "ASyncFuture": return ASyncFuture(self.__sub(other), dependencies=self.__list_dependencies(other)) def __mul__(self, other) -> "ASyncFuture": return ASyncFuture(self.__mul(other), dependencies=self.__list_dependencies(other)) diff --git a/a_sync/modifiers/manager.py b/a_sync/modifiers/manager.py index 0dd8e3f8..97f9ec3f 100644 --- a/a_sync/modifiers/manager.py +++ b/a_sync/modifiers/manager.py @@ -18,7 +18,7 @@ class ModifierManager(Dict[str, Any]): # sync modifiers executor: Executor - def __init__(self, **modifiers: ModifierKwargs) -> None: + def __init__(self, **modifiers: Unpack[ModifierKwargs]) -> None: for key in modifiers.keys(): if key not in valid_modifiers: raise ValueError(f"'{key}' is not a supported modifier.") @@ -70,13 +70,13 @@ def sync_modifier_wrap(*args: P.args, **kwargs: P.kwargs) -> T: return sync_modifier_wrap # Dictionary api - def items(self) -> List[Tuple[str, Any]]: + def items(self) -> ItemsView[str, Any]: return self._modifiers.items() - def keys(self) -> List[str]: + def keys(self) -> KeysView[str]: return self._modifiers.keys() - def values(self) -> List[Any]: + def values(self) -> ValuesView[Any]: return self._modifiers.values() - def __contains__(self, key: str) -> bool: + def __contains__(self, key: str) -> bool: # type: ignore [override] return key in self._modifiers def __iter__(self) -> Iterator[str]: return self._modifiers.__iter__() diff --git a/a_sync/primitives/executor.py b/a_sync/primitives/executor.py index c6c1c32f..daaa2f9f 100644 --- a/a_sync/primitives/executor.py +++ b/a_sync/primitives/executor.py @@ -20,6 +20,8 @@ TEN_MINUTES = 60 * 10 +Initializer = Callable[..., object] + class _AsyncExecutorMixin(cf.Executor, _DebugDaemonMixin): _max_workers: int _workers: str @@ -31,7 +33,7 @@ async def run(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: Oh, and you can also use kwargs! """ return fn(*args, **kwargs) if self.sync_mode else await self.submit(fn, *args, **kwargs) - def submit(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[T]": + def submit(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[T]": # type: ignore [override] """Submits a job to the executor and returns an `asyncio.Future` that can be awaited for the result without blocking.""" if self.sync_mode: fut = asyncio.ensure_future(self._exec_sync(fn, *args, **kwargs)) @@ -49,7 +51,7 @@ def sync_mode(self) -> bool: return self._max_workers == 0 @property def worker_count_current(self) -> int: - len(getattr(self, f"_{self._workers}")) + return len(getattr(self, f"_{self._workers}")) async def _exec_sync(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: """Just wraps a fn and its args into an awaitable.""" return fn(*args, **kwargs) @@ -68,7 +70,7 @@ def __init__( self, max_workers: Optional[int] = None, mp_context: Optional[multiprocessing.context.BaseContext] = None, - initializer: Callable[..., object] = None, + initializer: Optional[Initializer] = None, initargs: Tuple[Any, ...] = (), ) -> None: if max_workers == 0: @@ -85,7 +87,7 @@ def __init__( self, max_workers: Optional[int] = None, thread_name_prefix: str = '', - initializer: Callable[..., object] = None, + initializer: Optional[Initializer] = None, initargs: Tuple[Any, ...] = (), ) -> None: if max_workers == 0: diff --git a/a_sync/primitives/locks/event.py b/a_sync/primitives/locks/event.py index 01b74719..557b4f9a 100644 --- a/a_sync/primitives/locks/event.py +++ b/a_sync/primitives/locks/event.py @@ -1,13 +1,17 @@ -import asyncio, sys -from functools import cached_property -from typing import Optional +import asyncio +import sys +from a_sync._typing import * from a_sync.primitives._debug import _DebugDaemonMixin class Event(asyncio.Event, _DebugDaemonMixin): """asyncio.Event but with some additional debug logging to help detect deadlocks.""" + _value: bool + _loop: asyncio.AbstractEventLoop + _waiters: Deque["asyncio.Future[None]"] + def __init__(self, name: str = "", debug_daemon_interval: int = 300, *, loop: Optional[asyncio.AbstractEventLoop] = None): if sys.version_info >= (3, 10): super().__init__() @@ -22,7 +26,7 @@ def __repr__(self) -> str: if self._waiters: status += f', waiters:{len(self._waiters)}' return f"<{self.__class__.__module__}.{self.__class__.__name__} {label} at {hex(id(self))} [{status}]>" - async def wait(self) -> bool: + async def wait(self) -> Literal[True]: if self.is_set(): return True self._ensure_debug_daemon() diff --git a/a_sync/primitives/locks/prio_semaphore.py b/a_sync/primitives/locks/prio_semaphore.py index cb0f5626..b9fdbfa2 100644 --- a/a_sync/primitives/locks/prio_semaphore.py +++ b/a_sync/primitives/locks/prio_semaphore.py @@ -4,14 +4,12 @@ import logging from collections import deque from functools import cached_property -from typing import (Deque, Dict, Generic, List, Literal, Optional, Protocol, - Type, TypeVar) +from a_sync._typing import * from a_sync.primitives.locks.semaphore import Semaphore logger = logging.getLogger(__name__) -T = TypeVar('T', covariant=True) class Priority(Protocol): def __lt__(self, other) -> bool: @@ -41,7 +39,7 @@ def __init__(self, value: int = 1, *, name: Optional[str] = None) -> None: super().__init__(value, name=name) self._waiters = [] # NOTE: This should (hopefully) be temporary - self._potential_lost_waiters = [] + self._potential_lost_waiters: List["asyncio.Future[None]"] = [] def __repr__(self) -> str: return f"<{self.__class__.__name__} name={self.name} capacity={self._capacity} value={self._value} waiters={self._count_waiters()}>" @@ -188,10 +186,10 @@ async def acquire(self) -> Literal[True]: def release(self) -> None: self._parent.release() -class _PrioritySemaphoreContextManager(_AbstractPrioritySemaphoreContextManager[int]): +class _PrioritySemaphoreContextManager(_AbstractPrioritySemaphoreContextManager[Numeric]): _priority_name = "priority" -class PrioritySemaphore(_AbstractPrioritySemaphore[int, _PrioritySemaphoreContextManager]): # type: ignore [type-var] +class PrioritySemaphore(_AbstractPrioritySemaphore[Numeric, _PrioritySemaphoreContextManager]): # type: ignore [type-var] _context_manager_class = _PrioritySemaphoreContextManager _top_priority = -1 """ diff --git a/a_sync/primitives/locks/semaphore.py b/a_sync/primitives/locks/semaphore.py index 3a3cca76..f22ea4fb 100644 --- a/a_sync/primitives/locks/semaphore.py +++ b/a_sync/primitives/locks/semaphore.py @@ -16,7 +16,7 @@ def __init__(self, value: int, name=None, **kwargs) -> None: """ super().__init__(value, **kwargs) self.name = name or self.__origin__ if hasattr(self, '__origin__') else None - self._decorated = set() + self._decorated: Set[str] = set() # Dank new functionality def __call__(self, fn: Callable[P, T]) -> Callable[P, T]: @@ -31,7 +31,7 @@ def __repr__(self) -> str: def __len__(self) -> int: return len(self._waiters) if self._waiters else 0 - def decorate(self, fn: Callable[P, T]) -> Callable[P, T]: + def decorate(self, fn: CoroFn[P, T]) -> CoroFn[P, T]: if not asyncio.iscoroutinefunction(fn): raise TypeError(f"{fn} must be a coroutine function") @functools.wraps(fn) @@ -41,7 +41,7 @@ async def semaphore_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: self._decorated.add(f"{fn.__module__}.{fn.__name__}") return semaphore_wrapper - async def acquire(self) -> bool: + async def acquire(self) -> Literal[True]: if self._value <= 0: self._ensure_debug_daemon() return await super().acquire() diff --git a/a_sync/task.py b/a_sync/task.py index 502c1fce..90f283f9 100644 --- a/a_sync/task.py +++ b/a_sync/task.py @@ -19,7 +19,7 @@ def create_task(coro: Awaitable[T], *, name: Optional[str] = None, skip_gc_until __persist(task) return task -class TaskMapping(DefaultDict[K, "asyncio.Task[_V]"]): +class TaskMapping(DefaultDict[K, "asyncio.Task[V]"]): def __init__(self, coro_fn: Callable[Concatenate[K, P], Awaitable[V]] = None, *, name: str = '', **coro_fn_kwargs: P.kwargs) -> None: self._coro_fn = coro_fn self._coro_fn_kwargs = coro_fn_kwargs @@ -38,7 +38,7 @@ def __getitem__(self, item: K) -> "asyncio.Task[V]": return task async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Literal['keys', 'both'] = 'both') -> AsyncIterator[Tuple[K, V]]: assert not self, "must be empty" - async for key in as_yielded(*[__yield_keys(iterable) for iterable in iterables]): + async for key in as_yielded(*[__yield_keys(iterable) for iterable in iterables]): # type: ignore [attr-defined] self[key] # ensure task is running async for key, value in self.yield_completed(pop=pop): yield __yield(key, value, yields) @@ -47,7 +47,9 @@ async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Litera async def yield_completed(self, pop: bool = True) -> AsyncIterator[Tuple[K, V]]: for k, task in dict(self).items(): if task.done(): - yield k, await self.pop(k) if pop else task + if pop: + task = self.pop(k) + yield k, await task __persisted_tasks: Set["asyncio.Task[Any]"] = set() diff --git a/a_sync/utils/iterators.py b/a_sync/utils/iterators.py index a94abeb0..778f9148 100644 --- a/a_sync/utils/iterators.py +++ b/a_sync/utils/iterators.py @@ -1,7 +1,7 @@ import asyncio import logging -from asyncio.futures import _chain_future +from asyncio.futures import _chain_future # type: ignore [attr-defined] from typing import AsyncIterator, Optional, TypeVar, Union, overload from a_sync.primitives.queue import Queue @@ -58,8 +58,8 @@ async def as_yielded(*iterators: AsyncIterator[T]) -> AsyncIterator[T]: # type: queue = Queue() task = asyncio.create_task(exhaust_iterators(iterators, queue=queue)) def done_callback(t: asyncio.Task) -> None: - if t.exception() and not next_fut.done(): - next_fut.set_exception(t.exception()) + if (e := t.exception()) and not next_fut.done(): + next_fut.set_exception(e) task.add_done_callback(done_callback) while not task.done(): diff --git a/a_sync/utils/map.py b/a_sync/utils/map.py index e13e1f00..0e7e613b 100644 --- a/a_sync/utils/map.py +++ b/a_sync/utils/map.py @@ -1,6 +1,4 @@ -from typing import Awaitable, Callable, Literal, Tuple, Union, overload - from a_sync._typing import * from a_sync.iter import ASyncIterator