Skip to content

Commit

Permalink
distinguish concurrent.future.Future and asyncio.Future
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 12, 2024
1 parent 317e434 commit 6184c11
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 51 deletions.
62 changes: 21 additions & 41 deletions src/plumpy/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,33 @@
"""

import asyncio
import contextlib
from typing import Any, Awaitable, Callable, Optional

import kiwipy
__all__ = ['create_task', 'CancellableAction', 'create_task']

__all__ = ['CancelledError', 'Future', 'chain', 'copy_future', 'create_task', 'gather']

CancelledError = kiwipy.CancelledError
class InvalidFutureError(Exception):
"""Exception for when a future or action is in an invalid state"""


class InvalidStateError(Exception):
"""Exception for when a future or action is in an invalid state"""
Future = asyncio.Future


copy_future = kiwipy.copy_future
chain = kiwipy.chain
gather = asyncio.gather
@contextlib.contextmanager
def capture_exceptions(future: Future[Any], ignore: tuple[type[BaseException], ...] = ()):
"""
Capture any exceptions in the context and set them as the result of the given future
Future = asyncio.Future
:param future: The future to the exception on
:param ignore: An optional list of exception types to ignore, these will be raised and not set on the future
"""
try:
yield
except ignore:
raise
except Exception as exception:
future.set_exception(exception)


class CancellableAction(Future):
Expand All @@ -46,10 +55,10 @@ def run(self, *args: Any, **kwargs: Any) -> None:
:param kwargs: the keyword arguments to the action
"""
if self.done():
raise InvalidStateError('Action has already been ran')
raise InvalidFutureError('Action has already been ran')

