diff --git a/newsfragments/573.feature.rst b/newsfragments/573.feature.rst new file mode 100644 index 0000000000..e3f6b9e9fa --- /dev/null +++ b/newsfragments/573.feature.rst @@ -0,0 +1,2 @@ +Add the ability to close a :class:`trio.Queue`, cancelling all waiting getters and putters, and +preventing anyone else from getting or putting onto it. \ No newline at end of file diff --git a/trio/_sync.py b/trio/_sync.py index fe806dd34a..c9a2c0a950 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -15,6 +15,7 @@ "StrictFIFOLock", "Condition", "Queue", + "QueueClosed", ] @@ -802,6 +803,11 @@ class _QueueStats: tasks_waiting_get = attr.ib() +class QueueClosed(Exception): + """Raised on waiters for the queue when a queue is closed. + """ + + # Like queue.Queue, with the notable difference that the capacity argument is # mandatory. class Queue: @@ -842,6 +848,9 @@ def __init__(self, capacity): # if len(self._data) < self.capacity, then self._put_wait is empty # if len(self._data) > 0, then self._get_wait is empty self._data = deque() + # closed state + self._put_close = False + self._all_closed = False def __repr__(self): return ( @@ -887,6 +896,9 @@ def put_nowait(self, obj): WouldBlock: if the queue is full. """ + if self._put_close or self._all_closed: + raise QueueClosed + if self._get_wait: assert not self._data task, _ = self._get_wait.popitem(last=False) @@ -905,6 +917,9 @@ async def put(self, obj): """ await _core.checkpoint_if_cancelled() + if self._put_close or self._all_closed: + raise QueueClosed + try: self.put_nowait(obj) except _core.WouldBlock: @@ -933,6 +948,9 @@ def get_nowait(self): WouldBlock: if the queue is empty. """ + if self._all_closed: + raise QueueClosed + if self._put_wait: task, value = self._put_wait.popitem(last=False) # No need to check max_size, b/c we'll pop an item off again right @@ -942,6 +960,16 @@ def get_nowait(self): if self._data: value = self._data.popleft() return value + if self._put_close: + # this confused me a bit so its bound to confuse somebody else as to why this is here + # 1) there's no put waiters, so we skip that branch + # 2) there's no data so we skip that branch + # that means that if there's no data at all, and the put side is closed + # we cannot ever have more data, so we close this side and raise QueueClosed so that + # any getters from here on close early + self._all_closed = True + raise QueueClosed + raise _core.WouldBlock() @_core.enable_ki_protection @@ -953,6 +981,9 @@ async def get(self): """ await _core.checkpoint_if_cancelled() + if self._all_closed: + raise QueueClosed + try: value = self.get_nowait() except _core.WouldBlock: @@ -972,12 +1003,46 @@ def abort_fn(_): value = await _core.wait_task_rescheduled(abort_fn) return value + def close_put(self): + """Closes one side of this queue, preventing any putters from putting data onto the queue. + + If this queue is empty, it will also cancel all getters. + """ + if self.empty(): + # pointless to let the getters wait on closed data + self.close_both_sides() + else: + self._put_close = True + for task in self._put_wait.values(): + _core.reschedule(task, outcome.Error(QueueClosed)) + + self._put_wait.clear() + + def close_both_sides(self): + """Closes both the getter and putter sides of the queue, discarding all data. + """ + self._put_close, self._all_closed = True, True + for task in self._get_wait.values(): + _core.reschedule(task, outcome.Error(QueueClosed)) + + self._get_wait.clear() + + for task in self._put_wait.values(): + _core.reschedule(task, outcome.Error(QueueClosed)) + + self._put_wait.clear() + + self._data.clear() + @aiter_compat def __aiter__(self): return self async def __anext__(self): - return await self.get() + try: + return await self.get() + except QueueClosed: + raise StopAsyncIteration from None def statistics(self): """Returns an object containing debugging information. diff --git a/trio/tests/test_sync.py b/trio/tests/test_sync.py index bded1b0544..2d0c354907 100644 --- a/trio/tests/test_sync.py +++ b/trio/tests/test_sync.py @@ -5,7 +5,6 @@ from ..testing import wait_all_tasks_blocked, assert_checkpoints from .. import _core -from .. import _timeouts from .._timeouts import sleep_forever, move_on_after from .._sync import * @@ -542,6 +541,26 @@ async def do_put(q, v): q.get_nowait() +async def test_Queue_close(): + q1 = Queue(capacity=1) + + await q1.put(1) + q1.close_put() + with pytest.raises(QueueClosed): + await q1.put(2) + + assert (await q1.get()) == 1 + with pytest.raises(QueueClosed): + await q1.get() + + q2 = Queue(capacity=1) + await q2.put(1) + q2.close_both_sides() + + with pytest.raises(QueueClosed): + await q2.get() + + # Two ways of implementing a Lock in terms of a Queue. Used to let us put the # Queue through the generic lock tests.