Skip to content

Commit

Permalink
chore: fix mypy errs (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Feb 11, 2024
1 parent c44e616 commit ae3ae84
Show file tree
Hide file tree
Showing 14 changed files with 59 additions and 49 deletions.
2 changes: 1 addition & 1 deletion a_sync/_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, coro: Awaitable[T], property: Union[AsyncPropertyDescriptor[T
self._property = property
def __repr__(self) -> str:
return f"<_PropertyGetter for {self._property}._get at {hex(id(self))}>"
def __await__(self) -> T:
def __await__(self) -> Generator[Any, None, T]:
return self._coro.__await__()

@overload
Expand Down
11 changes: 7 additions & 4 deletions a_sync/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from concurrent.futures._base import Executor
from decimal import Decimal
from typing import (TYPE_CHECKING, Any, AsyncIterable, AsyncIterator, Awaitable,
Callable, DefaultDict, Dict, Generic, Iterable, Iterator,
List, Literal, Optional, Set, Tuple, Type, TypedDict, TypeVar,
Union, final, overload)
Callable, DefaultDict, Deque, Dict, Generator, Generic,
ItemsView, Iterable, Iterator, KeysView, List, Literal,
Optional, Protocol, Set, Tuple, Type, TypedDict, TypeVar,
Union, ValuesView, final, overload)

from typing_extensions import Concatenate, ParamSpec, Self, Unpack

Expand All @@ -17,6 +18,8 @@
V = TypeVar("V")
P = ParamSpec("P")

Numeric = Union[int, float, Decimal]

MaybeAwaitable = Union[Awaitable[T], T]

Property = Callable[["ASyncABC"], T]
Expand Down Expand Up @@ -45,7 +48,7 @@ class ModifierKwargs(TypedDict, total=False):
cache_type: CacheType
cache_typed: bool
ram_cache_maxsize: Optional[int]
ram_cache_ttl: Optional[int]
ram_cache_ttl: Optional[Numeric]
runs_per_minute: Optional[int]
semaphore: SemaphoreSpec
# sync modifiers
Expand Down
4 changes: 2 additions & 2 deletions a_sync/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@ def __a_sync_flag_name__(self) -> str:
def __a_sync_flag_value__(self) -> bool:
pass

@abc.abstractclassmethod # type: ignore [misc]
@abc.abstractclassmethod # type: ignore [arg-type, misc]
def __a_sync_default_mode__(cls) -> bool:
...
pass
14 changes: 7 additions & 7 deletions a_sync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __a_sync_default_mode__(cls) -> bool:
return sync

@classmethod
def __get_a_sync_flag_name_from_signature(cls) -> Optional[str]:
def __get_a_sync_flag_name_from_signature(cls) -> str:
logger.debug("Searching for flags defined on %s.__init__", cls)
if cls.__name__ == "ASyncGenericBase":
logger.debug("There are no flags defined on the base class, this is expected. Skipping.")
Expand All @@ -69,7 +69,7 @@ def __get_a_sync_flag_name_from_signature(cls) -> Optional[str]:
return cls.__parse_flag_name_from_list(parameters)

@classmethod
def __get_a_sync_flag_name_from_class_def(cls) -> Optional[str]:
def __get_a_sync_flag_name_from_class_def(cls) -> str:
logger.debug("Searching for flags defined on %s", cls)
try:
return cls.__parse_flag_name_from_list(cls.__dict__)
Expand All @@ -93,14 +93,14 @@ def __a_sync_flag_default_value_from_signature(cls) -> bool:
return flag_value

@classmethod
def __get_a_sync_flag_value_from_class_def(cls, flag: str) -> Optional[bool]:
def __get_a_sync_flag_value_from_class_def(cls, flag: str) -> bool:
for spec in [cls, *cls.__bases__]:
flag_value = spec.__dict__.get(flag)
if flag_value is not None:
return flag_value
if flag in spec.__dict__:
return spec.__dict__[flag]
raise exceptions.FlagNotDefined(cls, flag)

@classmethod
def __parse_flag_name_from_list(cls, items: Dict[str, Any]) -> Optional[str]:
def __parse_flag_name_from_list(cls, items: Dict[str, Any]) -> str:
present_flags = [flag for flag in _flags.VIABLE_FLAGS if flag in items]
if len(present_flags) == 0:
logger.debug("There are too many flags defined on %s", cls)
Expand Down
7 changes: 5 additions & 2 deletions a_sync/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from typing import Any, Set
from typing import Any, Optional, Set, Type


class ASyncFlagException(ValueError):
Expand Down Expand Up @@ -31,7 +31,7 @@ def __init__(self, target, present_flags):
super().__init__(err)

class InvalidFlag(ASyncFlagException):
def __init__(self, flag: str):
def __init__(self, flag: Optional[str]):
err = f"'flag' must be one of: {self.viable_flags}. You passed {flag}."
err += "\nThis code should not be reached and likely indicates an issue with a custom subclass definition."
super().__init__(err)
Expand All @@ -40,6 +40,9 @@ class InvalidFlagValue(ASyncFlagException):
def __init__(self, flag: str, flag_value: Any):
super().__init__(f"'{flag}' should be boolean. You passed {flag_value}.")

class FlagNotDefined(ASyncFlagException):
def __init__(self, obj: Type, flag: str):
super().__init__(f"{obj} flag {flag} is not defined.")


class ImproperFunctionType(ValueError):
Expand Down
6 changes: 3 additions & 3 deletions a_sync/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _materialize(meta: "ASyncFuture[T]") -> T:
except RuntimeError as e:
raise RuntimeError(f"{meta} result is not set and the event loop is running, you will need to await it first") from e

Numeric = Union[int, float, Decimal, "ASyncFuture[int]", "ASyncFuture[float]", "ASyncFuture[Decimal]"]
MetaNumeric = Union[Numeric, "ASyncFuture[int]", "ASyncFuture[float]", "ASyncFuture[Decimal]"]

class ASyncFuture(concurrent.futures.Future, Awaitable[T]):
__slots__ = "__awaitable__", "__dependencies", "__dependants", "__task"
Expand Down Expand Up @@ -120,7 +120,7 @@ def __add__(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFu
def __add__(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":...
@overload
def __add__(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":...
def __add__(self, other: Numeric) -> "ASyncFuture":
def __add__(self, other: MetaNumeric) -> "ASyncFuture":
return ASyncFuture(self.__add(other), dependencies=self.__list_dependencies(other))
@overload
def __sub__(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]":...
Expand Down Expand Up @@ -150,7 +150,7 @@ def __sub__(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFu
def __sub__(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":...
@overload
def __sub__(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":...
def __sub__(self, other: Numeric) -> "ASyncFuture":
def __sub__(self, other: MetaNumeric) -> "ASyncFuture":
return ASyncFuture(self.__sub(other), dependencies=self.__list_dependencies(other))
def __mul__(self, other) -> "ASyncFuture":
return ASyncFuture(self.__mul(other), dependencies=self.__list_dependencies(other))
Expand Down
10 changes: 5 additions & 5 deletions a_sync/modifiers/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ModifierManager(Dict[str, Any]):
# sync modifiers
executor: Executor

def __init__(self, **modifiers: ModifierKwargs) -> None:
def __init__(self, **modifiers: Unpack[ModifierKwargs]) -> None:
for key in modifiers.keys():
if key not in valid_modifiers:
raise ValueError(f"'{key}' is not a supported modifier.")
Expand Down Expand Up @@ -70,13 +70,13 @@ def sync_modifier_wrap(*args: P.args, **kwargs: P.kwargs) -> T:
return sync_modifier_wrap

# Dictionary api
def items(self) -> List[Tuple[str, Any]]:
def items(self) -> ItemsView[str, Any]:
return self._modifiers.items()
def keys(self) -> List[str]:
def keys(self) -> KeysView[str]:
return self._modifiers.keys()
def values(self) -> List[Any]:
def values(self) -> ValuesView[Any]:
return self._modifiers.values()
def __contains__(self, key: str) -> bool:
def __contains__(self, key: str) -> bool: # type: ignore [override]
return key in self._modifiers
def __iter__(self) -> Iterator[str]:
return self._modifiers.__iter__()
Expand Down
10 changes: 6 additions & 4 deletions a_sync/primitives/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

TEN_MINUTES = 60 * 10

Initializer = Callable[..., object]

class _AsyncExecutorMixin(cf.Executor, _DebugDaemonMixin):
_max_workers: int
_workers: str
Expand All @@ -31,7 +33,7 @@ async def run(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
Oh, and you can also use kwargs!
"""
return fn(*args, **kwargs) if self.sync_mode else await self.submit(fn, *args, **kwargs)
def submit(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[T]":
def submit(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[T]": # type: ignore [override]
"""Submits a job to the executor and returns an `asyncio.Future` that can be awaited for the result without blocking."""
if self.sync_mode:
fut = asyncio.ensure_future(self._exec_sync(fn, *args, **kwargs))
Expand All @@ -49,7 +51,7 @@ def sync_mode(self) -> bool:
return self._max_workers == 0
@property
def worker_count_current(self) -> int:
len(getattr(self, f"_{self._workers}"))
return len(getattr(self, f"_{self._workers}"))
async def _exec_sync(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
"""Just wraps a fn and its args into an awaitable."""
return fn(*args, **kwargs)
Expand All @@ -68,7 +70,7 @@ def __init__(
self,
max_workers: Optional[int] = None,
mp_context: Optional[multiprocessing.context.BaseContext] = None,
initializer: Callable[..., object] = None,
initializer: Optional[Initializer] = None,
initargs: Tuple[Any, ...] = (),
) -> None:
if max_workers == 0:
Expand All @@ -85,7 +87,7 @@ def __init__(
self,
max_workers: Optional[int] = None,
thread_name_prefix: str = '',
initializer: Callable[..., object] = None,
initializer: Optional[Initializer] = None,
initargs: Tuple[Any, ...] = (),
) -> None:
if max_workers == 0:
Expand Down
12 changes: 8 additions & 4 deletions a_sync/primitives/locks/event.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@

import asyncio, sys
from functools import cached_property
from typing import Optional
import asyncio
import sys

from a_sync._typing import *
from a_sync.primitives._debug import _DebugDaemonMixin


class Event(asyncio.Event, _DebugDaemonMixin):
"""asyncio.Event but with some additional debug logging to help detect deadlocks."""
_value: bool
_loop: asyncio.AbstractEventLoop
_waiters: Deque["asyncio.Future[None]"]

def __init__(self, name: str = "", debug_daemon_interval: int = 300, *, loop: Optional[asyncio.AbstractEventLoop] = None):
if sys.version_info >= (3, 10):
super().__init__()
Expand All @@ -22,7 +26,7 @@ def __repr__(self) -> str:
if self._waiters:
status += f', waiters:{len(self._waiters)}'
return f"<{self.__class__.__module__}.{self.__class__.__name__} {label} at {hex(id(self))} [{status}]>"
async def wait(self) -> bool:
async def wait(self) -> Literal[True]:
if self.is_set():
return True
self._ensure_debug_daemon()
Expand Down
10 changes: 4 additions & 6 deletions a_sync/primitives/locks/prio_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
import logging
from collections import deque
from functools import cached_property
from typing import (Deque, Dict, Generic, List, Literal, Optional, Protocol,
Type, TypeVar)

from a_sync._typing import *
from a_sync.primitives.locks.semaphore import Semaphore

logger = logging.getLogger(__name__)

T = TypeVar('T', covariant=True)

class Priority(Protocol):
def __lt__(self, other) -> bool:
Expand Down Expand Up @@ -41,7 +39,7 @@ def __init__(self, value: int = 1, *, name: Optional[str] = None) -> None:
super().__init__(value, name=name)
self._waiters = []
# NOTE: This should (hopefully) be temporary
self._potential_lost_waiters = []
self._potential_lost_waiters: List["asyncio.Future[None]"] = []

def __repr__(self) -> str:
return f"<{self.__class__.__name__} name={self.name} capacity={self._capacity} value={self._value} waiters={self._count_waiters()}>"
Expand Down Expand Up @@ -188,10 +186,10 @@ async def acquire(self) -> Literal[True]:
def release(self) -> None:
self._parent.release()

class _PrioritySemaphoreContextManager(_AbstractPrioritySemaphoreContextManager[int]):
class _PrioritySemaphoreContextManager(_AbstractPrioritySemaphoreContextManager[Numeric]):
_priority_name = "priority"

class PrioritySemaphore(_AbstractPrioritySemaphore[int, _PrioritySemaphoreContextManager]): # type: ignore [type-var]
class PrioritySemaphore(_AbstractPrioritySemaphore[Numeric, _PrioritySemaphoreContextManager]): # type: ignore [type-var]
_context_manager_class = _PrioritySemaphoreContextManager
_top_priority = -1
"""
Expand Down
6 changes: 3 additions & 3 deletions a_sync/primitives/locks/semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, value: int, name=None, **kwargs) -> None:
"""
super().__init__(value, **kwargs)
self.name = name or self.__origin__ if hasattr(self, '__origin__') else None
self._decorated = set()
self._decorated: Set[str] = set()

# Dank new functionality
def __call__(self, fn: Callable[P, T]) -> Callable[P, T]:
Expand All @@ -31,7 +31,7 @@ def __repr__(self) -> str:
def __len__(self) -> int:
return len(self._waiters) if self._waiters else 0

def decorate(self, fn: Callable[P, T]) -> Callable[P, T]:
def decorate(self, fn: CoroFn[P, T]) -> CoroFn[P, T]:
if not asyncio.iscoroutinefunction(fn):
raise TypeError(f"{fn} must be a coroutine function")
@functools.wraps(fn)
Expand All @@ -41,7 +41,7 @@ async def semaphore_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
self._decorated.add(f"{fn.__module__}.{fn.__name__}")
return semaphore_wrapper

async def acquire(self) -> bool:
async def acquire(self) -> Literal[True]:
if self._value <= 0:
self._ensure_debug_daemon()
return await super().acquire()
Expand Down
8 changes: 5 additions & 3 deletions a_sync/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def create_task(coro: Awaitable[T], *, name: Optional[str] = None, skip_gc_until
__persist(task)
return task

class TaskMapping(DefaultDict[K, "asyncio.Task[_V]"]):
class TaskMapping(DefaultDict[K, "asyncio.Task[V]"]):
def __init__(self, coro_fn: Callable[Concatenate[K, P], Awaitable[V]] = None, *, name: str = '', **coro_fn_kwargs: P.kwargs) -> None:
self._coro_fn = coro_fn
self._coro_fn_kwargs = coro_fn_kwargs
Expand All @@ -38,7 +38,7 @@ def __getitem__(self, item: K) -> "asyncio.Task[V]":
return task
async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Literal['keys', 'both'] = 'both') -> AsyncIterator[Tuple[K, V]]:
assert not self, "must be empty"
async for key in as_yielded(*[__yield_keys(iterable) for iterable in iterables]):
async for key in as_yielded(*[__yield_keys(iterable) for iterable in iterables]): # type: ignore [attr-defined]
self[key] # ensure task is running
async for key, value in self.yield_completed(pop=pop):
yield __yield(key, value, yields)
Expand All @@ -47,7 +47,9 @@ async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Litera
async def yield_completed(self, pop: bool = True) -> AsyncIterator[Tuple[K, V]]:
for k, task in dict(self).items():
if task.done():
yield k, await self.pop(k) if pop else task
if pop:
task = self.pop(k)
yield k, await task


__persisted_tasks: Set["asyncio.Task[Any]"] = set()
Expand Down
6 changes: 3 additions & 3 deletions a_sync/utils/iterators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import asyncio
import logging
from asyncio.futures import _chain_future
from asyncio.futures import _chain_future # type: ignore [attr-defined]
from typing import AsyncIterator, Optional, TypeVar, Union, overload

from a_sync.primitives.queue import Queue
Expand Down Expand Up @@ -58,8 +58,8 @@ async def as_yielded(*iterators: AsyncIterator[T]) -> AsyncIterator[T]: # type:
queue = Queue()
task = asyncio.create_task(exhaust_iterators(iterators, queue=queue))
def done_callback(t: asyncio.Task) -> None:
if t.exception() and not next_fut.done():
next_fut.set_exception(t.exception())
if (e := t.exception()) and not next_fut.done():
next_fut.set_exception(e)

task.add_done_callback(done_callback)
while not task.done():
Expand Down
2 changes: 0 additions & 2 deletions a_sync/utils/map.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@

from typing import Awaitable, Callable, Literal, Tuple, Union, overload

from a_sync._typing import *
from a_sync.iter import ASyncIterator

Expand Down

0 comments on commit ae3ae84

Please sign in to comment.