try:
with kiwipy.capture_exceptions(self):
with capture_exceptions(self):
self.set_result(self._action(*args, **kwargs))
finally:
self._action = None # type: ignore
Expand All @@ -70,38 +79,9 @@ def create_task(coro: Callable[[], Awaitable[Any]], loop: Optional[asyncio.Abstr
future = loop.create_future()

async def run_task() -> None:
with kiwipy.capture_exceptions(future):
with capture_exceptions(future):
res = await coro()
future.set_result(res)

asyncio.run_coroutine_threadsafe(run_task(), loop)
return future


def unwrap_kiwi_future(future: kiwipy.Future) -> kiwipy.Future:
"""
Create a kiwi future that represents the final results of a nested series of futures,
meaning that if the futures provided itself resolves to a future the returned
future will not resolve to a value until the final chain of futures is not a future
but a concrete value. If at any point in the chain a future resolves to an exception
then the returned future will also resolve to that exception.
:param future: the future to unwrap
:return: the unwrapping future
"""
unwrapping = kiwipy.Future()

def unwrap(fut: kiwipy.Future) -> None:
if fut.cancelled():
unwrapping.cancel()
else:
with kiwipy.capture_exceptions(unwrapping):
result = fut.result()
if isinstance(result, kiwipy.Future):
result.add_done_callback(unwrap)
else:
unwrapping.set_result(result)

future.add_done_callback(unwrap)
return unwrapping
5 changes: 3 additions & 2 deletions src/plumpy/process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import kiwipy

from . import futures, loaders, persistence
from . import loaders, persistence
from .utils import PID_TYPE

__all__ = [
Expand Down Expand Up @@ -448,11 +448,12 @@ def execute_process(
:param no_reply: if True, this call will be fire-and-forget, i.e. no return value
:return: the result of executing the process
"""
from plumpy.rmq.futures import unwrap_kiwi_future

message = create_create_body(process_class, init_args, init_kwargs, persist=True, loader=loader)

execute_future = kiwipy.Future()
create_future = futures.unwrap_kiwi_future(self._communicator.task_send(message))
create_future = unwrap_kiwi_future(self._communicator.task_send(message))

def on_created(_: Any) -> None:
with kiwipy.capture_exceptions(execute_future):
Expand Down
3 changes: 3 additions & 0 deletions src/plumpy/rmq/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed
import kiwipy

__all__ = [
'CommunicatorChannelInvalidStateError',
Expand All @@ -9,3 +10,5 @@
# Alias aio_pika
CommunicatorConnectionClosed = ConnectionClosed
CommunicatorChannelInvalidStateError = ChannelInvalidStateError

CancelledError = kiwipy.CancelledError
40 changes: 40 additions & 0 deletions src/plumpy/rmq/futures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
"""
Module containing future related methods and classes
"""

import kiwipy

__all__ = ['chain', 'copy_future', 'unwrap_kiwi_future']

copy_future = kiwipy.copy_future
chain = kiwipy.chain


def unwrap_kiwi_future(future: kiwipy.Future) -> kiwipy.Future:
"""
Create a kiwi future that represents the final results of a nested series of futures,
meaning that if the futures provided itself resolves to a future the returned
future will not resolve to a value until the final chain of futures is not a future
but a concrete value. If at any point in the chain a future resolves to an exception
then the returned future will also resolve to that exception.
:param future: the future to unwrap
:return: the unwrapping future
"""
unwrapping = kiwipy.Future()

def unwrap(fut: kiwipy.Future) -> None:
if fut.cancelled():
unwrapping.cancel()
else:
with kiwipy.capture_exceptions(unwrapping):
result = fut.result()
if isinstance(result, kiwipy.Future):
result.add_done_callback(unwrap)
else:
unwrapping.set_result(result)

future.add_done_callback(unwrap)
return unwrapping
8 changes: 4 additions & 4 deletions tests/rmq/test_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class TestLoopCommunicator:
@pytest.mark.asyncio
async def test_broadcast(self, loop_communicator):
BROADCAST = {'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420} # noqa: N806
broadcast_future = plumpy.Future()
broadcast_future = asyncio.Future()

loop = asyncio.get_event_loop()

Expand All @@ -83,7 +83,7 @@ def get_broadcast(_comm, body, sender, subject, correlation_id):

@pytest.mark.asyncio
async def test_broadcast_filter(self, loop_communicator):
broadcast_future = plumpy.Future()
broadcast_future = asyncio.Future()

def ignore_broadcast(_comm, body, sender, subject, correlation_id):
broadcast_future.set_exception(AssertionError('broadcast received'))
Expand All @@ -103,7 +103,7 @@ def get_broadcast(_comm, body, sender, subject, correlation_id):
@pytest.mark.asyncio
async def test_rpc(self, loop_communicator):
MSG = 'rpc this' # noqa: N806
rpc_future = plumpy.Future()
rpc_future = asyncio.Future()

loop = asyncio.get_event_loop()

Expand All @@ -120,7 +120,7 @@ def get_rpc(_comm, msg):
@pytest.mark.asyncio
async def test_task(self, loop_communicator):
TASK = 'task this' # noqa: N806
task_future = plumpy.Future()
task_future = asyncio.Future()

loop = asyncio.get_event_loop()

Expand Down
5 changes: 3 additions & 2 deletions tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import kiwipy
import pytest
from plumpy.futures import CancellableAction
from tests import utils

import plumpy
Expand Down Expand Up @@ -540,7 +541,7 @@ def test_pause_in_process(self):
class TestPausePlay(plumpy.Process):
def run(self):
fut = self.pause()
test_case.assertIsInstance(fut, plumpy.Future)
assert isinstance(fut, CancellableAction)

loop = asyncio.get_event_loop()

Expand All @@ -564,7 +565,7 @@ def test_pause_play_in_process(self):
class TestPausePlay(plumpy.Process):
def run(self):
fut = self.pause()
test_case.assertIsInstance(fut, plumpy.Future)
test_case.assertIsInstance(fut, CancellableAction)
result = self.play()
test_case.assertTrue(result)

Expand Down
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def run_until_waiting(proc):
from plumpy import ProcessState

listener = plumpy.ProcessListener()
in_waiting = plumpy.Future()
in_waiting = asyncio.Future()

if proc.state == ProcessState.WAITING:
in_waiting.set_result(True)
Expand All @@ -488,7 +488,7 @@ def run_until_paused(proc):
"""Set up a future that will be resolved when the process is paused"""

listener = plumpy.ProcessListener()
paused = plumpy.Future()
paused = asyncio.Future()

if proc.paused:
paused.set_result(True)
Expand Down

0 comments on commit 6184c11

Please sign in to comment.