From 5db75155cd7354dff2e13922ab2f40b90dfa18c3 Mon Sep 17 00:00:00 2001 From: BobTheBuidler Date: Sun, 11 Feb 2024 22:51:54 +0000 Subject: [PATCH] chore: fix mypy errs --- a_sync/_typing.py | 8 ++++---- a_sync/abstract.py | 4 ++-- a_sync/base.py | 14 +++++++------- a_sync/exceptions.py | 5 ++++- a_sync/modifiers/manager.py | 2 +- a_sync/primitives/executor.py | 6 ++++-- a_sync/primitives/locks/event.py | 2 +- a_sync/primitives/locks/prio_semaphore.py | 3 +-- a_sync/task.py | 6 ++++-- a_sync/utils/iterators.py | 6 +++--- 10 files changed, 31 insertions(+), 25 deletions(-) diff --git a/a_sync/_typing.py b/a_sync/_typing.py index bbc4c9fe..5a750066 100644 --- a/a_sync/_typing.py +++ b/a_sync/_typing.py @@ -3,9 +3,9 @@ from concurrent.futures._base import Executor from decimal import Decimal from typing import (TYPE_CHECKING, Any, AsyncIterable, AsyncIterator, Awaitable, - Callable, DefaultDict, Dict, Generic, Iterable, Iterator, - List, Literal, Optional, Protocol, Set, Tuple, Type, TypedDict, - TypeVar, Union, final, overload) + Callable, DefaultDict, Deque, Dict, Generic, ItemsView, Iterable, + Iterator, KeysView, List, Literal, Optional, Protocol, Set, Tuple, + Type, TypedDict, TypeVar, Union, ValuesView, final, overload) from typing_extensions import Concatenate, ParamSpec, Self, Unpack @@ -47,7 +47,7 @@ class ModifierKwargs(TypedDict, total=False): cache_type: CacheType cache_typed: bool ram_cache_maxsize: Optional[int] - ram_cache_ttl: Optional[int] + ram_cache_ttl: Optional[Numeric] runs_per_minute: Optional[int] semaphore: SemaphoreSpec # sync modifiers diff --git a/a_sync/abstract.py b/a_sync/abstract.py index 3f2e3c66..c7566df7 100644 --- a/a_sync/abstract.py +++ b/a_sync/abstract.py @@ -68,6 +68,6 @@ def __a_sync_flag_name__(self) -> str: def __a_sync_flag_value__(self) -> bool: pass - @abc.abstractclassmethod # type: ignore [misc] + @abc.abstractclassmethod # type: ignore [arg-type, misc] def __a_sync_default_mode__(cls) -> bool: - ... \ No newline at end of file + pass \ No newline at end of file diff --git a/a_sync/base.py b/a_sync/base.py index 8e2037e4..e3c8c533 100644 --- a/a_sync/base.py +++ b/a_sync/base.py @@ -59,7 +59,7 @@ def __a_sync_default_mode__(cls) -> bool: return sync @classmethod - def __get_a_sync_flag_name_from_signature(cls) -> Optional[str]: + def __get_a_sync_flag_name_from_signature(cls) -> str: logger.debug("Searching for flags defined on %s.__init__", cls) if cls.__name__ == "ASyncGenericBase": logger.debug("There are no flags defined on the base class, this is expected. Skipping.") @@ -69,7 +69,7 @@ def __get_a_sync_flag_name_from_signature(cls) -> Optional[str]: return cls.__parse_flag_name_from_list(parameters) @classmethod - def __get_a_sync_flag_name_from_class_def(cls) -> Optional[str]: + def __get_a_sync_flag_name_from_class_def(cls) -> str: logger.debug("Searching for flags defined on %s", cls) try: return cls.__parse_flag_name_from_list(cls.__dict__) @@ -93,14 +93,14 @@ def __a_sync_flag_default_value_from_signature(cls) -> bool: return flag_value @classmethod - def __get_a_sync_flag_value_from_class_def(cls, flag: Optional[str]) -> Optional[bool]: + def __get_a_sync_flag_value_from_class_def(cls, flag: str) -> bool: for spec in [cls, *cls.__bases__]: - flag_value = spec.__dict__.get(flag) - if flag_value is not None: - return flag_value + if flag in spec.__dict__: + return spec.__dict__[flag] + raise exceptions.FlagNotDefined(cls, flag) @classmethod - def __parse_flag_name_from_list(cls, items: Dict[str, Any]) -> Optional[str]: + def __parse_flag_name_from_list(cls, items: Dict[str, Any]) -> str: present_flags = [flag for flag in _flags.VIABLE_FLAGS if flag in items] if len(present_flags) == 0: logger.debug("There are too many flags defined on %s", cls) diff --git a/a_sync/exceptions.py b/a_sync/exceptions.py index 1a806efd..172f8801 100644 --- a/a_sync/exceptions.py +++ b/a_sync/exceptions.py @@ -1,5 +1,5 @@ -from typing import Any, Optional, Set +from typing import Any, Optional, Set, Type class ASyncFlagException(ValueError): @@ -40,6 +40,9 @@ class InvalidFlagValue(ASyncFlagException): def __init__(self, flag: str, flag_value: Any): super().__init__(f"'{flag}' should be boolean. You passed {flag_value}.") +class FlagNotDefined(ASyncFlagException): + def __init__(self, obj: Type, flag: str): + super().__init__(f"{obj} flag {flag} is not defined.") class ImproperFunctionType(ValueError): diff --git a/a_sync/modifiers/manager.py b/a_sync/modifiers/manager.py index d55573a6..97f9ec3f 100644 --- a/a_sync/modifiers/manager.py +++ b/a_sync/modifiers/manager.py @@ -18,7 +18,7 @@ class ModifierManager(Dict[str, Any]): # sync modifiers executor: Executor - def __init__(self, **modifiers: ModifierKwargs) -> None: + def __init__(self, **modifiers: Unpack[ModifierKwargs]) -> None: for key in modifiers.keys(): if key not in valid_modifiers: raise ValueError(f"'{key}' is not a supported modifier.") diff --git a/a_sync/primitives/executor.py b/a_sync/primitives/executor.py index 28373a0d..daaa2f9f 100644 --- a/a_sync/primitives/executor.py +++ b/a_sync/primitives/executor.py @@ -20,6 +20,8 @@ TEN_MINUTES = 60 * 10 +Initializer = Callable[..., object] + class _AsyncExecutorMixin(cf.Executor, _DebugDaemonMixin): _max_workers: int _workers: str @@ -68,7 +70,7 @@ def __init__( self, max_workers: Optional[int] = None, mp_context: Optional[multiprocessing.context.BaseContext] = None, - initializer: Callable[..., object] = None, + initializer: Optional[Initializer] = None, initargs: Tuple[Any, ...] = (), ) -> None: if max_workers == 0: @@ -85,7 +87,7 @@ def __init__( self, max_workers: Optional[int] = None, thread_name_prefix: str = '', - initializer: Callable[..., object] = None, + initializer: Optional[Initializer] = None, initargs: Tuple[Any, ...] = (), ) -> None: if max_workers == 0: diff --git a/a_sync/primitives/locks/event.py b/a_sync/primitives/locks/event.py index da821842..557b4f9a 100644 --- a/a_sync/primitives/locks/event.py +++ b/a_sync/primitives/locks/event.py @@ -9,7 +9,7 @@ class Event(asyncio.Event, _DebugDaemonMixin): """asyncio.Event but with some additional debug logging to help detect deadlocks.""" _value: bool - _loop: asyncio.BaseEventLoop + _loop: asyncio.AbstractEventLoop _waiters: Deque["asyncio.Future[None]"] def __init__(self, name: str = "", debug_daemon_interval: int = 300, *, loop: Optional[asyncio.AbstractEventLoop] = None): diff --git a/a_sync/primitives/locks/prio_semaphore.py b/a_sync/primitives/locks/prio_semaphore.py index 755818d9..b9fdbfa2 100644 --- a/a_sync/primitives/locks/prio_semaphore.py +++ b/a_sync/primitives/locks/prio_semaphore.py @@ -10,7 +10,6 @@ logger = logging.getLogger(__name__) -T = TypeVar('T', covariant=True) class Priority(Protocol): def __lt__(self, other) -> bool: @@ -40,7 +39,7 @@ def __init__(self, value: int = 1, *, name: Optional[str] = None) -> None: super().__init__(value, name=name) self._waiters = [] # NOTE: This should (hopefully) be temporary - self._potential_lost_waiters = [] + self._potential_lost_waiters: List["asyncio.Future[None]"] = [] def __repr__(self) -> str: return f"<{self.__class__.__name__} name={self.name} capacity={self._capacity} value={self._value} waiters={self._count_waiters()}>" diff --git a/a_sync/task.py b/a_sync/task.py index 31f58fbc..90f283f9 100644 --- a/a_sync/task.py +++ b/a_sync/task.py @@ -38,7 +38,7 @@ def __getitem__(self, item: K) -> "asyncio.Task[V]": return task async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Literal['keys', 'both'] = 'both') -> AsyncIterator[Tuple[K, V]]: assert not self, "must be empty" - async for key in as_yielded(*[__yield_keys(iterable) for iterable in iterables]): + async for key in as_yielded(*[__yield_keys(iterable) for iterable in iterables]): # type: ignore [attr-defined] self[key] # ensure task is running async for key, value in self.yield_completed(pop=pop): yield __yield(key, value, yields) @@ -47,7 +47,9 @@ async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Litera async def yield_completed(self, pop: bool = True) -> AsyncIterator[Tuple[K, V]]: for k, task in dict(self).items(): if task.done(): - yield k, await self.pop(k) if pop else task + if pop: + task = self.pop(k) + yield k, await task __persisted_tasks: Set["asyncio.Task[Any]"] = set() diff --git a/a_sync/utils/iterators.py b/a_sync/utils/iterators.py index a94abeb0..778f9148 100644 --- a/a_sync/utils/iterators.py +++ b/a_sync/utils/iterators.py @@ -1,7 +1,7 @@ import asyncio import logging -from asyncio.futures import _chain_future +from asyncio.futures import _chain_future # type: ignore [attr-defined] from typing import AsyncIterator, Optional, TypeVar, Union, overload from a_sync.primitives.queue import Queue @@ -58,8 +58,8 @@ async def as_yielded(*iterators: AsyncIterator[T]) -> AsyncIterator[T]: # type: queue = Queue() task = asyncio.create_task(exhaust_iterators(iterators, queue=queue)) def done_callback(t: asyncio.Task) -> None: - if t.exception() and not next_fut.done(): - next_fut.set_exception(t.exception()) + if (e := t.exception()) and not next_fut.done(): + next_fut.set_exception(e) task.add_done_callback(done_callback) while not task.done():