Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: dict type cant hold defaultdict #420

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion a_sync/a_sync/_kwargs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ cdef bint is_sync(str flag, dict kwargs, bint pop_flag):
:func:`get_flag_name`: Retrieves the name of the flag present in the kwargs.
"""
if pop_flag:
# NOTE: we should techincally raise InvalidFlagValue here but I dont want to set flag_value to a var
# NOTE: we should techincally raise InvalidFlagValue here but I dont want
# to set flag_value to a var and it will raise a TypeError anyway.
return negate_if_necessary(flag, kwargs.pop(flag))
else:
try:
Expand Down
8 changes: 3 additions & 5 deletions a_sync/a_sync/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class ASyncGenericBase(ASyncABC):
return flag

@functools.cached_property
def __a_sync_flag_value__(self) -> bool:
def __a_sync_flag_value__(self) -> bint:
# TODO: cythonize this cache
"""If you wish to be able to hotswap default modes, just duplicate this def as a non-cached property."""
if c_logger.isEnabledFor(logging.DEBUG):
Expand Down Expand Up @@ -202,9 +202,8 @@ cdef str _get_a_sync_flag_name_from_signature(object cls):


cdef str _parse_flag_name_from_list(object cls, object items):
cdef list[str] present_flags
cdef str flag
present_flags = [flag for flag in VIABLE_FLAGS if flag in items]
cdef list[str] present_flags = [flag for flag in VIABLE_FLAGS if flag in items]
if not present_flags:
c_logger.debug("There are too many flags defined on %s", cls)
raise exceptions.NoFlagsFound(cls, items.keys())
Expand All @@ -215,8 +214,7 @@ cdef str _parse_flag_name_from_list(object cls, object items):
flag = present_flags[0]
c_logger._log(logging.DEBUG, "found flag %s", flag)
return flag
else:
return present_flags[0]
return present_flags[0]


cdef bint _get_a_sync_flag_value_from_class_def(object cls, str flag):
Expand Down
14 changes: 7 additions & 7 deletions a_sync/a_sync/function.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]):
def __init__(
self,
fn: AnyFn[P, T],
_skip_validate: bool = False,
_skip_validate: bint = False,
**modifiers: Unpack[ModifierKwargs],
) -> None:
"""
Expand Down Expand Up @@ -416,7 +416,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]):
concurrency: Optional[int] = None,
task_name: str = "",
**function_kwargs: P.kwargs,
) -> bool:
) -> bint:
"""
Checks if any result of the function applied to the iterables is truthy.

Expand Down Expand Up @@ -445,7 +445,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]):
concurrency: Optional[int] = None,
task_name: str = "",
**function_kwargs: P.kwargs,
) -> bool:
) -> bint:
"""
Checks if all results of the function applied to the iterables are truthy.

Expand Down Expand Up @@ -595,7 +595,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]):
concurrency: Optional[int] = None,
task_name: str = "",
**function_kwargs: P.kwargs,
) -> bool:
) -> bint:
"""
Checks if any result of the function applied to the iterables is truthy.

Expand Down Expand Up @@ -624,7 +624,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]):
concurrency: Optional[int] = None,
task_name: str = "",
**function_kwargs: P.kwargs,
) -> bool:
) -> bint:
"""
Checks if all results of the function applied to the iterables are truthy.

Expand Down Expand Up @@ -735,7 +735,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]):
).sum(pop=True, sync=False)

@functools.cached_property
def _sync_default(self) -> bool:
def _sync_default(self) -> bint:
"""
Determines the default execution mode (sync or async) for the function.

Expand All @@ -755,7 +755,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]):
)

@functools.cached_property
def _async_def(self) -> bool:
def _async_def(self) -> bint:
"""
Checks if the wrapped function is an asynchronous function.

Expand Down
3 changes: 2 additions & 1 deletion a_sync/primitives/locks/counter.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ cdef class CounterLock(_DebugDaemonMixin):
cdef char* __name
cdef int _value
cdef dict[int, Event] _events
cdef object is_ready
cpdef bint is_ready(self, int v)
cdef bint c_is_ready(self, int v)
cpdef void set(self, int value)
cdef void c_set(self, int value)
27 changes: 17 additions & 10 deletions a_sync/primitives/locks/counter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ cdef class CounterLock(_DebugDaemonMixin):
:class:`CounterLockCluster` for managing multiple :class:`CounterLock` instances.
"""

def __cinit__(self):
self._events = {}
"""A defaultdict that maps each awaited value to an :class:`Event` that manages the waiters for that value."""

def __init__(self, start_value: int = 0, str name = ""):
"""
Initializes the :class:`CounterLock` with a starting value and an optional name.
Expand Down Expand Up @@ -60,12 +64,6 @@ cdef class CounterLock(_DebugDaemonMixin):
self._value = start_value
"""The current value of the counter."""

self._events: DefaultDict[int, Event] = defaultdict(Event)
"""A defaultdict that maps each awaited value to an :class:`Event` that manages the waiters for that value."""

self.is_ready = lambda v: self._value >= v
"""A lambda function that indicates whether the current counter value is greater than or equal to a given value."""

def __dealloc__(self):
# Free the memory allocated for __name
if self.__name is not NULL:
Expand All @@ -85,7 +83,14 @@ cdef class CounterLock(_DebugDaemonMixin):
cdef dict[int, Py_ssize_t] waiters = {v: len(self._events[v]._waiters) for v in sorted(self._events)}
return "<CounterLock name={} value={} waiters={}>".format(self.__name.decode("utf-8"), self._value, waiters)

async def wait_for(self, value: int) -> bool:
cpdef bint is_ready(self, int v):
"""A function that indicates whether the current counter value is greater than or equal to a given value."""
return self._value >= v

cdef bint c_is_ready(self, int v):
return self._value >= v

async def wait_for(self, int value) -> bint:
"""
Waits until the counter reaches or exceeds the specified value.

Expand All @@ -101,7 +106,9 @@ cdef class CounterLock(_DebugDaemonMixin):
See Also:
:meth:`CounterLock.set` to set the counter value.
"""
if not self.is_ready(value):
if not self.c_is_ready(value):
if value not in self._events:
self._events[value] = Event()
self._c_ensure_debug_daemon((),{})
await self._events[value].c_wait()
return True
Expand Down Expand Up @@ -142,7 +149,7 @@ cdef class CounterLock(_DebugDaemonMixin):
return self._value

@value.setter
def value(self, value: int) -> None:
def value(self, int value) -> None:
"""
Sets the counter to a new value, waking up any waiters if the value increases beyond the value they are awaiting.

Expand Down Expand Up @@ -228,7 +235,7 @@ class CounterLockCluster:
"""
self.locks = list(counter_locks)

async def wait_for(self, value: int) -> bool:
async def wait_for(self, value: int) -> bint:
"""
Waits until the value of all :class:`CounterLock` objects in the cluster reaches or exceeds the specified value.

Expand Down
Loading