diff --git a/src/plumpy/rmq/communications.py b/src/plumpy/rmq/communications.py index 188108be..9dbafbed 100644 --- a/src/plumpy/rmq/communications.py +++ b/src/plumpy/rmq/communications.py @@ -8,7 +8,7 @@ import kiwipy from plumpy import futures -from plumpy.rmq.futures import wrap_to_kiwi_future +from plumpy.rmq.futures import wrap_to_concurrent_future from plumpy.utils import ensure_coroutine __all__ = [ @@ -72,7 +72,7 @@ def converted(communicator: kiwipy.Communicator, *args: Any, **kwargs: Any) -> k msg_fn = functools.partial(coro, communicator, *args, **kwargs) task_future = futures.create_task(msg_fn, loop) - return wrap_to_kiwi_future(task_future) + return wrap_to_concurrent_future(task_future) return converted diff --git a/src/plumpy/rmq/futures.py b/src/plumpy/rmq/futures.py index 649fb2b4..897c8147 100644 --- a/src/plumpy/rmq/futures.py +++ b/src/plumpy/rmq/futures.py @@ -1,23 +1,107 @@ # -*- coding: utf-8 -*- -""" -Module containing future related methods and classes -""" +# mypy: disable-error-code="no-untyped-def, no-untyped-call" +"""Module containing future related methods and classes""" import asyncio import concurrent.futures -from asyncio.futures import _chain_future, _copy_future_state # type: ignore[attr-defined] from typing import Any import kiwipy -__all__ = ['chain', 'copy_future', 'wrap_to_kiwi_future'] +__all__ = ['wrap_to_concurrent_future'] -copy_future = _copy_future_state -chain = _chain_future +def _convert_future_exc(exc): + exc_class = type(exc) + if exc_class is concurrent.futures.CancelledError: + return asyncio.exceptions.CancelledError(*exc.args) + elif exc_class is concurrent.futures.TimeoutError: + return asyncio.exceptions.TimeoutError(*exc.args) + elif exc_class is concurrent.futures.InvalidStateError: + return asyncio.exceptions.InvalidStateError(*exc.args) + else: + return exc -def wrap_to_kiwi_future(future: asyncio.Future[Any]) -> kiwipy.Future: - """Wrap to concurrent.futures.Future object.""" + +def _set_concurrent_future_state(concurrent, source): + """Copy state from a future to a concurrent.futures.Future.""" + assert source.done() + if source.cancelled(): + concurrent.cancel() + if not concurrent.set_running_or_notify_cancel(): + return + exception = source.exception() + if exception is not None: + concurrent.set_exception(_convert_future_exc(exception)) + else: + result = source.result() + concurrent.set_result(result) + + +def _copy_future_state(source, dest): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert source.done() + if dest.cancelled(): + return + assert not dest.done() + if source.cancelled(): + dest.cancel() + else: + exception = source.exception() + if exception is not None: + dest.set_exception(_convert_future_exc(exception)) + else: + result = source.result() + dest.set_result(result) + + +def _chain_future(source, destination): + """Chain two futures so that when one completes, so does the other. + + The result (or exception) of source will be copied to destination. + If destination is cancelled, source gets cancelled too. + Compatible with both asyncio.Future and concurrent.futures.Future. + """ + if not asyncio.isfuture(source) and not isinstance(source, concurrent.futures.Future): + raise TypeError('A future is required for source argument') + if not asyncio.isfuture(destination) and not isinstance(destination, concurrent.futures.Future): + raise TypeError('A future is required for destination argument') + source_loop = asyncio.Future.get_loop(source) if asyncio.isfuture(source) else None + dest_loop = asyncio.Future.get_loop(destination) if asyncio.isfuture(destination) else None + + def _set_state(future, other): + if asyncio.isfuture(future): + _copy_future_state(other, future) + else: + _set_concurrent_future_state(future, other) + + def _call_check_cancel(destination): + if destination.cancelled(): + if source_loop is None or source_loop is dest_loop: + source.cancel() + else: + source_loop.call_soon_threadsafe(source.cancel) + + def _call_set_state(source): + if destination.cancelled() and dest_loop is not None and dest_loop.is_closed(): + return + if dest_loop is None or dest_loop is source_loop: + _set_state(destination, source) + else: + if dest_loop.is_closed(): + return + dest_loop.call_soon_threadsafe(_set_state, destination, source) + + destination.add_done_callback(_call_check_cancel) + source.add_done_callback(_call_set_state) + + +def wrap_to_concurrent_future(future: asyncio.Future[Any]) -> kiwipy.Future: + """Wrap to concurrent.futures.Future object. (the function is adapted from asyncio.future.wrap_future). + The function `_chain_future`, `_copy_future_state` is from asyncio future module.""" if isinstance(future, concurrent.futures.Future): return future assert asyncio.isfuture(future), f'concurrent.futures.Future is expected, got {future!r}'