Skip to content

Commit

Permalink
new Queue with multi_get() and max_size>=0
Browse files Browse the repository at this point in the history
based on ideas and code from #272
  • Loading branch information
sorcio committed Mar 17, 2018
1 parent aee2676 commit a1ac052
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 28 deletions.
112 changes: 86 additions & 26 deletions trio/_sync.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import operator
from collections import deque
from collections import deque, OrderedDict

import attr

Expand Down Expand Up @@ -823,21 +823,23 @@ class Queue:
capacity (int): The maximum number of items allowed in the queue before
:meth:`put` blocks. Choosing a sensible value here is important to
ensure that backpressure is communicated promptly and avoid
unnecessary latency. If in doubt, use 1.
unnecessary latency. If in doubt, use 0.
"""

def __init__(self, capacity):
if not isinstance(capacity, int):
raise TypeError("capacity must be an integer")
if capacity < 1:
raise ValueError("capacity must be >= 1")
# Invariants:
# get_semaphore.value() == len(self._data)
# put_semaphore.value() + get_semaphore.value() = capacity
if capacity < 0:
raise ValueError("capacity must be >= 0")
self.capacity = operator.index(capacity)
self._put_semaphore = Semaphore(capacity, max_value=capacity)
self._get_semaphore = Semaphore(0, max_value=capacity)
# {task: abort func}
self._get_wait = OrderedDict()
# {task: queued value}
self._put_wait = OrderedDict()
# invariants:
# if len(self._data) < self.capacity, then self._put_wait is empty
# if len(self._data) > 0, then self._get_wait is empty
self._data = deque()

def __repr__(self):
Expand Down Expand Up @@ -873,10 +875,6 @@ def empty(self):
"""
return not self._data

def _put_protected(self, obj):
self._data.append(obj)
self._get_semaphore.release()

@_core.enable_ki_protection
def put_nowait(self, obj):
"""Attempt to put an object into the queue, without blocking.
Expand All @@ -888,8 +886,15 @@ def put_nowait(self, obj):
WouldBlock: if the queue is full.
"""
self._put_semaphore.acquire_nowait()
self._put_protected(obj)
if self._get_wait:
assert not self._data
task, abort_fn = self._get_wait.popitem(last=False)
abort_fn(None)
_core.reschedule(task, _core.Value((self, obj)))
elif len(self._data) < self.capacity:
self._data.append(obj)
else:
raise _core.WouldBlock()

@_core.enable_ki_protection
async def put(self, obj):
Expand All @@ -899,12 +904,23 @@ async def put(self, obj):
obj (object): The object to enqueue.
"""
await self._put_semaphore.acquire()
self._put_protected(obj)
await _core.checkpoint_if_cancelled()
try:
self.put_nowait(obj)
except _core.WouldBlock:
pass
else:
await _core.cancel_shielded_checkpoint()
return

task = _core.current_task()
self._put_wait[task] = obj

def _get_protected(self):
self._put_semaphore.release()
return self._data.popleft()
def abort_fn(_):
del self._put_wait[task]
return _core.Abort.SUCCEEDED

await _core.wait_task_rescheduled(abort_fn)

@_core.enable_ki_protection
def get_nowait(self):
Expand All @@ -917,8 +933,8 @@ def get_nowait(self):
WouldBlock: if the queue is empty.
"""
self._get_semaphore.acquire_nowait()
return self._get_protected()
_, value = multi_get_nowait([self])
return value

@_core.enable_ki_protection
async def get(self):
Expand All @@ -928,8 +944,8 @@ async def get(self):
object: The dequeued object.
"""
await self._get_semaphore.acquire()
return self._get_protected()
_, value = await multi_get([self])
return value

@aiter_compat
def __aiter__(self):
Expand All @@ -954,6 +970,50 @@ def statistics(self):
return _QueueStats(
qsize=len(self._data),
capacity=self.capacity,
tasks_waiting_put=self._put_semaphore.statistics().tasks_waiting,
tasks_waiting_get=self._get_semaphore.statistics().tasks_waiting,
tasks_waiting_put=len(self._put_wait),
tasks_waiting_get=len(self._get_wait),
)


@_core.enable_ki_protection
def multi_get_nowait(queues):
for queue in queues:
if queue._put_wait:
task, value = queue._put_wait.popitem(last=False)
# No need to check max_size, b/c we'll pop an item off again right
# below.
queue._data.append(value)
_core.reschedule(task)
if queue._data:
value = queue._data.popleft()
return queue, value
raise _core.WouldBlock()


@_core.enable_ki_protection
async def multi_get(queues):
# Returns (queue object, value gotten)
await _core.checkpoint_if_cancelled()
try:
queue, value = multi_get_nowait(queues)
except _core.WouldBlock:
pass
else:
await _core.cancel_shielded_checkpoint()
return queue, value
# No queue had anything.
task = _core.current_task()

def abort_fn(_):
for queue in queues:
try:
del queue._get_wait[task]
except KeyError:
# If we just pushed to this queue, we already popped.
# But is it alright to... always pass?
pass
return _core.Abort.SUCCEEDED

for queue in queues:
queue._get_wait[task] = abort_fn
return await _core.wait_task_rescheduled(abort_fn)
25 changes: 23 additions & 2 deletions trio/tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,6 @@ async def test_Queue():
Queue(1.0)
with pytest.raises(ValueError):
Queue(-1)
with pytest.raises(ValueError):
Queue(0)

q = Queue(2)
repr(q) # smoke test
Expand Down Expand Up @@ -512,6 +510,29 @@ async def do_get(q):
assert (await q.get()) == 2


async def test_Queue_unbuffered():
q = Queue(0)
assert q.capacity == 0
assert q.qsize() == 0
assert q.empty()
assert q.full()
with pytest.raises(_core.WouldBlock):
q.get_nowait()
with pytest.raises(_core.WouldBlock):
q.put_nowait(1)

async def do_put(q, v):
with assert_checkpoints():
await q.put(v)

async with _core.open_nursery() as nursery:
nursery.start_soon(do_put, q, 1)
with assert_checkpoints():
assert await q.get() == 1
with pytest.raises(_core.WouldBlock):
q.get_nowait()


# Two ways of implementing a Lock in terms of a Queue. Used to let us put the
# Queue through the generic lock tests.

Expand Down

0 comments on commit a1ac052

Please sign in to comment.