diff --git a/a_sync/a_sync/_kwargs.pyx b/a_sync/a_sync/_kwargs.pyx index e3e3a5dc..78bcdadf 100644 --- a/a_sync/a_sync/_kwargs.pyx +++ b/a_sync/a_sync/_kwargs.pyx @@ -55,7 +55,8 @@ cdef bint is_sync(str flag, dict kwargs, bint pop_flag): :func:`get_flag_name`: Retrieves the name of the flag present in the kwargs. """ if pop_flag: - # NOTE: we should techincally raise InvalidFlagValue here but I dont want to set flag_value to a var + # NOTE: we should techincally raise InvalidFlagValue here but I dont want + # to set flag_value to a var and it will raise a TypeError anyway. return negate_if_necessary(flag, kwargs.pop(flag)) else: try: diff --git a/a_sync/a_sync/base.pyx b/a_sync/a_sync/base.pyx index 9ed8a5d5..cf0d8b8a 100644 --- a/a_sync/a_sync/base.pyx +++ b/a_sync/a_sync/base.pyx @@ -92,7 +92,7 @@ class ASyncGenericBase(ASyncABC): return flag @functools.cached_property - def __a_sync_flag_value__(self) -> bool: + def __a_sync_flag_value__(self) -> bint: # TODO: cythonize this cache """If you wish to be able to hotswap default modes, just duplicate this def as a non-cached property.""" if c_logger.isEnabledFor(logging.DEBUG): @@ -202,9 +202,8 @@ cdef str _get_a_sync_flag_name_from_signature(object cls): cdef str _parse_flag_name_from_list(object cls, object items): - cdef list[str] present_flags cdef str flag - present_flags = [flag for flag in VIABLE_FLAGS if flag in items] + cdef list[str] present_flags = [flag for flag in VIABLE_FLAGS if flag in items] if not present_flags: c_logger.debug("There are too many flags defined on %s", cls) raise exceptions.NoFlagsFound(cls, items.keys()) @@ -215,8 +214,7 @@ cdef str _parse_flag_name_from_list(object cls, object items): flag = present_flags[0] c_logger._log(logging.DEBUG, "found flag %s", flag) return flag - else: - return present_flags[0] + return present_flags[0] cdef bint _get_a_sync_flag_value_from_class_def(object cls, str flag): diff --git a/a_sync/a_sync/function.pyx b/a_sync/a_sync/function.pyx index 6ac29873..3f7c1432 100644 --- a/a_sync/a_sync/function.pyx +++ b/a_sync/a_sync/function.pyx @@ -219,7 +219,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]): def __init__( self, fn: AnyFn[P, T], - _skip_validate: bool = False, + _skip_validate: bint = False, **modifiers: Unpack[ModifierKwargs], ) -> None: """ @@ -416,7 +416,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]): concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs, - ) -> bool: + ) -> bint: """ Checks if any result of the function applied to the iterables is truthy. @@ -445,7 +445,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]): concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs, - ) -> bool: + ) -> bint: """ Checks if all results of the function applied to the iterables are truthy. @@ -595,7 +595,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]): concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs, - ) -> bool: + ) -> bint: """ Checks if any result of the function applied to the iterables is truthy. @@ -624,7 +624,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]): concurrency: Optional[int] = None, task_name: str = "", **function_kwargs: P.kwargs, - ) -> bool: + ) -> bint: """ Checks if all results of the function applied to the iterables are truthy. @@ -735,7 +735,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]): ).sum(pop=True, sync=False) @functools.cached_property - def _sync_default(self) -> bool: + def _sync_default(self) -> bint: """ Determines the default execution mode (sync or async) for the function. @@ -755,7 +755,7 @@ class ASyncFunction(_ModifiedMixin, Generic[P, T]): ) @functools.cached_property - def _async_def(self) -> bool: + def _async_def(self) -> bint: """ Checks if the wrapped function is an asynchronous function. diff --git a/a_sync/primitives/locks/counter.pxd b/a_sync/primitives/locks/counter.pxd index 87362ce3..4441766d 100644 --- a/a_sync/primitives/locks/counter.pxd +++ b/a_sync/primitives/locks/counter.pxd @@ -6,6 +6,7 @@ cdef class CounterLock(_DebugDaemonMixin): cdef char* __name cdef int _value cdef dict[int, Event] _events - cdef object is_ready + cpdef bint is_ready(self, int v) + cdef bint c_is_ready(self, int v) cpdef void set(self, int value) cdef void c_set(self, int value) \ No newline at end of file diff --git a/a_sync/primitives/locks/counter.pyx b/a_sync/primitives/locks/counter.pyx index b8c25b63..d5cbab8b 100644 --- a/a_sync/primitives/locks/counter.pyx +++ b/a_sync/primitives/locks/counter.pyx @@ -31,6 +31,10 @@ cdef class CounterLock(_DebugDaemonMixin): :class:`CounterLockCluster` for managing multiple :class:`CounterLock` instances. """ + def __cinit__(self): + self._events = {} + """A defaultdict that maps each awaited value to an :class:`Event` that manages the waiters for that value.""" + def __init__(self, start_value: int = 0, str name = ""): """ Initializes the :class:`CounterLock` with a starting value and an optional name. @@ -60,12 +64,6 @@ cdef class CounterLock(_DebugDaemonMixin): self._value = start_value """The current value of the counter.""" - self._events: DefaultDict[int, Event] = defaultdict(Event) - """A defaultdict that maps each awaited value to an :class:`Event` that manages the waiters for that value.""" - - self.is_ready = lambda v: self._value >= v - """A lambda function that indicates whether the current counter value is greater than or equal to a given value.""" - def __dealloc__(self): # Free the memory allocated for __name if self.__name is not NULL: @@ -85,7 +83,14 @@ cdef class CounterLock(_DebugDaemonMixin): cdef dict[int, Py_ssize_t] waiters = {v: len(self._events[v]._waiters) for v in sorted(self._events)} return "".format(self.__name.decode("utf-8"), self._value, waiters) - async def wait_for(self, value: int) -> bool: + cpdef bint is_ready(self, int v): + """A function that indicates whether the current counter value is greater than or equal to a given value.""" + return self._value >= v + + cdef bint c_is_ready(self, int v): + return self._value >= v + + async def wait_for(self, int value) -> bint: """ Waits until the counter reaches or exceeds the specified value. @@ -101,7 +106,9 @@ cdef class CounterLock(_DebugDaemonMixin): See Also: :meth:`CounterLock.set` to set the counter value. """ - if not self.is_ready(value): + if not self.c_is_ready(value): + if value not in self._events: + self._events[value] = Event() self._c_ensure_debug_daemon((),{}) await self._events[value].c_wait() return True @@ -142,7 +149,7 @@ cdef class CounterLock(_DebugDaemonMixin): return self._value @value.setter - def value(self, value: int) -> None: + def value(self, int value) -> None: """ Sets the counter to a new value, waking up any waiters if the value increases beyond the value they are awaiting. @@ -228,7 +235,7 @@ class CounterLockCluster: """ self.locks = list(counter_locks) - async def wait_for(self, value: int) -> bool: + async def wait_for(self, value: int) -> bint: """ Waits until the value of all :class:`CounterLock` objects in the cluster reaches or exceeds the specified value.