Skip to content

Commit

Permalink
fix: semaphore overflow error on very large values
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler committed Nov 22, 2024
1 parent 2062c02 commit 6fc9ee9
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 6 deletions.
2 changes: 1 addition & 1 deletion a_sync/primitives/locks/semaphore.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ from a_sync.primitives._debug cimport _DebugDaemonMixin

cdef class Semaphore(_DebugDaemonMixin):
cdef str _name
cdef int __value
cdef unsigned long long __value
cdef object _waiters
cdef set _decorated
cdef dict __dict__
Expand Down
14 changes: 10 additions & 4 deletions a_sync/primitives/locks/semaphore.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,18 @@ cdef class Semaphore(_DebugDaemonMixin):
value: The initial value for the semaphore.
name (optional): An optional name used only to provide useful context in debug logs.
"""
if value is None:
raise ValueError(value)
super().__init__(loop=loop)
if value < 0:
raise ValueError("Semaphore initial value must be >= 0")
self.__value = value
try:
self.__value = value
except OverflowError as e:
raise OverflowError(
{"error": str(e), "value": value, "max value": 18446744073709551615},
"If you need a Semaphore with a larger value, you should just use asyncio.Semaphore",
) from e.__cause__
self._name = name or getattr(self, "__origin__", "")
def __call__(self, fn: CoroFn[P, T]) -> CoroFn[P, T]:
Expand Down Expand Up @@ -220,7 +226,7 @@ cdef class Semaphore(_DebugDaemonMixin):
return self.__value
@_value.setter
def _value(self, int value):
def _value(self, unsigned long long value):
# required for subclass compatability
self.__value = value
Expand Down
267 changes: 266 additions & 1 deletion tests/test_semaphore.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
import asyncio
import sys
from time import time

from a_sync import Semaphore
from a_sync.primitives.locks.semaphore import Semaphore
from tests.fixtures import TestSemaphore, increment


Expand Down Expand Up @@ -49,3 +51,266 @@ async def test_semaphore_cached_property(i: int):
# If the override is not working, all tests will complete in just over 1 second.
# We increased the threshold from 1.05 to 1.4 to help tests pass on slow github runners
assert i == 1 or duration < 1.4

@pytest.mark.asyncio_cooperative
async def test_semaphore_acquire_release():
semaphore = Semaphore(2)
await semaphore.acquire()
assert semaphore._value == 1
semaphore.release()
assert semaphore._value == 2

@pytest.mark.asyncio_cooperative
async def test_semaphore_blocking():
semaphore = Semaphore(1)

async def task():
await semaphore.acquire()
await asyncio.sleep(0.1)
semaphore.release()

await semaphore.acquire()
task1 = asyncio.create_task(task())
await asyncio.sleep(0.05)
assert semaphore.locked()
semaphore.release()
await task1

@pytest.mark.asyncio_cooperative
async def test_semaphore_multiple_tasks():
semaphore = Semaphore(2)
results = []

async def task(index):
await semaphore.acquire()
results.append(index)
await asyncio.sleep(0.1)
semaphore.release()

tasks = [asyncio.create_task(task(i)) for i in range(4)]
await asyncio.gather(*tasks)
assert results == [0, 1, 2, 3]

@pytest.mark.asyncio_cooperative
async def test_semaphore_with_zero_initial_value():
semaphore = Semaphore(0)

async def task():
await semaphore.acquire()
return 'done'

task1 = asyncio.create_task(task())
await asyncio.sleep(0.1)
assert not task1.done()
semaphore.release()
result = await task1
assert result == 'done'

def test_semaphore_negative_initial_value():
with pytest.raises(ValueError):
Semaphore(-1)
with pytest.raises(TypeError):
Semaphore(None)
with pytest.raises(TypeError):
Semaphore("None")
"""
@pytest.mark.asyncio_cooperative
async def test_semaphore_releasing_without_acquiring():
semaphore = Semaphore(1)
semaphore.release()
assert semaphore._value == 2
@pytest.mark.asyncio_cooperative
async def test_semaphore_with_custom_loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
semaphore = Semaphore(1)
async def task():
await semaphore.acquire()
return 'done'
task1 = loop.create_task(task())
loop.call_soon(semaphore.release)
result = loop.run_until_complete(task1)
assert result == 'done'
loop.close()
@pytest.mark.asyncio_cooperative
async def test_concurrent_acquire_release():
semaphore = Semaphore(2)
results = []
async def task(index):
await semaphore.acquire()
results.append(index)
await asyncio.sleep(0.1)
semaphore.release()
tasks = [asyncio.create_task(task(i)) for i in range(10)]
await asyncio.gather(*tasks)
assert sorted(results) == list(range(10))
@pytest.mark.asyncio_cooperative
async def test_semaphore_max_integer_value():
semaphore = Semaphore(sys.maxsize)
await semaphore.acquire()
assert semaphore._value == sys.maxsize - 1
semaphore.release()
assert semaphore._value == sys.maxsize
@pytest.mark.asyncio_cooperative
async def test_rapid_acquire_release():
semaphore = Semaphore(1)
async def task():
for _ in range(100):
await semaphore.acquire()
semaphore.release()
await asyncio.gather(*[task() for _ in range(10)])
@pytest.mark.asyncio_cooperative
async def test_exception_handling_during_acquire():
semaphore = Semaphore(1)
async def task():
try:
await semaphore.acquire()
raise ValueError("Intentional error")
except ValueError:
semaphore.release()
await task()
assert semaphore._value == 1
@pytest.mark.asyncio_cooperative
async def test_cancelled_acquire():
semaphore = Semaphore(1)
async def task():
await semaphore.acquire()
await asyncio.sleep(0.5)
semaphore.release()
task1 = asyncio.create_task(task())
await asyncio.sleep(0.1)
async def waiting_task():
await semaphore.acquire()
waiting_task1 = asyncio.create_task(waiting_task())
await asyncio.sleep(0.1)
waiting_task1.cancel()
await task1
@pytest.mark.asyncio_cooperative
async def test_nested_acquires():
semaphore = Semaphore(1)
async def task():
await semaphore.acquire()
await semaphore.acquire()
semaphore.release()
semaphore.release()
with pytest.raises(RuntimeError):
await task()
@pytest.mark.asyncio_cooperative
async def test_delayed_release():
semaphore = Semaphore(1)
async def task():
await semaphore.acquire()
await asyncio.sleep(1)
semaphore.release()
task1 = asyncio.create_task(task())
await asyncio.sleep(0.1)
assert semaphore.locked()
await task1
@pytest.mark.asyncio_cooperative
async def test_invalid_release():
semaphore = Semaphore(1)
semaphore.release()
semaphore.release()
assert semaphore._value == 3
@pytest.mark.asyncio_cooperative
async def test_custom_error_handling():
semaphore = Semaphore(1)
async def task():
try:
await semaphore.acquire()
raise RuntimeError("Custom error")
except RuntimeError:
semaphore.release()
await task()
assert semaphore._value == 1
@pytest.mark.asyncio_cooperative
async def test_dynamic_resizing():
semaphore = Semaphore(2)
semaphore._value = 5
await semaphore.acquire()
assert semaphore._value == 4
semaphore.release()
assert semaphore._value == 5
@pytest.mark.asyncio_cooperative
async def test_high_frequency_operations():
semaphore = Semaphore(1)
async def task():
for _ in range(1000):
await semaphore.acquire()
semaphore.release()
await asyncio.gather(*[task() for _ in range(10)])
@pytest.mark.asyncio_cooperative
async def test_semaphore_as_context_manager():
semaphore = Semaphore(1)
async def task():
async with semaphore:
assert semaphore._value == 0
await task()
assert semaphore._value == 1
@pytest.mark.asyncio_cooperative
async def test_exception_in_release():
semaphore = Semaphore(1)
async def task():
await semaphore.acquire()
try:
raise RuntimeError("Error during release")
finally:
semaphore.release()
with pytest.raises(RuntimeError):
await task()
assert semaphore._value == 1
@pytest.mark.asyncio_cooperative
async def test_external_interruptions():
semaphore = Semaphore(1)
async def task():
await semaphore.acquire()
await asyncio.sleep(0.1)
semaphore.release()
task1 = asyncio.create_task(task())
await asyncio.sleep(0.05)
task1.cancel()
await asyncio.sleep(0.1)
assert semaphore._value == 1
"""

0 comments on commit 6fc9ee9

Please sign in to comment.