diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index f52a0d09..056dfe32 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -4,24 +4,33 @@ """ import asyncio +import contextlib from typing import Any, Awaitable, Callable, Optional -import kiwipy +__all__ = ['create_task', 'CancellableAction', 'create_task'] -__all__ = ['CancelledError', 'Future', 'chain', 'copy_future', 'create_task', 'gather'] -CancelledError = kiwipy.CancelledError +class InvalidFutureError(Exception): + """Exception for when a future or action is in an invalid state""" -class InvalidStateError(Exception): - """Exception for when a future or action is in an invalid state""" +Future = asyncio.Future -copy_future = kiwipy.copy_future -chain = kiwipy.chain -gather = asyncio.gather +@contextlib.contextmanager +def capture_exceptions(future: Future[Any], ignore: tuple[type[BaseException], ...] = ()): + """ + Capture any exceptions in the context and set them as the result of the given future -Future = asyncio.Future + :param future: The future to the exception on + :param ignore: An optional list of exception types to ignore, these will be raised and not set on the future + """ + try: + yield + except ignore: + raise + except Exception as exception: + future.set_exception(exception) class CancellableAction(Future): @@ -46,10 +55,10 @@ def run(self, *args: Any, **kwargs: Any) -> None: :param kwargs: the keyword arguments to the action """ if self.done(): - raise InvalidStateError('Action has already been ran') + raise InvalidFutureError('Action has already been ran') try: - with kiwipy.capture_exceptions(self): + with capture_exceptions(self): self.set_result(self._action(*args, **kwargs)) finally: self._action = None # type: ignore @@ -70,38 +79,9 @@ def create_task(coro: Callable[[], Awaitable[Any]], loop: Optional[asyncio.Abstr future = loop.create_future() async def run_task() -> None: - with kiwipy.capture_exceptions(future): + with capture_exceptions(future): res = await coro() future.set_result(res) asyncio.run_coroutine_threadsafe(run_task(), loop) return future - - -def unwrap_kiwi_future(future: kiwipy.Future) -> kiwipy.Future: - """ - Create a kiwi future that represents the final results of a nested series of futures, - meaning that if the futures provided itself resolves to a future the returned - future will not resolve to a value until the final chain of futures is not a future - but a concrete value. If at any point in the chain a future resolves to an exception - then the returned future will also resolve to that exception. - - :param future: the future to unwrap - :return: the unwrapping future - - """ - unwrapping = kiwipy.Future() - - def unwrap(fut: kiwipy.Future) -> None: - if fut.cancelled(): - unwrapping.cancel() - else: - with kiwipy.capture_exceptions(unwrapping): - result = fut.result() - if isinstance(result, kiwipy.Future): - result.add_done_callback(unwrap) - else: - unwrapping.set_result(result) - - future.add_done_callback(unwrap) - return unwrapping diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index a71f2b06..058e1aa4 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -8,7 +8,7 @@ import kiwipy -from . import futures, loaders, persistence +from . import loaders, persistence from .utils import PID_TYPE __all__ = [ @@ -448,11 +448,12 @@ def execute_process( :param no_reply: if True, this call will be fire-and-forget, i.e. no return value :return: the result of executing the process """ + from plumpy.rmq.futures import unwrap_kiwi_future message = create_create_body(process_class, init_args, init_kwargs, persist=True, loader=loader) execute_future = kiwipy.Future() - create_future = futures.unwrap_kiwi_future(self._communicator.task_send(message)) + create_future = unwrap_kiwi_future(self._communicator.task_send(message)) def on_created(_: Any) -> None: with kiwipy.capture_exceptions(execute_future): diff --git a/src/plumpy/rmq/exceptions.py b/src/plumpy/rmq/exceptions.py index b15d51c4..15dad7bc 100644 --- a/src/plumpy/rmq/exceptions.py +++ b/src/plumpy/rmq/exceptions.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed +import kiwipy __all__ = [ 'CommunicatorChannelInvalidStateError', @@ -9,3 +10,5 @@ # Alias aio_pika CommunicatorConnectionClosed = ConnectionClosed CommunicatorChannelInvalidStateError = ChannelInvalidStateError + +CancelledError = kiwipy.CancelledError diff --git a/src/plumpy/rmq/futures.py b/src/plumpy/rmq/futures.py new file mode 100644 index 00000000..59e21d41 --- /dev/null +++ b/src/plumpy/rmq/futures.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +""" +Module containing future related methods and classes +""" + +import kiwipy + +__all__ = ['chain', 'copy_future', 'unwrap_kiwi_future'] + +copy_future = kiwipy.copy_future +chain = kiwipy.chain + + +def unwrap_kiwi_future(future: kiwipy.Future) -> kiwipy.Future: + """ + Create a kiwi future that represents the final results of a nested series of futures, + meaning that if the futures provided itself resolves to a future the returned + future will not resolve to a value until the final chain of futures is not a future + but a concrete value. If at any point in the chain a future resolves to an exception + then the returned future will also resolve to that exception. + + :param future: the future to unwrap + :return: the unwrapping future + + """ + unwrapping = kiwipy.Future() + + def unwrap(fut: kiwipy.Future) -> None: + if fut.cancelled(): + unwrapping.cancel() + else: + with kiwipy.capture_exceptions(unwrapping): + result = fut.result() + if isinstance(result, kiwipy.Future): + result.add_done_callback(unwrap) + else: + unwrapping.set_result(result) + + future.add_done_callback(unwrap) + return unwrapping diff --git a/tests/rmq/test_communicator.py b/tests/rmq/test_communicator.py index 3644befb..2a0bfebc 100644 --- a/tests/rmq/test_communicator.py +++ b/tests/rmq/test_communicator.py @@ -64,7 +64,7 @@ class TestLoopCommunicator: @pytest.mark.asyncio async def test_broadcast(self, loop_communicator): BROADCAST = {'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420} # noqa: N806 - broadcast_future = plumpy.Future() + broadcast_future = asyncio.Future() loop = asyncio.get_event_loop() @@ -83,7 +83,7 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): @pytest.mark.asyncio async def test_broadcast_filter(self, loop_communicator): - broadcast_future = plumpy.Future() + broadcast_future = asyncio.Future() def ignore_broadcast(_comm, body, sender, subject, correlation_id): broadcast_future.set_exception(AssertionError('broadcast received')) @@ -103,7 +103,7 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): @pytest.mark.asyncio async def test_rpc(self, loop_communicator): MSG = 'rpc this' # noqa: N806 - rpc_future = plumpy.Future() + rpc_future = asyncio.Future() loop = asyncio.get_event_loop() @@ -120,7 +120,7 @@ def get_rpc(_comm, msg): @pytest.mark.asyncio async def test_task(self, loop_communicator): TASK = 'task this' # noqa: N806 - task_future = plumpy.Future() + task_future = asyncio.Future() loop = asyncio.get_event_loop() diff --git a/tests/test_processes.py b/tests/test_processes.py index faea9eae..a4238fbd 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -8,6 +8,7 @@ import kiwipy import pytest +from plumpy.futures import CancellableAction from tests import utils import plumpy @@ -540,7 +541,7 @@ def test_pause_in_process(self): class TestPausePlay(plumpy.Process): def run(self): fut = self.pause() - test_case.assertIsInstance(fut, plumpy.Future) + assert isinstance(fut, CancellableAction) loop = asyncio.get_event_loop() @@ -564,7 +565,7 @@ def test_pause_play_in_process(self): class TestPausePlay(plumpy.Process): def run(self): fut = self.pause() - test_case.assertIsInstance(fut, plumpy.Future) + test_case.assertIsInstance(fut, CancellableAction) result = self.play() test_case.assertTrue(result) diff --git a/tests/utils.py b/tests/utils.py index f2a58dfc..9f7bfb22 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -468,7 +468,7 @@ def run_until_waiting(proc): from plumpy import ProcessState listener = plumpy.ProcessListener() - in_waiting = plumpy.Future() + in_waiting = asyncio.Future() if proc.state == ProcessState.WAITING: in_waiting.set_result(True) @@ -488,7 +488,7 @@ def run_until_paused(proc): """Set up a future that will be resolved when the process is paused""" listener = plumpy.ProcessListener() - paused = plumpy.Future() + paused = asyncio.Future() if proc.paused: paused.set_result(True)