Skip to content

Commit

Permalink
fix: dict type cant hold defaultdict (#420)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Nov 22, 2024
1 parent 17da9c4 commit 8e0e9d2
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 24 deletions.
3 changes: 2 additions & 1 deletion a_sync/a_sync/_kwargs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ cdef bint is_sync(str flag, dict kwargs, bint pop_flag):
:func:`get_flag_name`: Retrieves the name of the flag present in the kwargs.
"""
if pop_flag:
# NOTE: we should techincally raise InvalidFlagValue here but I dont want to set flag_value to a var
# NOTE: we should techincally raise InvalidFlagValue here but I dont want
# to set flag_value to a var and it will raise a TypeError anyway.
return negate_if_necessary(flag, kwargs.pop(flag))
else:
try:
Expand Down
8 changes: 3 additions & 5 deletions a_sync/a_sync/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class ASyncGenericBase(ASyncABC):
return flag

@functools.cached_property
def __a_sync_flag_value__(self) -> bool:
def __a_sync_flag_value__(self) -> bint:
# TODO: cythonize this cache
"""If you wish to be able to hotswap default modes, just duplicate this def as a non-cached property."""
if c_logger.isEnabledFor(logging.DEBUG):
Expand Down Expand Up @@ -202,9 +202,8 @@ cdef str _get_a_sync_flag_name_from_signature(object cls):


cdef str _parse_flag_name_from_list(object cls, object items):
cdef list[str] present_flags
cdef str flag
present_flags = [flag for flag in VIABLE_FLAGS if flag in items]
cdef list[str] present_flags = [flag for flag in VIABLE_FLAGS if flag in items]
if not present_flags:
c_logger.debug("There are too many flags defined on %s", cls)
raise exceptions.NoFlagsFound(cls, items.keys())
Expand All @@ -215,8 +214,7 @@ cdef str _parse_flag_name_from_list(object cls, object items):
flag = present_flags[0]
c_logger._log(logging.DEBUG, "found flag %s", flag)
return flag
else:
return present_flags[0]
return present_flags[0]


cdef bint _get_a_sync_flag_value_from_class_def(object cls, str flag):
Expand Down
14 changes: 7 additions & 7 deletions a_sync/a_sync/function.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]):
def __init__(
self,
fn: AnyFn[P, T],
_skip_validate: bool = False,
_skip_validate: bint = False,
**modifiers: Unpack[ModifierKwargs],
) -> None:
"""
Expand Down Expand Up @@ -416,7 +416,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]):
concurrency: Optional[int] = None,
task_name: str = "",
**function_kwargs: P.kwargs,
) -> bool:
) -> bint:
"""
Checks if any result of the function applied to the iterables is truthy.

Expand Down Expand Up @@ -445,7 +445,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]):
concurrency: Optional[int] = None,
task_name: str = "",
**function_kwargs: P.kwargs,
) -> bool:
) -> bint:
"""
Checks if all results of the function applied to the iterables are truthy.

Expand Down Expand Up @@ -595,7 +595,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]):
concurrency: Optional[int] = None,
task_name: str = "",
**function_kwargs: P.kwargs,
) -> bool:
) -> bint:
"""
Checks if any result of the function applied to the iterables is truthy.

Expand Down Expand Up @@ -624,7 +624,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]):
concurrency: Optional[int] = None,
task_name: str = "",
**function_kwargs: P.kwargs,
) -> bool:
) -> bint:
"""
Checks if all results of the function applied to the iterables are truthy.

Expand Down Expand Up @@ -735,7 +735,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]):
).sum(pop=True, sync=False)

@functools.cached_property
def _sync_default(self) -> bool:
def _sync_default(self) -> bint:
"""
Determines the default execution mode (sync or async) for the function.

Expand All @@ -755,7 +755,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]):
)

@functools.cached_property
def _async_def(self) -> bool:
def _async_def(self) -> bint:
"""
Checks if the wrapped function is an asynchronous function.

Expand Down
3 changes: 2 additions & 1 deletion a_sync/primitives/locks/counter.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ cdef class CounterLock(_DebugDaemonMixin):
cdef char* __name
cdef int _value
cdef dict[int, Event] _events
cdef object is_ready
cpdef bint is_ready(self, int v)
cdef bint c_is_ready(self, int v)
cpdef void set(self, int value)
cdef void c_set(self, int value)
27 changes: 17 additions & 10 deletions a_sync/primitives/locks/counter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ cdef class CounterLock(_DebugDaemonMixin):
:class:`CounterLockCluster` for managing multiple :class:`CounterLock` instances.
"""

def __cinit__(self):
self._events = {}
"""A defaultdict that maps each awaited value to an :class:`Event` that manages the waiters for that value."""

def __init__(self, start_value: int = 0, str name = ""):
"""
Initializes the :class:`CounterLock` with a starting value and an optional name.
Expand Down Expand Up @@ -60,12 +64,6 @@ cdef class CounterLock(_DebugDaemonMixin):
self._value = start_value
"""The current value of the counter."""

self._events: DefaultDict[int, Event] = defaultdict(Event)
"""A defaultdict that maps each awaited value to an :class:`Event` that manages the waiters for that value."""

self.is_ready = lambda v: self._value >= v
"""A lambda function that indicates whether the current counter value is greater than or equal to a given value."""

def __dealloc__(self):
# Free the memory allocated for __name
if self.__name is not NULL:
Expand All @@ -85,7 +83,14 @@ cdef class CounterLock(_DebugDaemonMixin):
cdef dict[int, Py_ssize_t] waiters = {v: len(self._events[v]._waiters) for v in sorted(self._events)}
return "<CounterLock name={} value={} waiters={}>".format(self.__name.decode("utf-8"), self._value, waiters)

async def wait_for(self, value: int) -> bool:
cpdef bint is_ready(self, int v):
"""A function that indicates whether the current counter value is greater than or equal to a given value."""
return self._value >= v

cdef bint c_is_ready(self, int v):
return self._value >= v

async def wait_for(self, int value) -> bint:
"""
Waits until the counter reaches or exceeds the specified value.
Expand All @@ -101,7 +106,9 @@ cdef class CounterLock(_DebugDaemonMixin):
See Also:
:meth:`CounterLock.set` to set the counter value.
"""
if not self.is_ready(value):
if not self.c_is_ready(value):
if value not in self._events:
self._events[value] = Event()
self._c_ensure_debug_daemon((),{})
await self._events[value].c_wait()
return True
Expand Down Expand Up @@ -142,7 +149,7 @@ cdef class CounterLock(_DebugDaemonMixin):
return self._value

@value.setter
def value(self, value: int) -> None:
def value(self, int value) -> None:
"""
Sets the counter to a new value, waking up any waiters if the value increases beyond the value they are awaiting.

Expand Down Expand Up @@ -228,7 +235,7 @@ class CounterLockCluster:
"""
self.locks = list(counter_locks)
async def wait_for(self, value: int) -> bool:
async def wait_for(self, value: int) -> bint:
"""
Waits until the value of all :class:`CounterLock` objects in the cluster reaches or exceeds the specified value.

Expand Down

0 comments on commit 8e0e9d2

Please sign in to comment.