diff --git a/a_sync/primitives/locks/counter.pxd b/a_sync/primitives/locks/counter.pxd index 2513fa08..6aca939f 100644 --- a/a_sync/primitives/locks/counter.pxd +++ b/a_sync/primitives/locks/counter.pxd @@ -5,8 +5,9 @@ from a_sync.primitives.locks.event cimport CythonEvent as Event cdef class CounterLock(_DebugDaemonMixin): cdef char* __name cdef long long _value + cdef list _heap cdef dict[long long, Event] _events cpdef bint is_ready(self, long long v) cdef bint c_is_ready(self, long long v) cpdef void set(self, long long value) - cdef void c_set(self, long long value) \ No newline at end of file + cdef void c_set(self, long long value) diff --git a/a_sync/primitives/locks/counter.pyx b/a_sync/primitives/locks/counter.pyx index 7544b682..5d2f5af7 100644 --- a/a_sync/primitives/locks/counter.pyx +++ b/a_sync/primitives/locks/counter.pyx @@ -5,11 +5,11 @@ These primitives manage synchronization of tasks that must wait for an internal """ import asyncio -from collections import defaultdict +import heapq from libc.string cimport strcpy from libc.stdlib cimport malloc, free from libc.time cimport time -from typing import DefaultDict, Iterable +from typing import Iterable from a_sync.primitives._debug cimport _DebugDaemonMixin from a_sync.primitives.locks.event cimport CythonEvent as Event @@ -35,6 +35,8 @@ cdef class CounterLock(_DebugDaemonMixin): self._events = {} """A defaultdict that maps each awaited value to an :class:`Event` that manages the waiters for that value.""" + self._heap = [] + def __init__(self, start_value: int = 0, str name = ""): """ Initializes the :class:`CounterLock` with a starting value and an optional name. @@ -111,6 +113,7 @@ cdef class CounterLock(_DebugDaemonMixin): if event is None: event = Event() self._events[value] = event + heapq.heappush(self._heap, value) self._c_ensure_debug_daemon((),{}) await (event).c_wait() return True @@ -174,15 +177,16 @@ cdef class CounterLock(_DebugDaemonMixin): self.c_set(value) cdef void c_set(self, long long value): + cdef long long key if value > self._value: self._value = value - ready = [ - self._events.pop(key) - for key in list(self._events.keys()) - if key <= self._value - ] - for event in ready: - event.set() + while self._heap: + key = heapq.heappop(self._heap) + if key <= self._value: + (self._events.pop(key)).c_set() + else: + heapq.heappush(self._heap, key) + return elif value < self._value: raise ValueError("You cannot decrease the value.")