Skip to content

Commit

Permalink
fix: accidentally made Semaphore._value unreachable (#414)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Nov 21, 2024
1 parent b0ba49f commit 620a3eb
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 33 deletions.
82 changes: 58 additions & 24 deletions a_sync/primitives/locks/semaphore.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ a dummy semaphore that does nothing, and a threadsafe semaphore for use in multi
import asyncio
import functools
import logging
from collections import defaultdict
from collections import defaultdict, deque
from threading import Thread, current_thread

from a_sync._typing import *
Expand All @@ -15,6 +15,10 @@ from a_sync.primitives._debug cimport _DebugDaemonMixin
logger = logging.getLogger(__name__)


async def __acquire() -> Literal[True]:
return True


cdef class Semaphore(_DebugDaemonMixin):
"""
A semaphore with additional debugging capabilities inherited from :class:`_DebugDaemonMixin`.
Expand Down Expand Up @@ -47,12 +51,12 @@ cdef class Semaphore(_DebugDaemonMixin):
:class:`_DebugDaemonMixin` for more details on debugging capabilities.
"""
cdef str name
cdef int _value
cdef int __value
cdef object _waiters
cdef set _decorated
cdef dict __dict__


def __init__(self, value: int=1, name=None, loop=None, **kwargs) -> None:
"""
Initialize the semaphore with a given value and optional name for debugging.
Expand All @@ -66,7 +70,7 @@ cdef class Semaphore(_DebugDaemonMixin):
raise ValueError("Semaphore initial value must be >= 0")
self._waiters = None
self._value = value
self.__value = value
self.name = name or self.__origin__ if hasattr(self, "__origin__") else None
self._decorated: Set[str] = set()
Expand All @@ -86,23 +90,27 @@ cdef class Semaphore(_DebugDaemonMixin):
return self.decorate(fn) # type: ignore [arg-type, return-value]
def __repr__(self) -> str:
representation = f"<{self.__class__.__name__} name={self.name} value={self._value} waiters={len(self)}>"
representation = f"<{self.__class__.__name__} name={self.name} value={self.__value} waiters={len(self)}>"
if self._decorated:
representation = f"{representation[:-1]} decorates={self._decorated}"
return representation
async def __aenter__(self):
await self.acquire()
await self.c_acquire()
# We have no use for the "as ..." clause in the with
# statement for locks.
return None
async def __aexit__(self, exc_type, exc, tb):
self.c_release()
def locked(self):
cpdef bint locked(self):
"""Returns True if semaphore cannot be acquired immediately."""
return self._value == 0 or (
return self.c_locked()
cdef bint c_locked(self):
"""Returns True if semaphore cannot be acquired immediately."""
return self.__value == 0 or (
any(not w.cancelled() for w in (self._waiters or ())))
def __len__(self) -> int:
Expand Down Expand Up @@ -130,7 +138,7 @@ cdef class Semaphore(_DebugDaemonMixin):
self._decorated.add(f"{fn.__module__}.{fn.__name__}")
return semaphore_wrapper
async def acquire(self) -> Literal[True]:
cpdef object acquire(self):
"""
Acquire the semaphore, ensuring that debug logging is enabled if there are waiters.

Expand All @@ -143,33 +151,52 @@ cdef class Semaphore(_DebugDaemonMixin):
Returns:
True when the semaphore is successfully acquired.
"""
if self._value <= 0:
return self.c_acquire()
cdef object c_acquire(self):
"""
Acquire the semaphore, ensuring that debug logging is enabled if there are waiters.

If the internal counter is larger than zero on entry, decrement it by one and return
True immediately. If it is zero on entry, block, waiting until some other coroutine
has called release() to make it larger than 0, and then return True.

If the semaphore value is zero or less, the debug daemon is started to log the state of the semaphore.

Returns:
True when the semaphore is successfully acquired.
"""
if self.__value <= 0:
self._ensure_debug_daemon()
if not self.locked():
self._value -= 1
return True
import collections
if not self.c_locked():
self.__value -= 1
return __acquire()
if self._waiters is None:
self._waiters = collections.deque()
fut = self._c_get_loop().create_future()
self._waiters.append(fut)
self._waiters = deque()
return self.__acquire()
async def __acquire(self) -> Literal[True]:
# Finally block should be called before the CancelledError
# handling as we don't want CancelledError to call
# _wake_up_first() and attempt to wake up itself.
cdef object fut = self._c_get_loop().create_future()
self._waiters.append(fut)
try:
try:
await fut
finally:
self._waiters.remove(fut)
except asyncio.exceptions.CancelledError:
if not fut.cancelled():
self._value += 1
self.__value += 1
self._wake_up_next()
raise
if self._value > 0:
if self.__value > 0:
self._wake_up_next()
return True
Expand All @@ -179,21 +206,25 @@ cdef class Semaphore(_DebugDaemonMixin):
When it was zero on entry and another coroutine is waiting for it to
become larger than zero again, wake up that coroutine.
"""
self._value += 1
self.__value += 1
self._wake_up_next()
cdef void c_release(self):
self._value += 1
self.__value += 1
self._wake_up_next()
@property
def _value(self) -> int:
return self.__value
cdef void _wake_up_next(self):
"""Wake up the first waiter that isn't done."""
if not self._waiters:
return

for fut in self._waiters:
if not fut.done():
self._value -= 1
self.__value -= 1
fut.set_result(True)
return

Expand Down Expand Up @@ -235,7 +266,7 @@ cdef class DummySemaphore(Semaphore):
"""

def __cinit__(self):
self._value = 0
self.__value = 0
self._waiters = None

def __init__(self, name: Optional[str] = None):
Expand All @@ -253,6 +284,9 @@ cdef class DummySemaphore(Semaphore):
async def acquire(self) -> Literal[True]:
"""Acquire the dummy semaphore, which is a no-op."""
return True

async def c_acquire(self) -> Literal[True]:
return True

cpdef void release(self):
"""No-op release method."""
Expand Down Expand Up @@ -331,7 +365,7 @@ cdef class ThreadsafeSemaphore(Semaphore):
return self.dummy if self.use_dummy else self.semaphores[current_thread()]
async def __aenter__(self):
await self.c_get_semaphore().acquire()
await self.c_get_semaphore().c_acquire()
async def __aexit__(self, *args):
self.c_get_semaphore().c_release()
21 changes: 12 additions & 9 deletions tests/test_semaphore.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
from time import time

from a_sync import Semaphore
from tests.fixtures import TestSemaphore, increment


Expand All @@ -12,15 +14,18 @@
instance = TestSemaphore(1, sync=False)


def test_semaphore_init():
assert Semaphore(1)._value == Semaphore()._value == 1


@increment
@pytest.mark.asyncio_cooperative
async def test_semaphore(i: int):
start = time()
assert await instance.test_fn() == 1
duration = time() - start
assert (
i < 3 or duration > i
) # There is a 1 second sleep in this fn. If the semaphore is not working, all tests will complete in 1 second.
# There is a 1 second sleep in this fn. If the semaphore is not working, all tests will complete in 1 second.
assert i < 3 or duration > i


@increment
Expand All @@ -29,9 +34,8 @@ async def test_semaphore_property(i: int):
start = time()
assert await instance.test_property == 2
duration = time() - start
assert (
i < 3 or duration > i
) # There is a 1 second sleep in this fn. If the semaphore is not working, all tests will complete in 1 second.
# There is a 1 second sleep in this fn. If the semaphore is not working, all tests will complete in 1 second.
assert i < 3 or duration > i


@increment
Expand All @@ -43,6 +47,5 @@ async def test_semaphore_cached_property(i: int):
# There is a 1 second sleep in this fn but a semaphore override with a value of 50.
# You can tell it worked correctly because the class-defined semaphore value is just one, whch would cause this test to fail if it were used.
# If the override is not working, all tests will complete in just over 1 second.
assert (
i == 1 or duration < 1.4
) # We increased the threshold from 1.05 to 1.4 to help tests pass on slow github runners
# We increased the threshold from 1.05 to 1.4 to help tests pass on slow github runners
assert i == 1 or duration < 1.4

0 comments on commit 620a3eb

Please sign in to comment.