Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New faster Queue with capacity >= 0 #473

Merged
merged 10 commits into from
May 5, 2018
6 changes: 6 additions & 0 deletions docs/source/reference-core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,12 @@ If the queue gets too big, then it applies *backpressure*: ``put``
blocks and forces the producers to slow down and wait until the
consumer calls ``get``.

You can also create a :class:`Queue` with size 0. In that case any
task that calls ``put`` on the queue will wait until another task
calls ``get`` on the same queue, and vice versa. This is similar to
the behavior of `channels as described in the CSP model
<https://en.wikipedia.org/wiki/Channel_(programming)>`__.

.. autoclass:: Queue
:members:

Expand Down
2 changes: 2 additions & 0 deletions newsfragments/473.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add support for :class:`trio.Queue` with `capacity=0`. Queues implementation is
also faster now.
92 changes: 66 additions & 26 deletions trio/_sync.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import operator
from collections import deque
from collections import deque, OrderedDict

import attr
import outcome

from . import _core
from ._util import aiter_compat
Expand Down Expand Up @@ -823,21 +824,23 @@ class Queue:
capacity (int): The maximum number of items allowed in the queue before
:meth:`put` blocks. Choosing a sensible value here is important to
ensure that backpressure is communicated promptly and avoid
unnecessary latency. If in doubt, use 1.
unnecessary latency. If in doubt, use 0.

"""

def __init__(self, capacity):
if not isinstance(capacity, int):
raise TypeError("capacity must be an integer")
if capacity < 1:
raise ValueError("capacity must be >= 1")
# Invariants:
# get_semaphore.value() == len(self._data)
# put_semaphore.value() + get_semaphore.value() = capacity
if capacity < 0:
raise ValueError("capacity must be >= 0")
self.capacity = operator.index(capacity)
self._put_semaphore = Semaphore(capacity, max_value=capacity)
self._get_semaphore = Semaphore(0, max_value=capacity)
# {task: None} ordered set
self._get_wait = OrderedDict()
# {task: queued value}
self._put_wait = OrderedDict()
# invariants:
# if len(self._data) < self.capacity, then self._put_wait is empty
# if len(self._data) > 0, then self._get_wait is empty
self._data = deque()

def __repr__(self):
Expand Down Expand Up @@ -873,10 +876,6 @@ def empty(self):
"""
return not self._data

def _put_protected(self, obj):
self._data.append(obj)
self._get_semaphore.release()

@_core.enable_ki_protection
def put_nowait(self, obj):
"""Attempt to put an object into the queue, without blocking.
Expand All @@ -888,8 +887,14 @@ def put_nowait(self, obj):
WouldBlock: if the queue is full.

"""
self._put_semaphore.acquire_nowait()
self._put_protected(obj)
if self._get_wait:
assert not self._data
task, _ = self._get_wait.popitem(last=False)
_core.reschedule(task, outcome.Value(obj))
elif len(self._data) < self.capacity:
self._data.append(obj)
else:
raise _core.WouldBlock()

@_core.enable_ki_protection
async def put(self, obj):
Expand All @@ -899,12 +904,23 @@ async def put(self, obj):
obj (object): The object to enqueue.

"""
await self._put_semaphore.acquire()
self._put_protected(obj)
await _core.checkpoint_if_cancelled()
try:
self.put_nowait(obj)
except _core.WouldBlock:
pass
else:
await _core.cancel_shielded_checkpoint()
return

task = _core.current_task()
self._put_wait[task] = obj

def _get_protected(self):
self._put_semaphore.release()
return self._data.popleft()
def abort_fn(_):
del self._put_wait[task]
return _core.Abort.SUCCEEDED

await _core.wait_task_rescheduled(abort_fn)

@_core.enable_ki_protection
def get_nowait(self):
Expand All @@ -917,8 +933,16 @@ def get_nowait(self):
WouldBlock: if the queue is empty.

"""
self._get_semaphore.acquire_nowait()
return self._get_protected()
if self._put_wait:
task, value = self._put_wait.popitem(last=False)
# No need to check max_size, b/c we'll pop an item off again right
# below.
self._data.append(value)
_core.reschedule(task)
if self._data:
value = self._data.popleft()
return value
raise _core.WouldBlock()

