diff --git a/trio/_sync.py b/trio/_sync.py index 15bafe497c..91c23b3ee8 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -1,5 +1,5 @@ import operator -from collections import deque +from collections import deque, OrderedDict import attr @@ -823,21 +823,23 @@ class Queue: capacity (int): The maximum number of items allowed in the queue before :meth:`put` blocks. Choosing a sensible value here is important to ensure that backpressure is communicated promptly and avoid - unnecessary latency. If in doubt, use 1. + unnecessary latency. If in doubt, use 0. """ def __init__(self, capacity): if not isinstance(capacity, int): raise TypeError("capacity must be an integer") - if capacity < 1: - raise ValueError("capacity must be >= 1") - # Invariants: - # get_semaphore.value() == len(self._data) - # put_semaphore.value() + get_semaphore.value() = capacity + if capacity < 0: + raise ValueError("capacity must be >= 0") self.capacity = operator.index(capacity) - self._put_semaphore = Semaphore(capacity, max_value=capacity) - self._get_semaphore = Semaphore(0, max_value=capacity) + # {task: abort func} + self._get_wait = OrderedDict() + # {task: queued value} + self._put_wait = OrderedDict() + # invariants: + # if len(self._data) < self.capacity, then self._put_wait is empty + # if len(self._data) > 0, then self._get_wait is empty self._data = deque() def __repr__(self): @@ -873,10 +875,6 @@ def empty(self): """ return not self._data - def _put_protected(self, obj): - self._data.append(obj) - self._get_semaphore.release() - @_core.enable_ki_protection def put_nowait(self, obj): """Attempt to put an object into the queue, without blocking. @@ -888,8 +886,15 @@ def put_nowait(self, obj): WouldBlock: if the queue is full. """ - self._put_semaphore.acquire_nowait() - self._put_protected(obj) + if self._get_wait: + assert not self._data + task, abort_fn = self._get_wait.popitem(last=False) + abort_fn(None) + _core.reschedule(task, _core.Value((self, obj))) + elif len(self._data) < self.capacity: + self._data.append(obj) + else: + raise _core.WouldBlock() @_core.enable_ki_protection async def put(self, obj): @@ -899,12 +904,23 @@ async def put(self, obj): obj (object): The object to enqueue. """ - await self._put_semaphore.acquire() - self._put_protected(obj) + await _core.checkpoint_if_cancelled() + try: + self.put_nowait(obj) + except _core.WouldBlock: + pass + else: + await _core.cancel_shielded_checkpoint() + return + + task = _core.current_task() + self._put_wait[task] = obj - def _get_protected(self): - self._put_semaphore.release() - return self._data.popleft() + def abort_fn(_): + del self._put_wait[task] + return _core.Abort.SUCCEEDED + + await _core.wait_task_rescheduled(abort_fn) @_core.enable_ki_protection def get_nowait(self): @@ -917,8 +933,8 @@ def get_nowait(self): WouldBlock: if the queue is empty. """ - self._get_semaphore.acquire_nowait() - return self._get_protected() + _, value = multi_get_nowait([self]) + return value @_core.enable_ki_protection async def get(self): @@ -928,8 +944,8 @@ async def get(self): object: The dequeued object. """ - await self._get_semaphore.acquire() - return self._get_protected() + _, value = await multi_get([self]) + return value @aiter_compat def __aiter__(self): @@ -954,6 +970,50 @@ def statistics(self): return _QueueStats( qsize=len(self._data), capacity=self.capacity, - tasks_waiting_put=self._put_semaphore.statistics().tasks_waiting, - tasks_waiting_get=self._get_semaphore.statistics().tasks_waiting, + tasks_waiting_put=len(self._put_wait), + tasks_waiting_get=len(self._get_wait), ) + + +@_core.enable_ki_protection +def multi_get_nowait(queues): + for queue in queues: + if queue._put_wait: + task, value = queue._put_wait.popitem(last=False) + # No need to check max_size, b/c we'll pop an item off again right + # below. + queue._data.append(value) + _core.reschedule(task) + if queue._data: + value = queue._data.popleft() + return queue, value + raise _core.WouldBlock() + + +@_core.enable_ki_protection +async def multi_get(queues): + # Returns (queue object, value gotten) + await _core.checkpoint_if_cancelled() + try: + queue, value = multi_get_nowait(queues) + except _core.WouldBlock: + pass + else: + await _core.cancel_shielded_checkpoint() + return queue, value + # No queue had anything. + task = _core.current_task() + + def abort_fn(_): + for queue in queues: + try: + del queue._get_wait[task] + except KeyError: + # If we just pushed to this queue, we already popped. + # But is it alright to... always pass? + pass + return _core.Abort.SUCCEEDED + + for queue in queues: + queue._get_wait[task] = abort_fn + return await _core.wait_task_rescheduled(abort_fn) diff --git a/trio/tests/test_sync.py b/trio/tests/test_sync.py index 7523005ce2..6a6036584b 100644 --- a/trio/tests/test_sync.py +++ b/trio/tests/test_sync.py @@ -378,8 +378,6 @@ async def test_Queue(): Queue(1.0) with pytest.raises(ValueError): Queue(-1) - with pytest.raises(ValueError): - Queue(0) q = Queue(2) repr(q) # smoke test @@ -512,6 +510,29 @@ async def do_get(q): assert (await q.get()) == 2 +async def test_Queue_unbuffered(): + q = Queue(0) + assert q.capacity == 0 + assert q.qsize() == 0 + assert q.empty() + assert q.full() + with pytest.raises(_core.WouldBlock): + q.get_nowait() + with pytest.raises(_core.WouldBlock): + q.put_nowait(1) + + async def do_put(q, v): + with assert_checkpoints(): + await q.put(v) + + async with _core.open_nursery() as nursery: + nursery.start_soon(do_put, q, 1) + with assert_checkpoints(): + assert await q.get() == 1 + with pytest.raises(_core.WouldBlock): + q.get_nowait() + + # Two ways of implementing a Lock in terms of a Queue. Used to let us put the # Queue through the generic lock tests.