Skip to content

Commit

Permalink
Hand write asynio wrap for concurrent future
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 12, 2024
1 parent e32ffa2 commit 0e64b5b
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/plumpy/rmq/communications.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import kiwipy

from plumpy import futures
from plumpy.rmq.futures import wrap_to_kiwi_future
from plumpy.rmq.futures import wrap_to_concurrent_future
from plumpy.utils import ensure_coroutine

__all__ = [
Expand Down Expand Up @@ -72,7 +72,7 @@ def converted(communicator: kiwipy.Communicator, *args: Any, **kwargs: Any) -> k

msg_fn = functools.partial(coro, communicator, *args, **kwargs)
task_future = futures.create_task(msg_fn, loop)
return wrap_to_kiwi_future(task_future)
return wrap_to_concurrent_future(task_future)

return converted

Expand Down
102 changes: 93 additions & 9 deletions src/plumpy/rmq/futures.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,107 @@
# -*- coding: utf-8 -*-
"""
Module containing future related methods and classes
"""
# mypy: disable-error-code="no-untyped-def, no-untyped-call"
"""Module containing future related methods and classes"""

import asyncio
import concurrent.futures
from asyncio.futures import _chain_future, _copy_future_state # type: ignore[attr-defined]
from typing import Any

import kiwipy

__all__ = ['chain', 'copy_future', 'wrap_to_kiwi_future']
__all__ = ['wrap_to_concurrent_future']

copy_future = _copy_future_state
chain = _chain_future

def _convert_future_exc(exc):
exc_class = type(exc)
if exc_class is concurrent.futures.CancelledError:
return asyncio.exceptions.CancelledError(*exc.args)
elif exc_class is concurrent.futures.TimeoutError:
return asyncio.exceptions.TimeoutError(*exc.args)
elif exc_class is concurrent.futures.InvalidStateError:
return asyncio.exceptions.InvalidStateError(*exc.args)
else:
return exc

def wrap_to_kiwi_future(future: asyncio.Future[Any]) -> kiwipy.Future:
"""Wrap to concurrent.futures.Future object."""

def _set_concurrent_future_state(concurrent, source):
"""Copy state from a future to a concurrent.futures.Future."""
assert source.done()
if source.cancelled():
concurrent.cancel()
if not concurrent.set_running_or_notify_cancel():
return
exception = source.exception()
if exception is not None:
concurrent.set_exception(_convert_future_exc(exception))
else:
result = source.result()
concurrent.set_result(result)


def _copy_future_state(source, dest):
"""Internal helper to copy state from another Future.
The other Future may be a concurrent.futures.Future.
"""
assert source.done()
if dest.cancelled():
return
assert not dest.done()
if source.cancelled():
dest.cancel()
else:
exception = source.exception()
if exception is not None:
dest.set_exception(_convert_future_exc(exception))
else:
result = source.result()
dest.set_result(result)


def _chain_future(source, destination):
"""Chain two futures so that when one completes, so does the other.
The result (or exception) of source will be copied to destination.
If destination is cancelled, source gets cancelled too.
Compatible with both asyncio.Future and concurrent.futures.Future.
"""
if not asyncio.isfuture(source) and not isinstance(source, concurrent.futures.Future):
raise TypeError('A future is required for source argument')
if not asyncio.isfuture(destination) and not isinstance(destination, concurrent.futures.Future):
raise TypeError('A future is required for destination argument')
source_loop = asyncio.Future.get_loop(source) if asyncio.isfuture(source) else None
dest_loop = asyncio.Future.get_loop(destination) if asyncio.isfuture(destination) else None

def _set_state(future, other):
if asyncio.isfuture(future):
_copy_future_state(other, future)
else:
_set_concurrent_future_state(future, other)

def _call_check_cancel(destination):
if destination.cancelled():
if source_loop is None or source_loop is dest_loop:
source.cancel()
else:
source_loop.call_soon_threadsafe(source.cancel)

def _call_set_state(source):
if destination.cancelled() and dest_loop is not None and dest_loop.is_closed():
return
if dest_loop is None or dest_loop is source_loop:
_set_state(destination, source)
else:
if dest_loop.is_closed():
return
dest_loop.call_soon_threadsafe(_set_state, destination, source)

destination.add_done_callback(_call_check_cancel)
source.add_done_callback(_call_set_state)


def wrap_to_concurrent_future(future: asyncio.Future[Any]) -> kiwipy.Future:
"""Wrap to concurrent.futures.Future object. (the function is adapted from asyncio.future.wrap_future).
The function `_chain_future`, `_copy_future_state` is from asyncio future module."""
if isinstance(future, concurrent.futures.Future):
return future
assert asyncio.isfuture(future), f'concurrent.futures.Future is expected, got {future!r}'
Expand Down

0 comments on commit 0e64b5b

Please sign in to comment.