Skip to content

Commit

Permalink
Use MsgStream.subscribe() in Context.result()
Browse files Browse the repository at this point in the history
The case exists where there is multiple tasks consuming from an open
2-way stream created via `Context.open_stream()` where a sibling task is
pulling from the stream while some other task also calls `.result()`.
Previously the `.result()` call would consume (drain) stream messages
directly from the underlying mem chan which would mean any sibling task
would not receive those same messages. Instead, make `.result()` check
if a stream is open and instead consume (and discard) stream msgs using
a `BroadcastReceiver` (via `MsgStream.subscribe()`) such that all
interested tasks get copies of the same packets.
  • Loading branch information
goodboy committed Jan 30, 2023
1 parent f7a1f38 commit c9eb466
Showing 1 changed file with 46 additions and 13 deletions.
59 changes: 46 additions & 13 deletions tractor/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
Optional,
Callable,
AsyncGenerator,
AsyncIterator
AsyncIterator,
TYPE_CHECKING,
)

import warnings
Expand All @@ -41,6 +42,10 @@
from .trionics import broadcast_receiver, BroadcastReceiver


if TYPE_CHECKING:
from ._portal import Portal


log = get_logger(__name__)


Expand Down Expand Up @@ -378,7 +383,8 @@ class Context:
_remote_func_type: Optional[str] = None

# only set on the caller side
_portal: Optional['Portal'] = None # type: ignore # noqa
_portal: Optional[Portal] = None # type: ignore # noqa
_stream: Optional[MsgStream] = None
_result: Optional[Any] = False
_error: Optional[BaseException] = None

Expand Down Expand Up @@ -486,6 +492,7 @@ async def cancel(
log.cancel(f'Cancelling {side} side of context to {self.chan.uid}')

self._cancel_called = True
ipc_broken: bool = False

if side == 'caller':
if not self._portal:
Expand All @@ -503,7 +510,14 @@ async def cancel(
# NOTE: we're telling the far end actor to cancel a task
# corresponding to *this actor*. The far end local channel
# instance is passed to `Actor._cancel_task()` implicitly.
await self._portal.run_from_ns('self', '_cancel_task', cid=cid)
try:
await self._portal.run_from_ns(
'self',
'_cancel_task',
cid=cid,
)
except trio.BrokenResourceError:
ipc_broken = True

if cs.cancelled_caught:
# XXX: there's no way to know if the remote task was indeed
Expand All @@ -519,7 +533,10 @@ async def cancel(
"Timed out on cancelling remote task "
f"{cid} for {self._portal.channel.uid}")

# callee side remote task
elif ipc_broken:
log.cancel(
"Transport layer was broken before cancel request "
f"{cid} for {self._portal.channel.uid}")
else:
self._cancel_msg = msg

Expand Down Expand Up @@ -607,6 +624,7 @@ async def open_stream(
ctx=self,
rx_chan=ctx._recv_chan,
) as stream:
self._stream = stream

if self._portal:
self._portal._streams.add(stream)
Expand Down Expand Up @@ -648,25 +666,22 @@ async def result(self) -> Any:

if not self._recv_chan._closed: # type: ignore

# wait for a final context result consuming
# and discarding any bi dir stream msgs still
# in transit from the far end.
while True:
def consume(
msg: dict,

msg = await self._recv_chan.receive()
) -> Optional[dict]:
try:
self._result = msg['return']
break
return msg['return']
except KeyError as msgerr:

if 'yield' in msg:
# far end task is still streaming to us so discard
log.warning(f'Discarding stream delivered {msg}')
continue
return

elif 'stop' in msg:
log.debug('Remote stream terminated')
continue
return

# internal error should never get here
assert msg.get('cid'), (
Expand All @@ -676,6 +691,24 @@ async def result(self) -> Any:
msg, self._portal.channel
) from msgerr

# wait for a final context result consuming
# and discarding any bi dir stream msgs still
# in transit from the far end.
if self._stream:
async with self._stream.subscribe() as bstream:
async for msg in bstream:
result = consume(msg)
if result:
self._result = result

if not self._result:
while True:
msg = await self._recv_chan.receive()
result = consume(msg)
if result:
self._result = result
break

return self._result

async def started(
Expand Down

0 comments on commit c9eb466

Please sign in to comment.