diff --git a/a_sync/__init__.pxd b/a_sync/__init__.pxd new file mode 100644 index 00000000..01448a65 --- /dev/null +++ b/a_sync/__init__.pxd @@ -0,0 +1 @@ +from a_sync.primitives import * \ No newline at end of file diff --git a/a_sync/asyncio/gather.pyx b/a_sync/asyncio/gather.pyx index 3e694bf6..0766b702 100644 --- a/a_sync/asyncio/gather.pyx +++ b/a_sync/asyncio/gather.pyx @@ -5,7 +5,6 @@ This module provides an enhanced version of :func:`asyncio.gather`. from typing import Any, Awaitable, Dict, List, Mapping, Union, overload from a_sync._typing import * -from a_sync.asyncio.as_completed import _exc_wrap from a_sync.asyncio.as_completed cimport as_completed_mapping try: @@ -202,3 +201,18 @@ async def gather_mapping( cdef bint _is_mapping(object awaitables): return len(awaitables) == 1 and isinstance(awaitables[0], Mapping) + + +async def _exc_wrap(awaitable: Awaitable[T]) -> Union[T, Exception]: + """Wraps an awaitable to catch exceptions and return them instead of raising. + + Args: + awaitable: The awaitable to wrap. + + Returns: + The result of the awaitable or the exception if one is raised. + """ + try: + return await awaitable + except Exception as e: + return e \ No newline at end of file diff --git a/a_sync/executor.py b/a_sync/executor.py index c0249fd0..a3a7fb19 100644 --- a/a_sync/executor.py +++ b/a_sync/executor.py @@ -101,7 +101,7 @@ def submit(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> "asyn - :meth:`run` for running functions with the executor. """ if self.sync_mode: - fut = asyncio.get_event_loop().create_future() + fut = self._get_loop().create_future() try: fut.set_result(fn(*args, **kwargs)) except Exception as e: diff --git a/a_sync/iter.pyx b/a_sync/iter.pyx index cf4fee5e..0ecf0dab 100644 --- a/a_sync/iter.pyx +++ b/a_sync/iter.pyx @@ -726,7 +726,8 @@ class ASyncSorter(_ASyncView[T]): items.append(obj) sort_tasks.append(a_sync.asyncio.create_task(self._function(obj))) for sort_value, obj in sorted( - zip(await asyncio.gather(*sort_tasks), items), reverse=reverse + zip(await asyncio.gather(*sort_tasks), items), + reverse=reverse, ): yield obj else: diff --git a/a_sync/primitives/__init__.pxd b/a_sync/primitives/__init__.pxd new file mode 100644 index 00000000..22885239 --- /dev/null +++ b/a_sync/primitives/__init__.pxd @@ -0,0 +1 @@ +from a_sync.primitives.locks cimport * \ No newline at end of file diff --git a/a_sync/primitives/_debug.pxd b/a_sync/primitives/_debug.pxd new file mode 100644 index 00000000..164ee3a4 --- /dev/null +++ b/a_sync/primitives/_debug.pxd @@ -0,0 +1,9 @@ +from a_sync.primitives._loggable cimport _LoggerMixin + +cdef class _LoopBoundMixin(_LoggerMixin): + cdef object __loop + cpdef object _get_loop(self) + cdef object _c_get_loop(self) + +cdef class _DebugDaemonMixin(_LoopBoundMixin): + cdef object _daemon diff --git a/a_sync/primitives/_debug.pyi b/a_sync/primitives/_debug.pyi index 00823583..32ff0be6 100644 --- a/a_sync/primitives/_debug.pyi +++ b/a_sync/primitives/_debug.pyi @@ -6,7 +6,7 @@ The mixin provides a framework for managing a debug daemon task, which can be us import abc from a_sync.primitives._loggable import _LoggerMixin -class _DebugDaemonMixin(_LoggerMixin, metaclass=abc.ABCMeta): +class _DebugDaemonMixin(_LoggerMixin): """ A mixin class that provides a framework for debugging capabilities using a daemon task. diff --git a/a_sync/primitives/_debug.pyx b/a_sync/primitives/_debug.pyx index fa56253e..a9d3b33d 100644 --- a/a_sync/primitives/_debug.pyx +++ b/a_sync/primitives/_debug.pyx @@ -4,8 +4,9 @@ This module provides a mixin class used to facilitate the creation of debugging The mixin provides a framework for managing a debug daemon task, which can be used to emit rich debug logs from subclass instances whenever debug logging is enabled. Subclasses must implement the specific logging behavior. """ -import abc import asyncio +from asyncio.events import _running_loop +from threading import Lock from typing import Optional from a_sync.a_sync._helpers cimport get_event_loop @@ -13,7 +14,61 @@ from a_sync.asyncio.create_task cimport ccreate_task_simple from a_sync.primitives._loggable import _LoggerMixin -class _DebugDaemonMixin(_LoggerMixin, metaclass=abc.ABCMeta): +cdef extern from "unistd.h": + int getpid() + + +_global_lock = Lock() + + +cdef object _get_running_loop(): + """Return the running event loop or None. + + This is a low-level function intended to be used by event loops. + This function is thread-specific. + """ + cdef object running_loop, pid + running_loop, pid = _running_loop.loop_pid + if running_loop is not None and pid == getpid(): + return running_loop + + +cdef class _LoopBoundMixin(_LoggerMixin): + def __cinit__(self): + self.__loop = None + def __init__(self, *, loop=None): + if loop is not None: + raise TypeError( + 'The loop parameter is not supported. ' + 'As of 3.10, the *loop* parameter was removed' + '{}() since it is no longer necessary.'.format(type(self).__name__) + ) + @property + def _loop(self) -> asyncio.AbstractEventLoop: + return self.__loop + @_loop.setter + def _loop(self, loop: asyncio.AbstractEventLoop): + self.__loop = loop + cpdef object _get_loop(self): + return self._c_get_loop() + cdef object _c_get_loop(self): + cdef object loop = _get_running_loop() + if self.__loop is None: + with _global_lock: + if self.__loop is None: + self.__loop = loop + if loop is None: + return get_event_loop() + elif loop is not self.__loop: + raise RuntimeError( + f'{self!r} is bound to a different event loop', + "running loop: ".format(loop), + "bound to: ".format(self.__loop), + ) + return loop + + +cdef class _DebugDaemonMixin(_LoopBoundMixin): """ A mixin class that provides a framework for debugging capabilities using a daemon task. @@ -23,9 +78,6 @@ class _DebugDaemonMixin(_LoggerMixin, metaclass=abc.ABCMeta): :class:`_LoggerMixin` for logging capabilities. """ - __slots__ = ("_daemon",) - - @abc.abstractmethod async def _debug_daemon(self, fut: asyncio.Future, fn, *args, **kwargs) -> None: """ Abstract method to define the debug daemon's behavior. @@ -49,6 +101,7 @@ class _DebugDaemonMixin(_LoggerMixin, metaclass=abc.ABCMeta): self.logger.debug("Debugging...") await asyncio.sleep(1) """ + raise NotImplementedError def _start_debug_daemon(self, *args, **kwargs) -> "asyncio.Future[None]": """ @@ -74,8 +127,8 @@ class _DebugDaemonMixin(_LoggerMixin, metaclass=abc.ABCMeta): See Also: :meth:`_ensure_debug_daemon` for ensuring the daemon is running. """ - cdef object loop = get_event_loop() - if self.debug_logs_enabled and loop.is_running(): + cdef object loop = self._c_get_loop() + if self.check_debug_logs_enabled() and loop.is_running(): return ccreate_task_simple(self._debug_daemon(*args, **kwargs)) return loop.create_future() @@ -103,11 +156,13 @@ class _DebugDaemonMixin(_LoggerMixin, metaclass=abc.ABCMeta): See Also: :meth:`_start_debug_daemon` for starting the daemon. """ - if not self.debug_logs_enabled: - self._daemon = get_event_loop().create_future() - if not hasattr(self, "_daemon") or self._daemon is None: - self._daemon = self._start_debug_daemon(*args, **kwargs) - self._daemon.add_done_callback(self._stop_debug_daemon) + cdef object daemon = self._daemon + if daemon is None: + if self.check_debug_logs_enabled(): + self._daemon = self._start_debug_daemon(*args, **kwargs) + self._daemon.add_done_callback(self._stop_debug_daemon) + else: + self._daemon = get_event_loop().create_future() return self._daemon def _stop_debug_daemon(self, t: Optional[asyncio.Task] = None) -> None: diff --git a/a_sync/primitives/_loggable.pxd b/a_sync/primitives/_loggable.pxd new file mode 100644 index 00000000..2b62c610 --- /dev/null +++ b/a_sync/primitives/_loggable.pxd @@ -0,0 +1,4 @@ +cdef class _LoggerMixin: + cdef object _logger + cdef object get_logger(self) + cdef bint check_debug_logs_enabled(self) \ No newline at end of file diff --git a/a_sync/primitives/_loggable.pyx b/a_sync/primitives/_loggable.pyx index 56eec489..5f650efa 100644 --- a/a_sync/primitives/_loggable.pyx +++ b/a_sync/primitives/_loggable.pyx @@ -2,11 +2,10 @@ This module provides a mixin class to add debug logging capabilities to other classes. """ -from functools import cached_property from logging import Logger, getLogger, DEBUG -class _LoggerMixin: +cdef class _LoggerMixin: """ A mixin class that adds logging capabilities to other classes. @@ -17,7 +16,7 @@ class _LoggerMixin: - :class:`logging.Logger` """ - @cached_property + @property def logger(self) -> Logger: """ Provides a logger instance specific to the class using this mixin. @@ -48,10 +47,19 @@ class _LoggerMixin: - :func:`logging.getLogger` - :class:`logging.Logger` """ - logger_id = f"{type(self).__module__}.{type(self).__qualname__}" - if hasattr(self, "_name") and self._name: - logger_id += f".{self._name}" - return getLogger(logger_id) + return self.get_logger() + + cdef object get_logger(self): + cdef str logger_id + cdef object logger, cls + logger = self._logger + if not logger: + cls = type(self) + logger_id = "{}.{}".format(cls.__module__, cls.__qualname__) + if hasattr(self, "_name") and self._name: + logger_id += ".{}".format(self._name) + logger = getLogger(logger_id) + return logger @property def debug_logs_enabled(self) -> bool: @@ -69,4 +77,7 @@ class _LoggerMixin: See Also: - :attr:`logging.Logger.isEnabledFor` """ - return self.logger.isEnabledFor(DEBUG) + return self.get_logger().isEnabledFor(DEBUG) + + cdef bint check_debug_logs_enabled(self): + return self.get_logger().isEnabledFor(DEBUG) diff --git a/a_sync/primitives/locks/__init__.pxd b/a_sync/primitives/locks/__init__.pxd new file mode 100644 index 00000000..045c7d34 --- /dev/null +++ b/a_sync/primitives/locks/__init__.pxd @@ -0,0 +1,2 @@ +from a_sync.primitives.locks.counter import CounterLock +from a_sync.primitives.locks.event cimport CythonEvent as Event \ No newline at end of file diff --git a/a_sync/primitives/locks/__init__.py b/a_sync/primitives/locks/__init__.py index e69674a0..1cbec00e 100644 --- a/a_sync/primitives/locks/__init__.py +++ b/a_sync/primitives/locks/__init__.py @@ -1,5 +1,5 @@ from a_sync.primitives.locks.counter import CounterLock -from a_sync.primitives.locks.event import Event +from a_sync.primitives.locks.event import CythonEvent as Event from a_sync.primitives.locks.semaphore import ( DummySemaphore, Semaphore, diff --git a/a_sync/primitives/locks/counter.pxd b/a_sync/primitives/locks/counter.pxd new file mode 100644 index 00000000..87362ce3 --- /dev/null +++ b/a_sync/primitives/locks/counter.pxd @@ -0,0 +1,11 @@ + +from a_sync.primitives._debug cimport _DebugDaemonMixin +from a_sync.primitives.locks.event cimport CythonEvent as Event + +cdef class CounterLock(_DebugDaemonMixin): + cdef char* __name + cdef int _value + cdef dict[int, Event] _events + cdef object is_ready + cpdef void set(self, int value) + cdef void c_set(self, int value) \ No newline at end of file diff --git a/a_sync/primitives/locks/counter.py b/a_sync/primitives/locks/counter.pyx similarity index 78% rename from a_sync/primitives/locks/counter.py rename to a_sync/primitives/locks/counter.pyx index 3b4724a7..6e21ed71 100644 --- a/a_sync/primitives/locks/counter.py +++ b/a_sync/primitives/locks/counter.pyx @@ -6,14 +6,19 @@ import asyncio from collections import defaultdict -from time import time -from typing import DefaultDict, Iterable, Optional +from libc.string cimport strcpy +from libc.stdlib cimport malloc, free +from libc.time cimport time +from typing import DefaultDict, Iterable -from a_sync.primitives._debug import _DebugDaemonMixin -from a_sync.primitives.locks.event import Event +from a_sync.primitives._debug cimport _DebugDaemonMixin +from a_sync.primitives.locks.event cimport CythonEvent as Event +cdef extern from "time.h": + ctypedef long time_t -class CounterLock(_DebugDaemonMixin): + +cdef class CounterLock(_DebugDaemonMixin): """ An async primitive that uses an internal counter to manage task synchronization. @@ -26,9 +31,7 @@ class CounterLock(_DebugDaemonMixin): :class:`CounterLockCluster` for managing multiple :class:`CounterLock` instances. """ - __slots__ = "is_ready", "_name", "_value", "_events" - - def __init__(self, start_value: int = 0, name: Optional[str] = None): + def __init__(self, start_value: int = 0, str name = ""): """ Initializes the :class:`CounterLock` with a starting value and an optional name. @@ -41,9 +44,19 @@ def __init__(self, start_value: int = 0, name: Optional[str] = None): >>> counter.value 0 """ - self._name = name + # we need a constant to coerce to char* + cdef bytes encoded_name = name.encode("utf-8") + cdef Py_ssize_t length = len(encoded_name) + + # Allocate memory for the char* and add 1 for the null character + self.__name = malloc(length + 1) """An optional name for the counter, used in debug logs.""" + if self.__name == NULL: + raise MemoryError("Failed to allocate memory for __name.") + # Copy the bytes data into the char* + strcpy(self.__name, encoded_name) + self._value = start_value """The current value of the counter.""" @@ -53,6 +66,25 @@ def __init__(self, start_value: int = 0, name: Optional[str] = None): self.is_ready = lambda v: self._value >= v """A lambda function that indicates whether the current counter value is greater than or equal to a given value.""" + def __dealloc__(self): + # Free the memory allocated for __name + if self.__name is not NULL: + free(self.__name) + + def __repr__(self) -> str: + """ + Returns a string representation of the :class:`CounterLock` instance. + + The representation includes the name, current value, and the number of waiters for each awaited value. + + Examples: + >>> counter = CounterLock(start_value=0, name="example_counter") + >>> repr(counter) + '' + """ + cdef dict[int, Py_ssize_t] waiters = {v: len(self._events[v]._waiters) for v in sorted(self._events)} + return "".format(self.__name.decode("utf-8"), self._value, waiters) + async def wait_for(self, value: int) -> bool: """ Waits until the counter reaches or exceeds the specified value. @@ -71,10 +103,10 @@ async def wait_for(self, value: int) -> bool: """ if not self.is_ready(value): self._ensure_debug_daemon() - await self._events[value].wait() + await self._events[value].c_wait() return True - def set(self, value: int) -> None: + cpdef void set(self, int value): """ Sets the counter to the specified value. @@ -95,21 +127,7 @@ def set(self, value: int) -> None: See Also: :meth:`CounterLock.value` for direct value assignment. """ - self.value = value - - def __repr__(self) -> str: - """ - Returns a string representation of the :class:`CounterLock` instance. - - The representation includes the name, current value, and the number of waiters for each awaited value. - - Examples: - >>> counter = CounterLock(start_value=0, name="example_counter") - >>> repr(counter) - '' - """ - waiters = {v: len(self._events[v]._waiters) for v in sorted(self._events)} - return f"" + self.c_set(value) @property def value(self) -> int: @@ -144,6 +162,9 @@ def value(self, value: int) -> None: ... ValueError: You cannot decrease the value. """ + self.c_set(value) + + cdef void c_set(self, int value): if value > self._value: self._value = value ready = [ @@ -156,19 +177,30 @@ def value(self, value: int) -> None: elif value < self._value: raise ValueError("You cannot decrease the value.") + @property + def _name(self) -> str: + return self.__name.decode("utf-8") + async def _debug_daemon(self) -> None: """ Periodically logs debug information about the counter state and waiters. This method is used internally to provide debugging information when debug logging is enabled. """ - start = time() + cdef time_t start, now + start = time(NULL) while self._events: - self.logger.debug( - "%s is still locked after %sm", self, round(time() - start / 60, 2) + now = time(NULL) + self.get_logger().debug( + "%s is still locked after %sm", self, round(now - start / 60, 2) ) await asyncio.sleep(300) + def __dealloc__(self): + # Free the memory allocated for __name + if self.__name is not NULL: + free(self.__name) + class CounterLockCluster: """ diff --git a/a_sync/primitives/locks/event.pxd b/a_sync/primitives/locks/event.pxd new file mode 100644 index 00000000..5b80715e --- /dev/null +++ b/a_sync/primitives/locks/event.pxd @@ -0,0 +1,24 @@ +from libc.stdint cimport uint16_t + +from a_sync.primitives._debug cimport _DebugDaemonMixin + +cdef class CythonEvent(_DebugDaemonMixin): + """ + An asyncio.Event with additional debug logging to help detect deadlocks. + + This event class extends asyncio.Event by adding debug logging capabilities. It logs + detailed information about the event state and waiters, which can be useful for + diagnosing and debugging potential deadlocks. + """ + + cdef bint _value + cdef list _waiters + cdef char* __name + cdef uint16_t _debug_daemon_interval + cpdef bint is_set(self) + cpdef void set(self) + cpdef void clear(self) + cpdef object wait(self) + cdef void c_set(self) + cdef void c_clear(self) + cdef object c_wait(self) \ No newline at end of file diff --git a/a_sync/primitives/locks/event.py b/a_sync/primitives/locks/event.py deleted file mode 100644 index f04f951d..00000000 --- a/a_sync/primitives/locks/event.py +++ /dev/null @@ -1,85 +0,0 @@ -""" -This module provides an enhanced version of asyncio.Event with additional debug logging to help detect deadlocks. -""" - -import asyncio -import sys - -from a_sync._typing import * -from a_sync.primitives._debug import _DebugDaemonMixin - - -class Event(asyncio.Event, _DebugDaemonMixin): - """ - An asyncio.Event with additional debug logging to help detect deadlocks. - - This event class extends asyncio.Event by adding debug logging capabilities. It logs - detailed information about the event state and waiters, which can be useful for - diagnosing and debugging potential deadlocks. - """ - - _value: bool - _loop: asyncio.AbstractEventLoop - _waiters: Deque["asyncio.Future[None]"] - if sys.version_info >= (3, 10): - __slots__ = "_value", "_waiters", "_debug_daemon_interval" - else: - __slots__ = "_value", "_loop", "_waiters", "_debug_daemon_interval" - - def __init__( - self, - name: str = "", - debug_daemon_interval: int = 300, - *, - loop: Optional[asyncio.AbstractEventLoop] = None, - ): - """ - Initializes the Event. - - Args: - name (str): An optional name for the event, used in debug logs. - debug_daemon_interval (int): The interval in seconds for the debug daemon to log information. - loop (Optional[asyncio.AbstractEventLoop]): The event loop to use. - """ - if sys.version_info >= (3, 10): - super().__init__() - else: - super().__init__(loop=loop) - self._name = name - # backwards compatability - if hasattr(self, "_loop"): - self._loop = self._loop or asyncio.get_event_loop() - self._debug_daemon_interval = debug_daemon_interval - - def __repr__(self) -> str: - label = f"name={self._name}" if self._name else "object" - status = "set" if self._value else "unset" - if self._waiters: - status += f", waiters:{len(self._waiters)}" - return f"<{self.__class__.__module__}.{self.__class__.__name__} {label} at {hex(id(self))} [{status}]>" - - async def wait(self) -> Literal[True]: - """ - Wait until the event is set. - - Returns: - True when the event is set. - """ - if self.is_set(): - return True - self._ensure_debug_daemon() - return await super().wait() - - async def _debug_daemon(self) -> None: - """ - Periodically logs debug information about the event state and waiters. - """ - weakself = weakref.ref(self) - del self # no need to hold a reference here - while (self := weakself()) and not self.is_set(): - del self # no need to hold a reference here - await asyncio.sleep(self._debug_daemon_interval) - if (self := weakself()) and not self.is_set(): - self.logger.debug( - "Waiting for %s for %sm", self, round((time() - start) / 60, 2) - ) diff --git a/a_sync/primitives/locks/event.pyx b/a_sync/primitives/locks/event.pyx new file mode 100644 index 00000000..43b50f73 --- /dev/null +++ b/a_sync/primitives/locks/event.pyx @@ -0,0 +1,183 @@ +""" +This module provides an enhanced version of asyncio.Event with additional debug logging to help detect deadlocks. +""" + +import asyncio +import sys +import weakref +from libc.stdint cimport uint16_t +from libc.stdlib cimport malloc, free +from libc.string cimport strcpy +from libc.time cimport time + +from a_sync._typing import * +from a_sync.primitives._debug cimport _DebugDaemonMixin + +cdef extern from "time.h": + ctypedef long time_t + + +async def _return_true(): + return True + + +cdef class CythonEvent(_DebugDaemonMixin): + """ + An asyncio.Event with additional debug logging to help detect deadlocks. + + This event class extends asyncio.Event by adding debug logging capabilities. It logs + detailed information about the event state and waiters, which can be useful for + diagnosing and debugging potential deadlocks. + """ + def __init__( + self, + name: str = "", + debug_daemon_interval: int = 300, + *, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): + """ + Initializes the Event. + + Args: + name (str): An optional name for the event, used in debug logs. + debug_daemon_interval (int): The interval in seconds for the debug daemon to log information. + loop (Optional[asyncio.AbstractEventLoop]): The event loop to use. + """ + self._waiters = [] + #self._value = False + if sys.version_info >= (3, 10): + super().__init__() + else: + super().__init__(loop=loop) + + # backwards compatability + if hasattr(self, "_loop"): + self._loop = self._loop or asyncio.get_event_loop() + if debug_daemon_interval > 65535: + raise ValueError(f"'debug_daemon_interval' is stored as a uint16 and must be less than 65535") + self._debug_daemon_interval = debug_daemon_interval + # we need a constant to coerce to char* + cdef bytes encoded_name = name.encode("utf-8") + cdef Py_ssize_t length = len(encoded_name) + + # Allocate memory for the char* and add 1 for the null character + self.__name = malloc(length + 1) + """An optional name for the counter, used in debug logs.""" + + if self.__name == NULL: + raise MemoryError("Failed to allocate memory for __name.") + # Copy the bytes data into the char* + strcpy(self.__name, encoded_name) + + def __dealloc__(self): + # Free the memory allocated for __name + if self.__name is not NULL: + free(self.__name) + + def __repr__(self) -> str: + cdef str label = ( + "name={}".format(self.__name.decode("utf-8")) + if self.__name + else "object" + ) + cdef str status = "set" if self._value else "unset" + if self._waiters: + status += ", waiters:{}".format(len(self._waiters)) + return "<{}.{} {} at {} [{}]>".format( + self.__class__.__module__, + self.__class__.__name__, + label, + hex(id(self)), + status, + ) + + cpdef bint is_set(self): + """Return True if and only if the internal flag is true.""" + return self._value + + cpdef void set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + self.c_set() + + cdef void c_set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + cdef object fut + + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + cpdef void clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + cdef void c_clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + cpdef object wait(self): + """Block until the internal flag is true. + + If the internal flag is true on entry, return True + immediately. Otherwise, block until another coroutine calls + set() to set the flag to true, then return True. + + Returns: + True when the event is set. + """ + return self.c_wait() + + cdef object c_wait(self): + if self._value: + return _return_true() + + self._ensure_debug_daemon() + + cdef object fut = self._c_get_loop().create_future() + self._waiters.append(fut) + return self.__wait(fut) + + @property + def _name(self) -> str: + return self.__name.decode("utf-8") + + async def __wait(self, fut: asyncio.Future) -> Literal[True]: + try: + await fut + return True + finally: + self._waiters.remove(fut) + + async def _debug_daemon(self) -> None: + """ + Periodically logs debug information about the event state and waiters. + """ + cdef time_t start, now + cdef object weakself = weakref.ref(self) + cdef unsigned int loops = 0 + cdef uint16_t interval = self._debug_daemon_interval + + start = time(NULL) + while (self := weakself()) and not self._value: + if loops: + now = time(NULL) + self.get_logger().debug( + "Waiting for %s for %sm", self, round((now - start) / 60, 2) + ) + del self # no need to hold a reference here + await asyncio.sleep(interval) + loops += 1 diff --git a/a_sync/primitives/locks/semaphore.py b/a_sync/primitives/locks/semaphore.pyx similarity index 66% rename from a_sync/primitives/locks/semaphore.py rename to a_sync/primitives/locks/semaphore.pyx index 38d2cd06..c476edca 100644 --- a/a_sync/primitives/locks/semaphore.py +++ b/a_sync/primitives/locks/semaphore.pyx @@ -6,17 +6,16 @@ import asyncio import functools import logging -import sys from collections import defaultdict from threading import Thread, current_thread from a_sync._typing import * -from a_sync.primitives._debug import _DebugDaemonMixin +from a_sync.primitives._debug cimport _DebugDaemonMixin logger = logging.getLogger(__name__) -class Semaphore(asyncio.Semaphore, _DebugDaemonMixin): +cdef class Semaphore(_DebugDaemonMixin): """ A semaphore with additional debugging capabilities inherited from :class:`_DebugDaemonMixin`. @@ -47,13 +46,14 @@ async def limited(): See Also: :class:`_DebugDaemonMixin` for more details on debugging capabilities. """ + cdef str name + cdef int _value + cdef object _waiters + cdef set _decorated + cdef dict __dict__ - if sys.version_info >= (3, 10): - __slots__ = "name", "_value", "_waiters", "_decorated" - else: - __slots__ = "name", "_value", "_waiters", "_loop", "_decorated" - def __init__(self, value: int, name=None, **kwargs) -> None: + def __init__(self, value: int=1, name=None, loop=None, **kwargs) -> None: """ Initialize the semaphore with a given value and optional name for debugging. @@ -61,7 +61,12 @@ def __init__(self, value: int, name=None, **kwargs) -> None: value: The initial value for the semaphore. name (optional): An optional name used only to provide useful context in debug logs. """ - super().__init__(value, **kwargs) + super().__init__(loop=loop) + if value < 0: + raise ValueError("Semaphore initial value must be >= 0") + + self._waiters = None + self._value = value self.name = name or self.__origin__ if hasattr(self, "__origin__") else None self._decorated: Set[str] = set() @@ -86,6 +91,20 @@ def __repr__(self) -> str: representation = f"{representation[:-1]} decorates={self._decorated}" return representation + async def __aenter__(self): + await self.acquire() + # We have no use for the "as ..." clause in the with + # statement for locks. + return None + + async def __aexit__(self, exc_type, exc, tb): + self.c_release() + + def locked(self): + """Returns True if semaphore cannot be acquired immediately.""" + return self._value == 0 or ( + any(not w.cancelled() for w in (self._waiters or ()))) + def __len__(self) -> int: return len(self._waiters) if self._waiters else 0 @@ -115,6 +134,10 @@ async def acquire(self) -> Literal[True]: """ Acquire the semaphore, ensuring that debug logging is enabled if there are waiters. + If the internal counter is larger than zero on entry, decrement it by one and return + True immediately. If it is zero on entry, block, waiting until some other coroutine + has called release() to make it larger than 0, and then return True. + If the semaphore value is zero or less, the debug daemon is started to log the state of the semaphore. Returns: @@ -122,8 +145,58 @@ async def acquire(self) -> Literal[True]: """ if self._value <= 0: self._ensure_debug_daemon() - return await super().acquire() + if not self.locked(): + self._value -= 1 + return True + import collections + if self._waiters is None: + self._waiters = collections.deque() + fut = self._c_get_loop().create_future() + self._waiters.append(fut) + + # Finally block should be called before the CancelledError + # handling as we don't want CancelledError to call + # _wake_up_first() and attempt to wake up itself. + try: + try: + await fut + finally: + self._waiters.remove(fut) + except asyncio.exceptions.CancelledError: + if not fut.cancelled(): + self._value += 1 + self._wake_up_next() + raise + + if self._value > 0: + self._wake_up_next() + return True + + cpdef void release(self): + """Release a semaphore, incrementing the internal counter by one. + + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + """ + self._value += 1 + self._wake_up_next() + + cdef void c_release(self): + self._value += 1 + self._wake_up_next() + + cdef void _wake_up_next(self): + """Wake up the first waiter that isn't done.""" + if not self._waiters: + return + + for fut in self._waiters: + if not fut.done(): + self._value -= 1 + fut.set_result(True) + return + async def _debug_daemon(self) -> None: """ Daemon coroutine (runs in a background task) which will emit a debug log every minute while the semaphore has waiters. @@ -139,7 +212,7 @@ async def monitor(): """ while self._waiters: await asyncio.sleep(60) - self.logger.debug( + self.get_logger().debug( "%s has %s waiters for any of: %s", self, len(self), @@ -147,7 +220,7 @@ async def monitor(): ) -class DummySemaphore(asyncio.Semaphore): +cdef class DummySemaphore(Semaphore): """ A dummy semaphore that implements the standard :class:`asyncio.Semaphore` API but does nothing. @@ -161,7 +234,9 @@ async def no_op(): return 1 """ - __slots__ = "name", "_value" + def __cinit__(self): + self._value = 0 + self._waiters = None def __init__(self, name: Optional[str] = None): """ @@ -171,16 +246,18 @@ def __init__(self, name: Optional[str] = None): name (optional): An optional name for the dummy semaphore. """ self.name = name - self._value = 0 def __repr__(self) -> str: - return f"<{self.__class__.__name__} name={self.name}>" + return "<{} name={}>".format(self.__class__.__name__, self.name) async def acquire(self) -> Literal[True]: """Acquire the dummy semaphore, which is a no-op.""" return True - def release(self) -> None: + cpdef void release(self): + """No-op release method.""" + + cdef void c_release(self): """No-op release method.""" async def __aenter__(self): @@ -191,7 +268,7 @@ async def __aexit__(self, *args): """No-op context manager exit.""" -class ThreadsafeSemaphore(Semaphore): +cdef class ThreadsafeSemaphore(Semaphore): """ A semaphore that works in a multi-threaded environment. @@ -210,7 +287,8 @@ async def limited(): :class:`Semaphore` for the base class implementation. """ - __slots__ = "semaphores", "dummy" + cdef object semaphores, dummy + cdef bint use_dummy def __init__(self, value: Optional[int], name: Optional[str] = None) -> None: """ @@ -222,22 +300,18 @@ def __init__(self, value: Optional[int], name: Optional[str] = None) -> None: """ assert isinstance(value, int), f"{value} should be an integer." super().__init__(value, name=name) - self.semaphores: DefaultDict[Thread, Semaphore] = defaultdict(lambda: Semaphore(value, name=self.name)) # type: ignore [arg-type] - self.dummy = DummySemaphore(name=name) + + self.use_dummy = value is -1 + if self.use_dummy: + self.semaphores = {} + self.dummy = DummySemaphore(name=name) + else: + self.semaphores: DefaultDict[Thread, Semaphore] = defaultdict(lambda: Semaphore(value, name=self.name)) # type: ignore [arg-type] + self.dummy = None def __len__(self) -> int: return sum(len(sem._waiters) for sem in self.semaphores.values()) - @functools.cached_property - def use_dummy(self) -> bool: - """ - Determine whether to use a dummy semaphore. - - Returns: - True if the semaphore value is None, indicating the use of a dummy semaphore. - """ - return self._value is None - @property def semaphore(self) -> Semaphore: """ @@ -252,10 +326,12 @@ async def limited(): async with semaphore.semaphore: return 1 """ + + cdef Semaphore c_get_semaphore(self): return self.dummy if self.use_dummy else self.semaphores[current_thread()] async def __aenter__(self): - await self.semaphore.acquire() + await self.c_get_semaphore().acquire() async def __aexit__(self, *args): - self.semaphore.release() + self.c_get_semaphore().c_release() diff --git a/a_sync/task.pyx b/a_sync/task.pyx index b80b050b..04aa3dc8 100644 --- a/a_sync/task.pyx +++ b/a_sync/task.pyx @@ -15,7 +15,6 @@ import inspect import logging import weakref -import a_sync.asyncio from a_sync import exceptions from a_sync._typing import * from a_sync.a_sync._kwargs cimport get_flag_name @@ -26,10 +25,12 @@ from a_sync.a_sync.method import ( ASyncMethodDescriptorSyncDefault, ) from a_sync.a_sync.property import _ASyncPropertyDescriptorBase +from a_sync.asyncio import create_task, gather +from a_sync.asyncio.as_completed cimport as_completed_mapping from a_sync.asyncio.gather import Excluder from a_sync.iter import ASyncIterator, ASyncGeneratorFunction, ASyncSorter -from a_sync.primitives.queue import Queue, ProcessingQueue -from a_sync.primitives.locks.event import Event +from a_sync.primitives import Queue, ProcessingQueue +from a_sync.primitives cimport Event from a_sync.utils.iterators import as_yielded, exhaust_iterator @@ -67,7 +68,7 @@ class TaskMapping(DefaultDict[K, "asyncio.Task[V]"], AsyncIterable[Tuple[K, V]]) See Also: - :class:`asyncio.Task` - :func:`asyncio.create_task` - - :func:`a_sync.asyncio.create_task` + - :func:`create_task` """ concurrency: Optional[int] = None @@ -127,6 +128,7 @@ class TaskMapping(DefaultDict[K, "asyncio.Task[V]"], AsyncIterable[Tuple[K, V]]) task_map = TaskMapping(process_item, [1, 2, 3], concurrency=2) """ + cdef Event _next if concurrency: self.concurrency = concurrency @@ -152,7 +154,11 @@ class TaskMapping(DefaultDict[K, "asyncio.Task[V]"], AsyncIterable[Tuple[K, V]]) self._name = name if iterables: - self._next = Event(name="{} `_next`".format(self)) + _next = Event(name="{} `_next`".format(self)) + + # NOTE self._next will be a python object when retrieved for __aiter__ + # but _next is still a c object for _wrapped_set_next + self._next = _next @functools.wraps(wrapped_func) async def _wrapped_set_next( @@ -193,8 +199,8 @@ class TaskMapping(DefaultDict[K, "asyncio.Task[V]"], AsyncIterable[Tuple[K, V]]) else e2.with_traceback(e2.__traceback__) ) finally: - self._next.set() - self._next.clear() + _next.c_set() + _next.c_clear() self._wrapped_func = _wrapped_set_next init_loader_queue: Queue[Tuple[K, "asyncio.Future[V]"]] = Queue() @@ -230,7 +236,7 @@ class TaskMapping(DefaultDict[K, "asyncio.Task[V]"], AsyncIterable[Tuple[K, V]]) # NOTE: we use a queue instead of a Semaphore to reduce memory use for use cases involving many many tasks fut = self._queue.put_nowait(item) else: - fut = a_sync.asyncio.create_task( + fut = create_task( coro=self._wrapped_func(item, **self._wrapped_func_kwargs), name="{}[{}]".format(self._name, item) if self._name else "{}".format(item), ) @@ -245,6 +251,8 @@ class TaskMapping(DefaultDict[K, "asyncio.Task[V]"], AsyncIterable[Tuple[K, V]]) """Asynchronously iterate through all key-task pairs, yielding the key-result pair as each task completes.""" cdef dict ready cdef object key, task, unyielded, value + cdef Event next = self._next + _if_pop_check_destroyed(self, pop) # if you inited the TaskMapping with some iterators, we will load those @@ -252,7 +260,7 @@ class TaskMapping(DefaultDict[K, "asyncio.Task[V]"], AsyncIterable[Tuple[K, V]]) try: if self._init_loader is None: # if you didn't init the TaskMapping with iterators and you didn't start any tasks manually, we should fail - _raise_if_empty(self, "") + _raise_if_empty(mapping=self, msg="") else: while not self._init_loader.done(): await self._wait_for_next_key() @@ -269,20 +277,30 @@ class TaskMapping(DefaultDict[K, "asyncio.Task[V]"], AsyncIterable[Tuple[K, V]]) yield key, await task yielded.add(key) else: - await self._next.wait() + await next.c_wait() # loader is already done by this point, but we need to check for exceptions await self._init_loader # if there are any tasks that still need to complete, we need to await them and yield them if unyielded := {key: self[key] for key in self if key not in yielded}: if pop: - async for key, value in a_sync.asyncio.as_completed( - unyielded, aiter=True + async for key, value in as_completed_mapping( + mapping=unyielded, + timeout=0, + return_exceptions=False, + aiter=True, + tqdm=False, + tqdm_kwargs={}, ): self.pop(key) yield key, value else: - async for key, value in a_sync.asyncio.as_completed( - unyielded, aiter=True + async for key, value in as_completed_mapping( + mapping=unyielded, + timeout=0, + return_exceptions=False, + aiter=True, + tqdm=False, + tqdm_kwargs={} ): yield key, value finally: @@ -346,7 +364,8 @@ class TaskMapping(DefaultDict[K, "asyncio.Task[V]"], AsyncIterable[Tuple[K, V]]) try: if iterables: - _raise_if_not_empty(self) + if self: + raise exceptions.MappingNotEmptyError(self) try: async for _ in self._tasks_for_iterables(*iterables): async for key, value in self.yield_completed(pop=pop): @@ -365,20 +384,30 @@ class TaskMapping(DefaultDict[K, "asyncio.Task[V]"], AsyncIterable[Tuple[K, V]]) else: _raise_if_empty( - self, - "You must either initialize your TaskMapping with an iterable(s) or provide them during your call to map", + mapping=self, + msg="You must either initialize your TaskMapping with an iterable(s) or provide them during your call to map", ) if self: if pop: - async for key, value in a_sync.asyncio.as_completed( - self, aiter=True + async for key, value in as_completed_mapping( + mapping=self, + timeout=0, + return_exceptions=False, + aiter=True, + tqdm=False, + tqdm_kwargs={}, ): self.pop(key) yield _yield(key, value, yields) else: - async for key, value in a_sync.asyncio.as_completed( - self, aiter=True + async for key, value in as_completed_mapping( + mapping=self, + timeout=0, + return_exceptions=False, + aiter=True, + tqdm=False, + tqdm_kwargs={}, ): yield _yield(key, value, yields) finally: @@ -499,8 +528,8 @@ class TaskMapping(DefaultDict[K, "asyncio.Task[V]"], AsyncIterable[Tuple[K, V]]) """Wait for all tasks to complete and return a dictionary of the results.""" if self._init_loader: await self._init_loader - _raise_if_empty(self, "") - return await a_sync.asyncio.gather( + _raise_if_empty(mapping=self, msg="") + return await gather( self, return_exceptions=return_exceptions, exclude_if=exclude_if, @@ -572,7 +601,7 @@ class TaskMapping(DefaultDict[K, "asyncio.Task[V]"], AsyncIterable[Tuple[K, V]]) logger.debug("starting %s init loader", self) name = "{} init loader loading {} for {}".format(type(self).__name__, self.__iterables__, self) try: - task = a_sync.asyncio.create_task( + task = create_task( coro=self.__init_loader_coro, name=name ) except RuntimeError as e: @@ -638,7 +667,7 @@ class TaskMapping(DefaultDict[K, "asyncio.Task[V]"], AsyncIterable[Tuple[K, V]]) done, pending = await asyncio.wait( [ - a_sync.asyncio.create_task( + create_task( self._init_loader_next(), log_destroy_pending=False ), self._init_loader, @@ -664,12 +693,7 @@ cdef void _if_pop_check_destroyed(object mapping, bint pop): cdef void _raise_if_empty(object mapping, str msg): if not mapping: - raise exceptions.MappingIsEmptyError(mapping, msg) - - -cdef void _raise_if_not_empty(object mapping): - if mapping: - raise exceptions.MappingNotEmptyError(mapping) + raise exceptions.MappingIsEmptyError(mapping, msg) class _NoRunningLoop(Exception): ... diff --git a/setup.py b/setup.py index 6ff06f2e..abef1527 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,11 @@ }, ext_modules=cythonize( "a_sync/**/*.pyx", - compiler_directives={"embedsignature": True, "linetrace": True}, + compiler_directives={ + "language_level": 3, + "embedsignature": True, + "linetrace": True, + }, ), zip_safe=False, )