diff --git a/trio/__init__.py b/trio/__init__.py index 7ea2df21eb..acdc383b24 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -23,6 +23,9 @@ from ._sync import * __all__ += _sync.__all__ +from ._channel import * +__all__ += _channel.__all__ + from ._threads import * __all__ += _threads.__all__ diff --git a/trio/_channel.py b/trio/_channel.py new file mode 100644 index 0000000000..bb2995516b --- /dev/null +++ b/trio/_channel.py @@ -0,0 +1,202 @@ +from collections import deque, OrderedDict +from math import inf + +import attr +from outcome import Error, Value + +from . import _core +from ._util import aiter_compat + +__all__ = ["open_channel", "EndOfChannel", "BrokenChannelError"] + +# TODO: +# - introspection: +# - statistics +# - capacity, usage +# - repr +# - BrokenResourceError? +# - tests +# - docs + + +class EndOfChannel(Exception): + pass + + +class BrokenChannelError(Exception): + pass + + +def open_channel(capacity): + if capacity != inf and not isinstance(capacity, int): + raise TypeError("capacity must be an integer or math.inf") + if capacity < 0: + raise ValueError("capacity must be >= 0") + buf = ChannelBuf(capacity) + return PutChannel(buf), GetChannel(buf) + + +@attr.s(cmp=False, hash=False) +class ChannelBuf: + capacity = attr.ib() + data = attr.ib(default=attr.Factory(deque)) + # counts + put_channels = attr.ib(default=0) + get_channels = attr.ib(default=0) + # {task: value} + put_tasks = attr.ib(default=attr.Factory(OrderedDict)) + # {task: None} + get_tasks = attr.ib(default=attr.Factory(OrderedDict)) + + +class PutChannel: + def __init__(self, buf): + self._buf = buf + self.closed = False + self._tasks = set() + self._buf.put_channels += 1 + + @_core.disable_ki_protection + def put_nowait(self, value): + if self.closed: + raise _core.ClosedResourceError + if not self._buf.get_channels: + raise BrokenChannelError + if self._buf.get_tasks: + assert not self._buf.data + task = next(iter(self._buf.get_tasks)) + _core.reschedule(task, Value(value)) + elif len(self._buf.data) < self._buf.capacity: + self._buf.data.append(value) + else: + raise _core.WouldBlock + + @_core.disable_ki_protection + async def put(self, value): + await _core.checkpoint_if_cancelled() + try: + self.put_nowait(value) + except _core.WouldBlock: + pass + else: + await _core.cancel_shielded_checkpoint() + return + + task = _core.current_task() + self._tasks.add(task) + self._buf.put_tasks[task] = value + + def abort_fn(_): + self._tasks.remove(task) + del self._buf.put_tasks[task] + return _core.Abort.SUCCEEDED + + await _core.wait_task_rescheduled(abort_fn, always_abort=True) + + @_core.disable_ki_protection + def clone(self): + if self.closed: + raise _core.ClosedResourceError + return PutChannel(self._buf) + + @_core.disable_ki_protection + def close(self): + if self.closed: + return + self.closed = True + for task in list(self._tasks): + _core.reschedule(task, Error(ClosedResourceError())) + self._buf.put_channels -= 1 + if self._buf.put_channels == 0: + assert not self._buf.put_tasks + for task in list(self._buf.get_tasks): + _core.reschedule(task, Error(EndOfChannel())) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +class GetChannel: + def __init__(self, buf): + self._buf = buf + self.closed = False + self._tasks = set() + self._buf.get_channels += 1 + + @_core.disable_ki_protection + def get_nowait(self): + if self.closed: + raise _core.ClosedResourceError + buf = self._buf + if buf.put_tasks: + task, value = next(iter(buf.put_tasks.items())) + _core.reschedule(task) + return value + if buf.data: + return buf.data.popleft() + if not buf.put_channels: + raise EndOfChannel + raise _core.WouldBlock + + @_core.disable_ki_protection + async def get(self): + await _core.checkpoint_if_cancelled() + try: + return self.get_nowait() + except _core.WouldBlock: + pass + else: + await _core.cancel_shielded_checkpoint() + return + + task = _core.current_task() + self._tasks.add(task) + self._buf.get_tasks[task] = None + + def abort_fn(_): + self._tasks.remove(task) + del self._buf.get_tasks[task] + return _core.Abort.SUCCEEDED + + return await _core.wait_task_rescheduled(abort_fn, always_abort=True) + + @_core.disable_ki_protection + def clone(self): + if self.closed: + raise _core.ClosedResourceError + return GetChannel(self._buf) + + @_core.disable_ki_protection + def close(self): + if self.closed: + return + self.closed = True + for task in list(self._tasks): + _core.reschedule(task, Error(ClosedResourceError())) + self._buf.get_channels -= 1 + if self._buf.get_channels == 0: + assert not self._buf.get_tasks + for task in list(self._buf.put_tasks): + _core.reschedule(task, Error(BrokenChannelError())) + # XX: or if we're losing data, maybe we should raise a + # BrokenChannelError here? + self._buf.data.clear() + + @aiter_compat + def __aiter__(self): + return self + + async def __anext__(self): + try: + return await self.get() + except EndOfChannel: + raise StopAsyncIteration + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close()