Skip to content

Commit

Permalink
chore: fix mypy errs
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler committed Feb 11, 2024
1 parent a9c1828 commit 5db7515
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 25 deletions.
8 changes: 4 additions & 4 deletions a_sync/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from concurrent.futures._base import Executor
from decimal import Decimal
from typing import (TYPE_CHECKING, Any, AsyncIterable, AsyncIterator, Awaitable,
Callable, DefaultDict, Dict, Generic, Iterable, Iterator,
List, Literal, Optional, Protocol, Set, Tuple, Type, TypedDict,
TypeVar, Union, final, overload)
Callable, DefaultDict, Deque, Dict, Generic, ItemsView, Iterable,
Iterator, KeysView, List, Literal, Optional, Protocol, Set, Tuple,
Type, TypedDict, TypeVar, Union, ValuesView, final, overload)

from typing_extensions import Concatenate, ParamSpec, Self, Unpack

Expand Down Expand Up @@ -47,7 +47,7 @@ class ModifierKwargs(TypedDict, total=False):
cache_type: CacheType
cache_typed: bool
ram_cache_maxsize: Optional[int]
ram_cache_ttl: Optional[int]
ram_cache_ttl: Optional[Numeric]
runs_per_minute: Optional[int]
semaphore: SemaphoreSpec
# sync modifiers
Expand Down
4 changes: 2 additions & 2 deletions a_sync/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@ def __a_sync_flag_name__(self) -> str:
def __a_sync_flag_value__(self) -> bool:
pass

@abc.abstractclassmethod # type: ignore [misc]
@abc.abstractclassmethod # type: ignore [arg-type, misc]
def __a_sync_default_mode__(cls) -> bool:
...
pass
14 changes: 7 additions & 7 deletions a_sync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __a_sync_default_mode__(cls) -> bool:
return sync

@classmethod
def __get_a_sync_flag_name_from_signature(cls) -> Optional[str]:
def __get_a_sync_flag_name_from_signature(cls) -> str:
logger.debug("Searching for flags defined on %s.__init__", cls)
if cls.__name__ == "ASyncGenericBase":
logger.debug("There are no flags defined on the base class, this is expected. Skipping.")
Expand All @@ -69,7 +69,7 @@ def __get_a_sync_flag_name_from_signature(cls) -> Optional[str]:
return cls.__parse_flag_name_from_list(parameters)

@classmethod
def __get_a_sync_flag_name_from_class_def(cls) -> Optional[str]:
def __get_a_sync_flag_name_from_class_def(cls) -> str:
logger.debug("Searching for flags defined on %s", cls)
try:
return cls.__parse_flag_name_from_list(cls.__dict__)
Expand All @@ -93,14 +93,14 @@ def __a_sync_flag_default_value_from_signature(cls) -> bool:
return flag_value

@classmethod
def __get_a_sync_flag_value_from_class_def(cls, flag: Optional[str]) -> Optional[bool]:
def __get_a_sync_flag_value_from_class_def(cls, flag: str) -> bool:
for spec in [cls, *cls.__bases__]:
flag_value = spec.__dict__.get(flag)
if flag_value is not None:
return flag_value
if flag in spec.__dict__:
return spec.__dict__[flag]
raise exceptions.FlagNotDefined(cls, flag)

@classmethod
def __parse_flag_name_from_list(cls, items: Dict[str, Any]) -> Optional[str]:
def __parse_flag_name_from_list(cls, items: Dict[str, Any]) -> str:
present_flags = [flag for flag in _flags.VIABLE_FLAGS if flag in items]
if len(present_flags) == 0:
logger.debug("There are too many flags defined on %s", cls)
Expand Down
5 changes: 4 additions & 1 deletion a_sync/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from typing import Any, Optional, Set
from typing import Any, Optional, Set, Type


class ASyncFlagException(ValueError):
Expand Down Expand Up @@ -40,6 +40,9 @@ class InvalidFlagValue(ASyncFlagException):
def __init__(self, flag: str, flag_value: Any):
super().__init__(f"'{flag}' should be boolean. You passed {flag_value}.")

class FlagNotDefined(ASyncFlagException):
def __init__(self, obj: Type, flag: str):
super().__init__(f"{obj} flag {flag} is not defined.")


class ImproperFunctionType(ValueError):
Expand Down
2 changes: 1 addition & 1 deletion a_sync/modifiers/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ModifierManager(Dict[str, Any]):
# sync modifiers
executor: Executor

