Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: fix mypy errs #114

Merged
merged 5 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion a_sync/_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, coro: Awaitable[T], property: Union[AsyncPropertyDescriptor[T
self._property = property
def __repr__(self) -> str:
return f"<_PropertyGetter for {self._property}._get at {hex(id(self))}>"
def __await__(self) -> T:
def __await__(self) -> Generator[Any, None, T]:
return self._coro.__await__()

@overload
Expand Down
11 changes: 7 additions & 4 deletions a_sync/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from concurrent.futures._base import Executor
from decimal import Decimal
from typing import (TYPE_CHECKING, Any, AsyncIterable, AsyncIterator, Awaitable,
Callable, DefaultDict, Dict, Generic, Iterable, Iterator,
List, Literal, Optional, Set, Tuple, Type, TypedDict, TypeVar,
Union, final, overload)
Callable, DefaultDict, Deque, Dict, Generator, Generic,
ItemsView, Iterable, Iterator, KeysView, List, Literal,
Optional, Protocol, Set, Tuple, Type, TypedDict, TypeVar,
Union, ValuesView, final, overload)

from typing_extensions import Concatenate, ParamSpec, Self, Unpack

Expand All @@ -17,6 +18,8 @@
V = TypeVar("V")
P = ParamSpec("P")

Numeric = Union[int, float, Decimal]

MaybeAwaitable = Union[Awaitable[T], T]

Property = Callable[["ASyncABC"], T]
Expand Down Expand Up @@ -45,7 +48,7 @@ class ModifierKwargs(TypedDict, total=False):
cache_type: CacheType
cache_typed: bool
ram_cache_maxsize: Optional[int]
ram_cache_ttl: Optional[int]
ram_cache_ttl: Optional[Numeric]
runs_per_minute: Optional[int]
semaphore: SemaphoreSpec
# sync modifiers
Expand Down
4 changes: 2 additions & 2 deletions a_sync/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@ def __a_sync_flag_name__(self) -> str:
def __a_sync_flag_value__(self) -> bool:
pass

@abc.abstractclassmethod # type: ignore [misc]
@abc.abstractclassmethod # type: ignore [arg-type, misc]
def __a_sync_default_mode__(cls) -> bool:
...
pass
14 changes: 7 additions & 7 deletions a_sync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __a_sync_default_mode__(cls) -> bool:
return sync

@classmethod
def __get_a_sync_flag_name_from_signature(cls) -> Optional[str]:
def __get_a_sync_flag_name_from_signature(cls) -> str:
logger.debug("Searching for flags defined on %s.__init__", cls)
if cls.__name__ == "ASyncGenericBase":
logger.debug("There are no flags defined on the base class, this is expected. Skipping.")
Expand All @@ -69,7 +69,7 @@ def __get_a_sync_flag_name_from_signature(cls) -> Optional[str]:
return cls.__parse_flag_name_from_list(parameters)

@classmethod
def __get_a_sync_flag_name_from_class_def(cls) -> Optional[str]:
def __get_a_sync_flag_name_from_class_def(cls) -> str:
logger.debug("Searching for flags defined on %s", cls)
try:
return cls.__parse_flag_name_from_list(cls.__dict__)
Expand All @@ -93,14 +93,14 @@ def __a_sync_flag_default_value_from_signature(cls) -> bool:
return flag_value

@classmethod
def __get_a_sync_flag_value_from_class_def(cls, flag: str) -> Optional[bool]:
def __get_a_sync_flag_value_from_class_def(cls, flag: str) -> bool:
for spec in [cls, *cls.__bases__]:
flag_value = spec.__dict__.get(flag)
if flag_value is not None:
return flag_value
if flag in spec.__dict__:
return spec.__dict__[flag]
raise exceptions.FlagNotDefined(cls, flag)

@classmethod
def __parse_flag_name_from_list(cls, items: Dict[str, Any]) -> Optional[str]:
def __parse_flag_name_from_list(cls, items: Dict[str, Any]) -> str:
present_flags = [flag for flag in _flags.VIABLE_FLAGS if flag in items]
if len(present_flags) == 0:
logger.debug("There are too many flags defined on %s", cls)
Expand Down
7 changes: 5 additions & 2 deletions a_sync/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from typing import Any, Set
from typing import Any, Optional, Set, Type


class ASyncFlagException(ValueError):
Expand Down Expand Up @@ -31,7 +31,7 @@ def __init__(self, target, present_flags):
super().__init__(err)

class InvalidFlag(ASyncFlagException):
def __init__(self, flag: str):
def __init__(self, flag: Optional[str]):
err = f"'flag' must be one of: {self.viable_flags}. You passed {flag}."
err += "\nThis code should not be reached and likely indicates an issue with a custom subclass definition."
super().__init__(err)
Expand All @@ -40,6 +40,9 @@ class InvalidFlagValue(ASyncFlagException):
def __init__(self, flag: str, flag_value: Any):
super().__init__(f"'{flag}' should be boolean. You passed {flag_value}.")

class FlagNotDefined(ASyncFlagException):
def __init__(self, obj: Type, flag: str):
super().__init__(f"{obj} flag {flag} is not defined.")


class ImproperFunctionType(ValueError):
Expand Down
6 changes: 3 additions & 3 deletions a_sync/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _materialize(meta: "ASyncFuture[T]") -> T:
except RuntimeError as e:
raise RuntimeError(f"{meta} result is not set and the event loop is running, you will need to await it first") from e

Numeric = Union[int, float, Decimal, "ASyncFuture[int]", "ASyncFuture[float]", "ASyncFuture[Decimal]"]
MetaNumeric = Union[Numeric, "ASyncFuture[int]", "ASyncFuture[float]", "ASyncFuture[Decimal]"]

class ASyncFuture(concurrent.futures.Future, Awaitable[T]):
__slots__ = "__awaitable__", "__dependencies", "__dependants", "__task"
Expand Down Expand Up @@ -120,7 +120,7 @@ def __add__(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFu
def __add__(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":...
@overload
def __add__(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":...
def __add__(self, other: Numeric) -> "ASyncFuture":
def __add__(self, other: MetaNumeric) -> "ASyncFuture":
return ASyncFuture(self.__add(other), dependencies=self.__list_dependencies(other))
@overload
def __sub__(self: "ASyncFuture[int]", other: int) -> "ASyncFuture[int]":...
Expand Down Expand Up @@ -150,7 +150,7 @@ def __sub__(self: "ASyncFuture[Decimal]", other: Awaitable[Decimal]) -> "ASyncFu
def __sub__(self: "ASyncFuture[Decimal]", other: Awaitable[int]) -> "ASyncFuture[Decimal]":...
@overload
def __sub__(self: "ASyncFuture[int]", other: Awaitable[Decimal]) -> "ASyncFuture[Decimal]":...
def __sub__(self, other: Numeric) -> "ASyncFuture":
def __sub__(self, other: MetaNumeric) -> "ASyncFuture":
return ASyncFuture(self.__sub(other), dependencies=self.__list_dependencies(other))
def __mul__(self, other) -> "ASyncFuture":
return ASyncFuture(self.__mul(other), dependencies=self.__list_dependencies(other))
Expand Down
10 changes: 5 additions & 5 deletions a_sync/modifiers/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ModifierManager(Dict[str, Any]):
# sync modifiers
executor: Executor

def __init__(self, **modifiers: ModifierKwargs) -> None:
def __init__(self, **modifiers: Unpack[ModifierKwargs]) -> None:
for key in modifiers.keys():
if key not in valid_modifiers:
raise ValueError(f"'{key}' is not a supported modifier.")
Expand Down Expand Up @@ -70,13 +70,13 @@ def sync_modifier_wrap(*args: P.args, **kwargs: P.kwargs) -> T:
return sync_modifier_wrap

# Dictionary api
def items(self) -> List[Tuple[str, Any]]:
def items(self) -> ItemsView[str, Any]:
return self._modifiers.items()
def keys(self) -> List[str]:
def keys(self) -> KeysView[str]:
return self._modifiers.keys()
def values(self) -> List[Any]:
def values(self) -> ValuesView[Any]:
return self._modifiers.values()
def __contains__(self, key: str) -> bool:
def __contains__(self, key: str) -> bool: # type: ignore [override]
return key in self._modifiers
def __iter__(self) -> Iterator[str]:
return self._modifiers.__iter__()
Expand Down
10 changes: 6 additions & 4 deletions a_sync/primitives/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

TEN_MINUTES = 60 * 10

Initializer = Callable[..., object]

class _AsyncExecutorMixin(cf.Executor, _DebugDaemonMixin):
_max_workers: int
_workers: str
Expand All @@ -31,7 +33,7 @@ async def run(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
Oh, and you can also use kwargs!
"""
return fn(*args, **kwargs) if self.sync_mode else await self.submit(fn, *args, **kwargs)
def submit(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[T]":
def submit(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[T]": # type: ignore [override]
"""Submits a job to the executor and returns an `asyncio.Future` that can be awaited for the result without blocking."""
if self.sync_mode:
fut = asyncio.ensure_future(self._exec_sync(fn, *args, **kwargs))
Expand All @@ -49,7 +51,7 @@ def sync_mode(self) -> bool:
return self._max_workers == 0
@property
def worker_count_current(self) -> int:
len(getattr(self, f"_{self._workers}"))
return len(getattr(self, f"_{self._workers}"))
async def _exec_sync(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
"""Just wraps a fn and its args into an awaitable."""
return fn(*args, **kwargs)
Expand All @@ -68,7 +70,7 @@ def __init__(
self,
max_workers: Optional[int] = None,
mp_context: Optional[multiprocessing.context.BaseContext] = None,
initializer: Callable[..., object] = None,
initializer: Optional[Initializer] = None,
initargs: Tuple[Any, ...] = (),
) -> None:
if max_workers == 0:
Expand All @@ -85,7 +87,7 @@ def __init__(
self,
max_workers: Optional[int] = None,
thread_name_prefix: str = '',
initializer: Callable[..., object] = None,
initializer: Optional[Initializer] = None,
initargs: Tuple[Any, ...] = (),
) -> None:
if max_workers == 0:
Expand Down
12 changes: 8 additions & 4 deletions a_sync/primitives/locks/event.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@

import asyncio, sys
from functools import cached_property
from typing import Optional
import asyncio
import sys

from a_sync._typing import *
from a_sync.primitives._debug import _DebugDaemonMixin


class Event(asyncio.Event, _DebugDaemonMixin):
"""asyncio.Event but with some additional debug logging to help detect deadlocks."""
_value: bool
_loop: asyncio.AbstractEventLoop
_waiters: Deque["asyncio.Future[None]"]

def __init__(self, name: str = "", debug_daemon_interval: int = 300, *, loop: Optional[asyncio.AbstractEventLoop] = None):
if sys.version_info >= (3, 10):
super().__init__()
Expand All @@ -22,7 +26,7 @@ def __repr__(self) -> str:
if self._waiters:
status += f', waiters:{len(self._waiters)}'
return f"<{self.__class__.__module__}.{self.__class__.__name__} {label} at {hex(id(self))} [{status}]>"
async def wait(self) -> bool:
async def wait(self) -> Literal[True]:
if self.is_set():
return True
self._ensure_debug_daemon()
Expand Down
10 changes: 4 additions & 6 deletions a_sync/primitives/locks/prio_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
import logging
from collections import deque
from functools import cached_property
from typing import (Deque, Dict, Generic, List, Literal, Optional, Protocol,
Type, TypeVar)

from a_sync._typing import *
from a_sync.primitives.locks.semaphore import Semaphore

logger = logging.getLogger(__name__)

T = TypeVar('T', covariant=True)

class Priority(Protocol):
def __lt__(self, other) -> bool:
Expand Down Expand Up @@ -41,7 +39,7 @@ def __init__(self, value: int = 1, *, name: Optional[str] = None) -> None:
super().__init__(value, name=name)
self._waiters = []
# NOTE: This should (hopefully) be temporary
self._potential_lost_waiters = []
self._potential_lost_waiters: List["asyncio.Future[None]"] = []

def __repr__(self) -> str:
return f"<{self.__class__.__name__} name={self.name} capacity={self._capacity} value={self._value} waiters={self._count_waiters()}>"
Expand Down Expand Up @@ -188,10 +186,10 @@ async def acquire(self) -> Literal[True]:
def release(self) -> None:
self._parent.release()

class _PrioritySemaphoreContextManager(_AbstractPrioritySemaphoreContextManager[int]):
class _PrioritySemaphoreContextManager(_AbstractPrioritySemaphoreContextManager[Numeric]):
_priority_name = "priority"

class PrioritySemaphore(_AbstractPrioritySemaphore[int, _PrioritySemaphoreContextManager]): # type: ignore [type-var]
class PrioritySemaphore(_AbstractPrioritySemaphore[Numeric, _PrioritySemaphoreContextManager]): # type: ignore [type-var]
_context_manager_class = _PrioritySemaphoreContextManager
_top_priority = -1
"""
Expand Down
6 changes: 3 additions & 3 deletions a_sync/primitives/locks/semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, value: int, name=None, **kwargs) -> None:
"""
super().__init__(value, **kwargs)
self.name = name or self.__origin__ if hasattr(self, '__origin__') else None
self._decorated = set()
self._decorated: Set[str] = set()

# Dank new functionality
def __call__(self, fn: Callable[P, T]) -> Callable[P, T]:
Expand All @@ -31,7 +31,7 @@ def __repr__(self) -> str:
def __len__(self) -> int:
return len(self._waiters) if self._waiters else 0

def decorate(self, fn: Callable[P, T]) -> Callable[P, T]:
def decorate(self, fn: CoroFn[P, T]) -> CoroFn[P, T]:
if not asyncio.iscoroutinefunction(fn):
raise TypeError(f"{fn} must be a coroutine function")
@functools.wraps(fn)
Expand All @@ -41,7 +41,7 @@ async def semaphore_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
self._decorated.add(f"{fn.__module__}.{fn.__name__}")
return semaphore_wrapper

async def acquire(self) -> bool:
async def acquire(self) -> Literal[True]:
if self._value <= 0:
self._ensure_debug_daemon()
return await super().acquire()
Expand Down
8 changes: 5 additions & 3 deletions a_sync/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def create_task(coro: Awaitable[T], *, name: Optional[str] = None, skip_gc_until
__persist(task)
return task

class TaskMapping(DefaultDict[K, "asyncio.Task[_V]"]):
class TaskMapping(DefaultDict[K, "asyncio.Task[V]"]):
def __init__(self, coro_fn: Callable[Concatenate[K, P], Awaitable[V]] = None, *, name: str = '', **coro_fn_kwargs: P.kwargs) -> None:
self._coro_fn = coro_fn
self._coro_fn_kwargs = coro_fn_kwargs
Expand All @@ -38,7 +38,7 @@ def __getitem__(self, item: K) -> "asyncio.Task[V]":
return task
async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Literal['keys', 'both'] = 'both') -> AsyncIterator[Tuple[K, V]]:
assert not self, "must be empty"
async for key in as_yielded(*[__yield_keys(iterable) for iterable in iterables]):
async for key in as_yielded(*[__yield_keys(iterable) for iterable in iterables]): # type: ignore [attr-defined]
self[key] # ensure task is running
async for key, value in self.yield_completed(pop=pop):
yield __yield(key, value, yields)
Expand All @@ -47,7 +47,9 @@ async def map(self, *iterables: AnyIterable[K], pop: bool = True, yields: Litera
async def yield_completed(self, pop: bool = True) -> AsyncIterator[Tuple[K, V]]:
for k, task in dict(self).items():
if task.done():
yield k, await self.pop(k) if pop else task
if pop:
task = self.pop(k)
yield k, await task


__persisted_tasks: Set["asyncio.Task[Any]"] = set()
Expand Down
6 changes: 3 additions & 3 deletions a_sync/utils/iterators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import asyncio
import logging
from asyncio.futures import _chain_future
from asyncio.futures import _chain_future # type: ignore [attr-defined]
from typing import AsyncIterator, Optional, TypeVar, Union, overload

from a_sync.primitives.queue import Queue
Expand Down Expand Up @@ -58,8 +58,8 @@ async def as_yielded(*iterators: AsyncIterator[T]) -> AsyncIterator[T]: # type:
queue = Queue()
task = asyncio.create_task(exhaust_iterators(iterators, queue=queue))
def done_callback(t: asyncio.Task) -> None:
if t.exception() and not next_fut.done():
next_fut.set_exception(t.exception())
if (e := t.exception()) and not next_fut.done():
next_fut.set_exception(e)

task.add_done_callback(done_callback)
while not task.done():
Expand Down
2 changes: 0 additions & 2 deletions a_sync/utils/map.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@

from typing import Awaitable, Callable, Literal, Tuple, Union, overload

from a_sync._typing import *
from a_sync.iter import ASyncIterator

Expand Down
Loading