From 6fc9ee93c402753a815f17644b2a164fbc43e7e6 Mon Sep 17 00:00:00 2001 From: BobTheBuidler Date: Fri, 22 Nov 2024 20:39:18 +0000 Subject: [PATCH] fix: semaphore overflow error on very large values --- a_sync/primitives/locks/semaphore.pxd | 2 +- a_sync/primitives/locks/semaphore.pyx | 14 +- tests/test_semaphore.py | 267 +++++++++++++++++++++++++- 3 files changed, 277 insertions(+), 6 deletions(-) diff --git a/a_sync/primitives/locks/semaphore.pxd b/a_sync/primitives/locks/semaphore.pxd index 6ac023cc..e15d23a9 100644 --- a/a_sync/primitives/locks/semaphore.pxd +++ b/a_sync/primitives/locks/semaphore.pxd @@ -2,7 +2,7 @@ from a_sync.primitives._debug cimport _DebugDaemonMixin cdef class Semaphore(_DebugDaemonMixin): cdef str _name - cdef int __value + cdef unsigned long long __value cdef object _waiters cdef set _decorated cdef dict __dict__ diff --git a/a_sync/primitives/locks/semaphore.pyx b/a_sync/primitives/locks/semaphore.pyx index f84a6f0a..a40d6656 100644 --- a/a_sync/primitives/locks/semaphore.pyx +++ b/a_sync/primitives/locks/semaphore.pyx @@ -61,12 +61,18 @@ cdef class Semaphore(_DebugDaemonMixin): value: The initial value for the semaphore. name (optional): An optional name used only to provide useful context in debug logs. """ - if value is None: - raise ValueError(value) super().__init__(loop=loop) if value < 0: raise ValueError("Semaphore initial value must be >= 0") - self.__value = value + + try: + self.__value = value + except OverflowError as e: + raise OverflowError( + {"error": str(e), "value": value, "max value": 18446744073709551615}, + "If you need a Semaphore with a larger value, you should just use asyncio.Semaphore", + ) from e.__cause__ + self._name = name or getattr(self, "__origin__", "") def __call__(self, fn: CoroFn[P, T]) -> CoroFn[P, T]: @@ -220,7 +226,7 @@ cdef class Semaphore(_DebugDaemonMixin): return self.__value @_value.setter - def _value(self, int value): + def _value(self, unsigned long long value): # required for subclass compatability self.__value = value diff --git a/tests/test_semaphore.py b/tests/test_semaphore.py index e330eab4..779586b4 100644 --- a/tests/test_semaphore.py +++ b/tests/test_semaphore.py @@ -1,7 +1,9 @@ import pytest +import asyncio +import sys from time import time -from a_sync import Semaphore +from a_sync.primitives.locks.semaphore import Semaphore from tests.fixtures import TestSemaphore, increment @@ -49,3 +51,266 @@ async def test_semaphore_cached_property(i: int): # If the override is not working, all tests will complete in just over 1 second. # We increased the threshold from 1.05 to 1.4 to help tests pass on slow github runners assert i == 1 or duration < 1.4 + +@pytest.mark.asyncio_cooperative +async def test_semaphore_acquire_release(): + semaphore = Semaphore(2) + await semaphore.acquire() + assert semaphore._value == 1 + semaphore.release() + assert semaphore._value == 2 + +@pytest.mark.asyncio_cooperative +async def test_semaphore_blocking(): + semaphore = Semaphore(1) + + async def task(): + await semaphore.acquire() + await asyncio.sleep(0.1) + semaphore.release() + + await semaphore.acquire() + task1 = asyncio.create_task(task()) + await asyncio.sleep(0.05) + assert semaphore.locked() + semaphore.release() + await task1 + +@pytest.mark.asyncio_cooperative +async def test_semaphore_multiple_tasks(): + semaphore = Semaphore(2) + results = [] + + async def task(index): + await semaphore.acquire() + results.append(index) + await asyncio.sleep(0.1) + semaphore.release() + + tasks = [asyncio.create_task(task(i)) for i in range(4)] + await asyncio.gather(*tasks) + assert results == [0, 1, 2, 3] + +@pytest.mark.asyncio_cooperative +async def test_semaphore_with_zero_initial_value(): + semaphore = Semaphore(0) + + async def task(): + await semaphore.acquire() + return 'done' + + task1 = asyncio.create_task(task()) + await asyncio.sleep(0.1) + assert not task1.done() + semaphore.release() + result = await task1 + assert result == 'done' + +def test_semaphore_negative_initial_value(): + with pytest.raises(ValueError): + Semaphore(-1) + with pytest.raises(TypeError): + Semaphore(None) + with pytest.raises(TypeError): + Semaphore("None") +""" +@pytest.mark.asyncio_cooperative +async def test_semaphore_releasing_without_acquiring(): + semaphore = Semaphore(1) + semaphore.release() + assert semaphore._value == 2 + +@pytest.mark.asyncio_cooperative +async def test_semaphore_with_custom_loop(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + semaphore = Semaphore(1) + + async def task(): + await semaphore.acquire() + return 'done' + + task1 = loop.create_task(task()) + loop.call_soon(semaphore.release) + result = loop.run_until_complete(task1) + assert result == 'done' + loop.close() + +@pytest.mark.asyncio_cooperative +async def test_concurrent_acquire_release(): + semaphore = Semaphore(2) + results = [] + + async def task(index): + await semaphore.acquire() + results.append(index) + await asyncio.sleep(0.1) + semaphore.release() + + tasks = [asyncio.create_task(task(i)) for i in range(10)] + await asyncio.gather(*tasks) + assert sorted(results) == list(range(10)) + +@pytest.mark.asyncio_cooperative +async def test_semaphore_max_integer_value(): + semaphore = Semaphore(sys.maxsize) + await semaphore.acquire() + assert semaphore._value == sys.maxsize - 1 + semaphore.release() + assert semaphore._value == sys.maxsize + +@pytest.mark.asyncio_cooperative +async def test_rapid_acquire_release(): + semaphore = Semaphore(1) + + async def task(): + for _ in range(100): + await semaphore.acquire() + semaphore.release() + + await asyncio.gather(*[task() for _ in range(10)]) + +@pytest.mark.asyncio_cooperative +async def test_exception_handling_during_acquire(): + semaphore = Semaphore(1) + + async def task(): + try: + await semaphore.acquire() + raise ValueError("Intentional error") + except ValueError: + semaphore.release() + + await task() + assert semaphore._value == 1 + +@pytest.mark.asyncio_cooperative +async def test_cancelled_acquire(): + semaphore = Semaphore(1) + + async def task(): + await semaphore.acquire() + await asyncio.sleep(0.5) + semaphore.release() + + task1 = asyncio.create_task(task()) + await asyncio.sleep(0.1) + + async def waiting_task(): + await semaphore.acquire() + + waiting_task1 = asyncio.create_task(waiting_task()) + await asyncio.sleep(0.1) + waiting_task1.cancel() + await task1 + +@pytest.mark.asyncio_cooperative +async def test_nested_acquires(): + semaphore = Semaphore(1) + + async def task(): + await semaphore.acquire() + await semaphore.acquire() + semaphore.release() + semaphore.release() + + with pytest.raises(RuntimeError): + await task() + +@pytest.mark.asyncio_cooperative +async def test_delayed_release(): + semaphore = Semaphore(1) + + async def task(): + await semaphore.acquire() + await asyncio.sleep(1) + semaphore.release() + + task1 = asyncio.create_task(task()) + await asyncio.sleep(0.1) + assert semaphore.locked() + await task1 + +@pytest.mark.asyncio_cooperative +async def test_invalid_release(): + semaphore = Semaphore(1) + semaphore.release() + semaphore.release() + assert semaphore._value == 3 + +@pytest.mark.asyncio_cooperative +async def test_custom_error_handling(): + semaphore = Semaphore(1) + + async def task(): + try: + await semaphore.acquire() + raise RuntimeError("Custom error") + except RuntimeError: + semaphore.release() + + await task() + assert semaphore._value == 1 + +@pytest.mark.asyncio_cooperative +async def test_dynamic_resizing(): + semaphore = Semaphore(2) + semaphore._value = 5 + await semaphore.acquire() + assert semaphore._value == 4 + semaphore.release() + assert semaphore._value == 5 + +@pytest.mark.asyncio_cooperative +async def test_high_frequency_operations(): + semaphore = Semaphore(1) + + async def task(): + for _ in range(1000): + await semaphore.acquire() + semaphore.release() + + await asyncio.gather(*[task() for _ in range(10)]) + +@pytest.mark.asyncio_cooperative +async def test_semaphore_as_context_manager(): + semaphore = Semaphore(1) + + async def task(): + async with semaphore: + assert semaphore._value == 0 + + await task() + assert semaphore._value == 1 + +@pytest.mark.asyncio_cooperative +async def test_exception_in_release(): + semaphore = Semaphore(1) + + async def task(): + await semaphore.acquire() + try: + raise RuntimeError("Error during release") + finally: + semaphore.release() + + with pytest.raises(RuntimeError): + await task() + + assert semaphore._value == 1 + +@pytest.mark.asyncio_cooperative +async def test_external_interruptions(): + semaphore = Semaphore(1) + + async def task(): + await semaphore.acquire() + await asyncio.sleep(0.1) + semaphore.release() + + task1 = asyncio.create_task(task()) + await asyncio.sleep(0.05) + task1.cancel() + await asyncio.sleep(0.1) + assert semaphore._value == 1 +""" \ No newline at end of file