def __init__(self, **modifiers: ModifierKwargs) -> None:
def __init__(self, **modifiers: Unpack[ModifierKwargs]) -> None:
for key in modifiers.keys():
if key not in valid_modifiers:
raise ValueError(f"'{key}' is not a supported modifier.")
Expand Down
6 changes: 4 additions & 2 deletions a_sync/primitives/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

TEN_MINUTES = 60 * 10

Initializer = Callable[..., object]

class _AsyncExecutorMixin(cf.Executor, _DebugDaemonMixin):
_max_workers: int
_workers: str
Expand Down Expand Up @@ -68,7 +70,7 @@ def __init__(
self,
max_workers: Optional[int] = None,
mp_context: Optional[multiprocessing.context.BaseContext] = None,
initializer: Callable[..., object] = None,
initializer: Optional[Initializer] = None,
initargs: Tuple[Any, ...] = (),
) -> None:
if max_workers == 0:
Expand All @@ -85,7 +87,7 @@ def __init__(
self,
max_workers: Optional[int] = None,
thread_name_prefix: str = '',
initializer: Callable[..., object] = None,
initializer: Optional[Initializer] = None,
initargs: Tuple[Any, ...] = (),
) -> None:
if max_workers == 0:
Expand Down
2 changes: 1 addition & 1 deletion a_sync/primitives/locks/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class Event(asyncio.Event, _DebugDaemonMixin):
"""asyncio.Event but with some additional debug logging to help detect deadlocks."""
_value: bool
_loop: asyncio.BaseEventLoop
_loop: asyncio.AbstractEventLoop
_waiters: Deque["asyncio.Future[None]"]

def __init__(self, name: str = "", debug_daemon_interval: int = 300, *, loop: Optional[asyncio.AbstractEventLoop] = None):
Expand Down
3 changes: 1 addition & 2 deletions a_sync/primitives/locks/prio_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

logger = logging.getLogger(__name__)

T = TypeVar('T', covariant=True)

class Priority(Protocol):
def __lt__(self, other) -> bool:
Expand Down Expand Up @@ -40,7 +39,7 @@ def __init__(self, value: int = 1, *, name: Optional[str] = None) -> None:
super().__init__(value, name=name)
self._waiters = []
# NOTE: This should (hopefully) be temporary
self._potential_lost_waiters = []
self._potential_lost_waiters: List["asyncio.Future[None]"] = []

def __repr__(self) -> str:
return f"<{self.__class__.__name__} name={self.name} capacity={self._capacity} value={self._value} waiters={self._count_waiters()}>"
Expand Down
6 changes: 4 additions & 2 deletions a_sync/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __getitem__(self, item: K) -> "asyncio.Task[V]":
return task
async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Literal['keys', 'both'] = 'both') -> AsyncIterator[Tuple[K, V]]:
assert not self, "must be empty"
async for key in as_yielded(*[__yield_keys(iterable) for iterable in iterables]):
async for key in as_yielded(*[__yield_keys(iterable) for iterable in iterables]): # type: ignore [attr-defined]
self[key] # ensure task is running
async for key, value in self.yield_completed(pop=pop):
yield __yield(key, value, yields)
Expand All @@ -47,7 +47,9 @@ async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Litera
async def yield_completed(self, pop: bool = True) -> AsyncIterator[Tuple[K, V]]:
for k, task in dict(self).items():
if task.done():
yield k, await self.pop(k) if pop else task
if pop:
task = self.pop(k)
yield k, await task


__persisted_tasks: Set["asyncio.Task[Any]"] = set()
Expand Down
6 changes: 3 additions & 3 deletions a_sync/utils/iterators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import asyncio
import logging
from asyncio.futures import _chain_future
from asyncio.futures import _chain_future # type: ignore [attr-defined]
from typing import AsyncIterator, Optional, TypeVar, Union, overload

from a_sync.primitives.queue import Queue
Expand Down Expand Up @@ -58,8 +58,8 @@ async def as_yielded(*iterators: AsyncIterator[T]) -> AsyncIterator[T]: # type:
queue = Queue()
task = asyncio.create_task(exhaust_iterators(iterators, queue=queue))
def done_callback(t: asyncio.Task) -> None:
if t.exception() and not next_fut.done():
next_fut.set_exception(t.exception())
if (e := t.exception()) and not next_fut.done():
next_fut.set_exception(e)

task.add_done_callback(done_callback)
while not task.done():
Expand Down

0 comments on commit 5db7515

Please sign in to comment.