@_core.enable_ki_protection
async def get(self):
Expand All @@ -928,8 +952,24 @@ async def get(self):
object: The dequeued object.

"""
await self._get_semaphore.acquire()
return self._get_protected()
await _core.checkpoint_if_cancelled()
try:
value = self.get_nowait()
except _core.WouldBlock:
pass
else:
await _core.cancel_shielded_checkpoint()
return value

# Queue doesn't have anything, we must wait.
task = _core.current_task()

def abort_fn(_):
return _core.Abort.SUCCEEDED

self._get_wait[task] = None
value = await _core.wait_task_rescheduled(abort_fn)
return value

@aiter_compat
def __aiter__(self):
Expand All @@ -954,6 +994,6 @@ def statistics(self):
return _QueueStats(
qsize=len(self._data),
capacity=self.capacity,
tasks_waiting_put=self._put_semaphore.statistics().tasks_waiting,
tasks_waiting_get=self._get_semaphore.statistics().tasks_waiting,
tasks_waiting_put=len(self._put_wait),
tasks_waiting_get=len(self._get_wait),
)
55 changes: 53 additions & 2 deletions trio/tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,6 @@ async def test_Queue():
Queue(1.0)
with pytest.raises(ValueError):
Queue(-1)
with pytest.raises(ValueError):
Queue(0)

q = Queue(2)
repr(q) # smoke test
Expand Down Expand Up @@ -512,6 +510,29 @@ async def do_get(q):
assert (await q.get()) == 2


async def test_Queue_unbuffered():
Copy link
Member

@njsmith njsmith Apr 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a bit more checking of unbuffered Queue, I'd suggest adding a QueueLock1(0) to the lock_factories and lock_factory_names lists in this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this works. Attempting acquire the lock (i.e. await queue.put()) will block indefinitely. I definitely like idea to test Queue(0) as used for explicit synchronization though, so I may add a test that relies on rendezvous behaviour.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attempting acquire the lock (i.e. await queue.put()) will block indefinitely

Oh duh, of course you're right.

q = Queue(0)
assert q.capacity == 0
assert q.qsize() == 0
assert q.empty()
assert q.full()
with pytest.raises(_core.WouldBlock):
q.get_nowait()
with pytest.raises(_core.WouldBlock):
q.put_nowait(1)

async def do_put(q, v):
with assert_checkpoints():
await q.put(v)

async with _core.open_nursery() as nursery:
nursery.start_soon(do_put, q, 1)
with assert_checkpoints():
assert await q.get() == 1
with pytest.raises(_core.WouldBlock):
q.get_nowait()


# Two ways of implementing a Lock in terms of a Queue. Used to let us put the
# Queue through the generic lock tests.

Expand Down Expand Up @@ -551,6 +572,34 @@ def release(self):
self.q.put_nowait(None)


@async_cm
class QueueLock3:
def __init__(self):
self.q = Queue(0)
# self.acquired is true when one task acquires the lock and
# only becomes false when it's released and no tasks are
# waiting to acquire.
self.acquired = False

def acquire_nowait(self):
assert not self.acquired
self.acquired = True

async def acquire(self):
if self.acquired:
await self.q.put(None)
else:
self.acquired = True
await _core.checkpoint()

def release(self):
try:
self.q.get_nowait()
except _core.WouldBlock:
assert self.acquired
self.acquired = False


lock_factories = [
lambda: CapacityLimiter(1),
lambda: Semaphore(1),
Expand All @@ -559,6 +608,7 @@ def release(self):
lambda: QueueLock1(10),
lambda: QueueLock1(1),
QueueLock2,
QueueLock3,
]
lock_factory_names = [
"CapacityLimiter(1)",
Expand All @@ -568,6 +618,7 @@ def release(self):
"QueueLock1(10)",
"QueueLock1(1)",
"QueueLock2",
"QueueLock3",
]

generic_lock_test = pytest.mark.parametrize(
Expand Down