From 22c184a5d275565bff6ea5e970d691631ec7dc00 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Fri, 13 Dec 2024 01:38:34 +0100 Subject: [PATCH] Move communication to rmq --- src/plumpy/__init__.py | 6 +- src/plumpy/futures.py | 2 +- src/plumpy/message.py | 310 +++++++++++++++++ src/plumpy/processes.py | 47 +-- src/plumpy/rmq/__init__.py | 4 +- src/plumpy/{ => rmq}/process_comms.py | 313 +----------------- tests/{ => rmq}/test_communications.py | 0 tests/rmq/test_communicator.py | 3 +- tests/rmq/test_process_comms.py | 8 +- ...{test_process_comms.py => test_message.py} | 8 +- tests/test_processes.py | 2 +- tests/utils.py | 2 +- 12 files changed, 365 insertions(+), 340 deletions(-) create mode 100644 src/plumpy/message.py rename src/plumpy/{ => rmq}/process_comms.py (56%) rename tests/{ => rmq}/test_communications.py (100%) rename tests/{test_process_comms.py => test_message.py} (90%) diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index 8fd94df9..237617ac 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -8,14 +8,13 @@ from .exceptions import * from .futures import * from .loaders import * +from .message import * from .mixins import * from .persistence import * from .ports import * -from .process_comms import * from .process_listener import * from .process_states import * from .processes import * -from .rmq import * from .utils import * from .workchains import * @@ -27,8 +26,7 @@ + futures.__all__ + mixins.__all__ + persistence.__all__ - + rmq.__all__ - + process_comms.__all__ + + message.__all__ + process_listener.__all__ + workchains.__all__ + loaders.__all__ diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index a467f5d8..2f861d64 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -7,7 +7,7 @@ import contextlib from typing import Any, Awaitable, Callable, Generator, Optional -__all__ = ['CancellableAction', 'create_task', 'create_task'] +__all__ = ['CancellableAction', 'create_task', 'create_task', 'capture_exceptions'] class InvalidFutureError(Exception): diff --git a/src/plumpy/message.py b/src/plumpy/message.py new file mode 100644 index 00000000..b18d4123 --- /dev/null +++ b/src/plumpy/message.py @@ -0,0 +1,310 @@ +# -*- coding: utf-8 -*- +"""Module for process level communication functions and classes""" + +import asyncio +import logging +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast + +import kiwipy + +from . import loaders, persistence +from .utils import PID_TYPE + +__all__ = [ + 'KILL_MSG', + 'PAUSE_MSG', + 'PLAY_MSG', + 'STATUS_MSG', + 'ProcessLauncher', + 'create_continue_body', + 'create_launch_body', +] + +if TYPE_CHECKING: + from .processes import Process + +INTENT_KEY = 'intent' +MESSAGE_KEY = 'message' + + +class Intent: + """Intent constants for a process message""" + + PLAY: str = 'play' + PAUSE: str = 'pause' + KILL: str = 'kill' + STATUS: str = 'status' + + +PAUSE_MSG = {INTENT_KEY: Intent.PAUSE} +PLAY_MSG = {INTENT_KEY: Intent.PLAY} +KILL_MSG = {INTENT_KEY: Intent.KILL} +STATUS_MSG = {INTENT_KEY: Intent.STATUS} + +TASK_KEY = 'task' +TASK_ARGS = 'args' +PERSIST_KEY = 'persist' +# Launch +PROCESS_CLASS_KEY = 'process_class' +ARGS_KEY = 'init_args' +KWARGS_KEY = 'init_kwargs' +NOWAIT_KEY = 'nowait' +# Continue +PID_KEY = 'pid' +TAG_KEY = 'tag' +# Task types +LAUNCH_TASK = 'launch' +CONTINUE_TASK = 'continue' +CREATE_TASK = 'create' + +LOGGER = logging.getLogger(__name__) + + +def create_launch_body( + process_class: str, + init_args: Optional[Sequence[Any]] = None, + init_kwargs: Optional[Dict[str, Any]] = None, + persist: bool = False, + loader: Optional[loaders.ObjectLoader] = None, + nowait: bool = True, +) -> Dict[str, Any]: + """ + Create a message body for the launch action + + :param process_class: the class of the process to launch + :param init_args: any initialisation positional arguments + :param init_kwargs: any initialisation keyword arguments + :param persist: persist this process if True, otherwise don't + :param loader: the loader to use to load the persisted process + :param nowait: wait for the process to finish before completing the task, otherwise just return the PID + :return: a dictionary with the body of the message to launch the process + :rtype: dict + """ + if loader is None: + loader = loaders.get_object_loader() + + msg_body = { + TASK_KEY: LAUNCH_TASK, + TASK_ARGS: { + PROCESS_CLASS_KEY: loader.identify_object(process_class), + PERSIST_KEY: persist, + NOWAIT_KEY: nowait, + ARGS_KEY: init_args, + KWARGS_KEY: init_kwargs, + }, + } + return msg_body + + +def create_continue_body(pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False) -> Dict[str, Any]: + """ + Create a message body to continue an existing process + :param pid: the pid of the existing process + :param tag: the optional persistence tag + :param nowait: wait for the process to finish before completing the task, otherwise just return the PID + :return: a dictionary with the body of the message to continue the process + + """ + msg_body = {TASK_KEY: CONTINUE_TASK, TASK_ARGS: {PID_KEY: pid, NOWAIT_KEY: nowait, TAG_KEY: tag}} + return msg_body + + +def create_create_body( + process_class: str, + init_args: Optional[Sequence[Any]] = None, + init_kwargs: Optional[Dict[str, Any]] = None, + persist: bool = False, + loader: Optional[loaders.ObjectLoader] = None, +) -> Dict[str, Any]: + """ + Create a message body to create a new process + :param process_class: the class of the process to launch + :param init_args: any initialisation positional arguments + :param init_kwargs: any initialisation keyword arguments + :param persist: persist this process if True, otherwise don't + :param loader: the loader to use to load the persisted process + :return: a dictionary with the body of the message to launch the process + + """ + if loader is None: + loader = loaders.get_object_loader() + + msg_body = { + TASK_KEY: CREATE_TASK, + TASK_ARGS: { + PROCESS_CLASS_KEY: loader.identify_object(process_class), + PERSIST_KEY: persist, + ARGS_KEY: init_args, + KWARGS_KEY: init_kwargs, + }, + } + return msg_body + + +class ProcessLauncher: + """ + Takes incoming task messages and uses them to launch processes. + + Expected format of task: + + For launch:: + + { + 'task': + 'process_class': + 'args': + 'kwargs': . + 'nowait': True or False + } + + For continue:: + + { + 'task': + 'pid': + 'nowait': True or False + } + """ + + def __init__( + self, + loop: Optional[asyncio.AbstractEventLoop] = None, + persister: Optional[persistence.Persister] = None, + load_context: Optional[persistence.LoadSaveContext] = None, + loader: Optional[loaders.ObjectLoader] = None, + ) -> None: + self._loop = loop + self._persister = persister + self._load_context = load_context if load_context is not None else persistence.LoadSaveContext() + + if loader is not None: + self._loader = loader + self._load_context = self._load_context.copyextend(loader=loader) + else: + self._loader = loaders.get_object_loader() + + async def __call__(self, communicator: kiwipy.Communicator, task: Dict[str, Any]) -> Union[PID_TYPE, Any]: + """ + Receive a task. + :param task: The task message + """ + from plumpy.rmq import communications + + task_type = task[TASK_KEY] + if task_type == LAUNCH_TASK: + return await self._launch(communicator, **task.get(TASK_ARGS, {})) + if task_type == CONTINUE_TASK: + return await self._continue(communicator, **task.get(TASK_ARGS, {})) + if task_type == CREATE_TASK: + return await self._create(communicator, **task.get(TASK_ARGS, {})) + + raise communications.TaskRejected + + async def _launch( + self, + _communicator: kiwipy.Communicator, + process_class: str, + persist: bool, + nowait: bool, + init_args: Optional[Sequence[Any]] = None, + init_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[PID_TYPE, Any]: + """ + Launch the process + + :param _communicator: the communicator + :param process_class: the process class to launch + :param persist: should the process be persisted + :param nowait: if True only return when the process finishes + :param init_args: positional arguments to the process constructor + :param init_kwargs: keyword arguments to the process constructor + :return: the pid of the created process or the outputs (if nowait=False) + """ + from plumpy.rmq import communications + + if persist and not self._persister: + raise communications.TaskRejected('Cannot persist process, no persister') + + if init_args is None: + init_args = () + if init_kwargs is None: + init_kwargs = {} + + proc_class = self._loader.load_object(process_class) + proc = proc_class(*init_args, **init_kwargs) + if persist and self._persister is not None: + self._persister.save_checkpoint(proc) + + if nowait: + # XXX: can return a reference and gracefully use task to cancel itself when the upper call stack fails + asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 + return proc.pid + + await proc.step_until_terminated() + + return proc.future().result() + + async def _continue( + self, _communicator: kiwipy.Communicator, pid: 'PID_TYPE', nowait: bool, tag: Optional[str] = None + ) -> Union[PID_TYPE, Any]: + """ + Continue the process + + :param _communicator: the communicator + :param pid: the pid of the process to continue + :param nowait: if True don't wait for the process to complete + :param tag: the checkpoint tag to continue from + """ + from plumpy.rmq import communications + + if not self._persister: + LOGGER.warning('rejecting task: cannot continue process<%d> because no persister is available', pid) + raise communications.TaskRejected('Cannot continue process, no persister') + + # Do not catch exceptions here, because if these operations fail, the continue task should except and bubble up + saved_state = self._persister.load_checkpoint(pid, tag) + proc = cast('Process', saved_state.unbundle(self._load_context)) + + if nowait: + # XXX: can return a reference and gracefully use task to cancel itself when the upper call stack fails + asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 + return proc.pid + + await proc.step_until_terminated() + + return proc.future().result() + + async def _create( + self, + _communicator: kiwipy.Communicator, + process_class: str, + persist: bool, + init_args: Optional[Sequence[Any]] = None, + init_kwargs: Optional[Dict[str, Any]] = None, + ) -> 'PID_TYPE': + """ + Create the process + + :param _communicator: the communicator + :param process_class: the process class to create + :param persist: should the process be persisted + :param init_args: positional arguments to the process constructor + :param init_kwargs: keyword arguments to the process constructor + :return: the pid of the created process + """ + from plumpy.rmq import communications + + if persist and not self._persister: + raise communications.TaskRejected('Cannot persist process, no persister') + + if init_args is None: + init_args = () + if init_kwargs is None: + init_kwargs = {} + + proc_class = self._loader.load_object(process_class) + proc = proc_class(*init_args, **init_kwargs) + if persist and self._persister is not None: + self._persister.save_checkpoint(proc) + + return proc.pid diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index c8159ec6..091fd05d 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -38,7 +38,8 @@ import kiwipy import yaml -from . import events, exceptions, futures, persistence, ports, process_comms, process_states, utils +from . import events, exceptions, message, persistence, ports, process_states, utils +from .futures import capture_exceptions, CancellableAction from .base import state_machine from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check @@ -134,10 +135,10 @@ class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMe _spec_class = ProcessSpec # Default placeholders, will be populated in init() _stepping = False - _pausing: Optional[futures.CancellableAction] = None + _pausing: Optional[CancellableAction] = None _paused: Optional[persistence.SavableFuture] = None - _killing: Optional[futures.CancellableAction] = None - _interrupt_action: Optional[futures.CancellableAction] = None + _killing: Optional[CancellableAction] = None + _interrupt_action: Optional[CancellableAction] = None _closed = False _cleanups: Optional[List[Callable[[], None]]] = None @@ -317,7 +318,7 @@ def init(self) -> None: if not self._future.done(): - def try_killing(future: futures.Future) -> None: + def try_killing(future: asyncio.Future) -> None: if future.cancelled(): if not self.kill('Killed by future being cancelled'): self.logger.warning('Process<%s>: Failed to kill process on future cancel', self.pid) @@ -909,15 +910,15 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An """ self.logger.debug("Process<%s>: received RPC message with communicator '%s': %r", self.pid, _comm, msg) - intent = msg[process_comms.INTENT_KEY] + intent = msg[message.INTENT_KEY] - if intent == process_comms.Intent.PLAY: + if intent == message.Intent.PLAY: return self._schedule_rpc(self.play) - if intent == process_comms.Intent.PAUSE: - return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None)) - if intent == process_comms.Intent.KILL: - return self._schedule_rpc(self.kill, msg=msg.get(process_comms.MESSAGE_KEY, None)) - if intent == process_comms.Intent.STATUS: + if intent == message.Intent.PAUSE: + return self._schedule_rpc(self.pause, msg=msg.get(message.MESSAGE_KEY, None)) + if intent == message.Intent.KILL: + return self._schedule_rpc(self.kill, msg=msg.get(message.MESSAGE_KEY, None)) + if intent == message.Intent.STATUS: status_info: Dict[str, Any] = {} self.get_status_info(status_info) return status_info @@ -940,11 +941,11 @@ def broadcast_receive( ) # If we get a message we recognise then action it, otherwise ignore - if subject == process_comms.Intent.PLAY: + if subject == message.Intent.PLAY: return self._schedule_rpc(self.play) - if subject == process_comms.Intent.PAUSE: + if subject == message.Intent.PAUSE: return self._schedule_rpc(self.pause, msg=body) - if subject == process_comms.Intent.KILL: + if subject == message.Intent.KILL: return self._schedule_rpc(self.kill, msg=body) return None @@ -966,7 +967,7 @@ def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) kiwi_future = kiwipy.Future() async def run_callback() -> None: - with kiwipy.capture_exceptions(kiwi_future): + with capture_exceptions(kiwi_future): result = callback(*args, **kwargs) while asyncio.isfuture(result): result = await result @@ -1010,7 +1011,7 @@ def transition_failed( self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace) - def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: + def pause(self, msg: Union[str, None] = None) -> Union[bool, CancellableAction]: """Pause the process. :param msg: an optional message to set as the status. The current status will be saved in the private @@ -1039,7 +1040,7 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable self._pausing = self._interrupt_action # Try to interrupt the state self._state.interrupt(interrupt_exception) - return cast(futures.CancellableAction, self._interrupt_action) + return cast(CancellableAction, self._interrupt_action) return self._do_pause(msg) @@ -1055,7 +1056,7 @@ def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_state return True - def _create_interrupt_action(self, exception: process_states.Interruption) -> futures.CancellableAction: + def _create_interrupt_action(self, exception: process_states.Interruption) -> CancellableAction: """ Create an interrupt action from the corresponding interrupt exception @@ -1065,7 +1066,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu """ if isinstance(exception, process_states.PauseInterruption): do_pause = functools.partial(self._do_pause, str(exception)) - return futures.CancellableAction(do_pause, cookie=exception) + return CancellableAction(do_pause, cookie=exception) if isinstance(exception, process_states.KillInterruption): @@ -1077,11 +1078,11 @@ def do_kill(_next_state: process_states.State) -> Any: finally: self._killing = None - return futures.CancellableAction(do_kill, cookie=exception) + return CancellableAction(do_kill, cookie=exception) raise ValueError(f"Got unknown interruption type '{type(exception)}'") - def _set_interrupt_action(self, new_action: Optional[futures.CancellableAction]) -> None: + def _set_interrupt_action(self, new_action: Optional[CancellableAction]) -> None: """ Set the interrupt action cancelling the current one if it exists :param new_action: The new interrupt action to set @@ -1150,7 +1151,7 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, asyncio.Future]: self._set_interrupt_action_from_exception(interrupt_exception) self._killing = self._interrupt_action self._state.interrupt(interrupt_exception) - return cast(futures.CancellableAction, self._interrupt_action) + return cast(CancellableAction, self._interrupt_action) self.transition_to(process_states.ProcessState.KILLED, msg) return True diff --git a/src/plumpy/rmq/__init__.py b/src/plumpy/rmq/__init__.py index 779cc42b..ca14e02e 100644 --- a/src/plumpy/rmq/__init__.py +++ b/src/plumpy/rmq/__init__.py @@ -2,5 +2,7 @@ # mypy: disable-error-code=name-defined from .communications import * from .exceptions import * +from .futures import * +from .process_comms import * -__all__ = exceptions.__all__ + communications.__all__ +__all__ = exceptions.__all__ + communications.__all__ + futures.__all__ + process_comms.__all__ diff --git a/src/plumpy/process_comms.py b/src/plumpy/rmq/process_comms.py similarity index 56% rename from src/plumpy/process_comms.py rename to src/plumpy/rmq/process_comms.py index 10abeb53..75db4e6e 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/rmq/process_comms.py @@ -3,149 +3,33 @@ import asyncio import copy -import logging from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast import kiwipy -from . import loaders, persistence -from .utils import PID_TYPE +from plumpy.message import ( + MESSAGE_KEY, + PAUSE_MSG, + PLAY_MSG, + STATUS_MSG, + KILL_MSG, + Intent, + create_continue_body, + create_create_body, + create_launch_body, +) + +from plumpy import loaders +from plumpy.utils import PID_TYPE __all__ = [ - 'KILL_MSG', - 'PAUSE_MSG', - 'PLAY_MSG', - 'STATUS_MSG', - 'ProcessLauncher', 'RemoteProcessController', 'RemoteProcessThreadController', - 'create_continue_body', - 'create_launch_body', ] -if TYPE_CHECKING: - from .processes import Process - ProcessResult = Any ProcessStatus = Any -INTENT_KEY = 'intent' -MESSAGE_KEY = 'message' - - -class Intent: - """Intent constants for a process message""" - - PLAY: str = 'play' - PAUSE: str = 'pause' - KILL: str = 'kill' - STATUS: str = 'status' - - -PAUSE_MSG = {INTENT_KEY: Intent.PAUSE} -PLAY_MSG = {INTENT_KEY: Intent.PLAY} -KILL_MSG = {INTENT_KEY: Intent.KILL} -STATUS_MSG = {INTENT_KEY: Intent.STATUS} - -TASK_KEY = 'task' -TASK_ARGS = 'args' -PERSIST_KEY = 'persist' -# Launch -PROCESS_CLASS_KEY = 'process_class' -ARGS_KEY = 'init_args' -KWARGS_KEY = 'init_kwargs' -NOWAIT_KEY = 'nowait' -# Continue -PID_KEY = 'pid' -TAG_KEY = 'tag' -# Task types -LAUNCH_TASK = 'launch' -CONTINUE_TASK = 'continue' -CREATE_TASK = 'create' - -LOGGER = logging.getLogger(__name__) - - -def create_launch_body( - process_class: str, - init_args: Optional[Sequence[Any]] = None, - init_kwargs: Optional[Dict[str, Any]] = None, - persist: bool = False, - loader: Optional[loaders.ObjectLoader] = None, - nowait: bool = True, -) -> Dict[str, Any]: - """ - Create a message body for the launch action - - :param process_class: the class of the process to launch - :param init_args: any initialisation positional arguments - :param init_kwargs: any initialisation keyword arguments - :param persist: persist this process if True, otherwise don't - :param loader: the loader to use to load the persisted process - :param nowait: wait for the process to finish before completing the task, otherwise just return the PID - :return: a dictionary with the body of the message to launch the process - :rtype: dict - """ - if loader is None: - loader = loaders.get_object_loader() - - msg_body = { - TASK_KEY: LAUNCH_TASK, - TASK_ARGS: { - PROCESS_CLASS_KEY: loader.identify_object(process_class), - PERSIST_KEY: persist, - NOWAIT_KEY: nowait, - ARGS_KEY: init_args, - KWARGS_KEY: init_kwargs, - }, - } - return msg_body - - -def create_continue_body(pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False) -> Dict[str, Any]: - """ - Create a message body to continue an existing process - :param pid: the pid of the existing process - :param tag: the optional persistence tag - :param nowait: wait for the process to finish before completing the task, otherwise just return the PID - :return: a dictionary with the body of the message to continue the process - - """ - msg_body = {TASK_KEY: CONTINUE_TASK, TASK_ARGS: {PID_KEY: pid, NOWAIT_KEY: nowait, TAG_KEY: tag}} - return msg_body - - -def create_create_body( - process_class: str, - init_args: Optional[Sequence[Any]] = None, - init_kwargs: Optional[Dict[str, Any]] = None, - persist: bool = False, - loader: Optional[loaders.ObjectLoader] = None, -) -> Dict[str, Any]: - """ - Create a message body to create a new process - :param process_class: the class of the process to launch - :param init_args: any initialisation positional arguments - :param init_kwargs: any initialisation keyword arguments - :param persist: persist this process if True, otherwise don't - :param loader: the loader to use to load the persisted process - :return: a dictionary with the body of the message to launch the process - - """ - if loader is None: - loader = loaders.get_object_loader() - - msg_body = { - TASK_KEY: CREATE_TASK, - TASK_ARGS: { - PROCESS_CLASS_KEY: loader.identify_object(process_class), - PERSIST_KEY: persist, - ARGS_KEY: init_args, - KWARGS_KEY: init_kwargs, - }, - } - return msg_body - class RemoteProcessController: """ @@ -471,172 +355,3 @@ def task_send(self, message: Any, no_reply: bool = False) -> Optional[Any]: :return: the response from the remote side (if no_reply=False) """ return self._communicator.task_send(message, no_reply=no_reply) - - -class ProcessLauncher: - """ - Takes incoming task messages and uses them to launch processes. - - Expected format of task: - - For launch:: - - { - 'task': - 'process_class': - 'args': - 'kwargs': . - 'nowait': True or False - } - - For continue:: - - { - 'task': - 'pid': - 'nowait': True or False - } - """ - - def __init__( - self, - loop: Optional[asyncio.AbstractEventLoop] = None, - persister: Optional[persistence.Persister] = None, - load_context: Optional[persistence.LoadSaveContext] = None, - loader: Optional[loaders.ObjectLoader] = None, - ) -> None: - self._loop = loop - self._persister = persister - self._load_context = load_context if load_context is not None else persistence.LoadSaveContext() - - if loader is not None: - self._loader = loader - self._load_context = self._load_context.copyextend(loader=loader) - else: - self._loader = loaders.get_object_loader() - - async def __call__(self, communicator: kiwipy.Communicator, task: Dict[str, Any]) -> Union[PID_TYPE, ProcessResult]: - """ - Receive a task. - :param task: The task message - """ - from plumpy.rmq import communications - - task_type = task[TASK_KEY] - if task_type == LAUNCH_TASK: - return await self._launch(communicator, **task.get(TASK_ARGS, {})) - if task_type == CONTINUE_TASK: - return await self._continue(communicator, **task.get(TASK_ARGS, {})) - if task_type == CREATE_TASK: - return await self._create(communicator, **task.get(TASK_ARGS, {})) - - raise communications.TaskRejected - - async def _launch( - self, - _communicator: kiwipy.Communicator, - process_class: str, - persist: bool, - nowait: bool, - init_args: Optional[Sequence[Any]] = None, - init_kwargs: Optional[Dict[str, Any]] = None, - ) -> Union[PID_TYPE, ProcessResult]: - """ - Launch the process - - :param _communicator: the communicator - :param process_class: the process class to launch - :param persist: should the process be persisted - :param nowait: if True only return when the process finishes - :param init_args: positional arguments to the process constructor - :param init_kwargs: keyword arguments to the process constructor - :return: the pid of the created process or the outputs (if nowait=False) - """ - from plumpy.rmq import communications - - if persist and not self._persister: - raise communications.TaskRejected('Cannot persist process, no persister') - - if init_args is None: - init_args = () - if init_kwargs is None: - init_kwargs = {} - - proc_class = self._loader.load_object(process_class) - proc = proc_class(*init_args, **init_kwargs) - if persist and self._persister is not None: - self._persister.save_checkpoint(proc) - - if nowait: - # XXX: can return a reference and gracefully use task to cancel itself when the upper call stack fails - asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 - return proc.pid - - await proc.step_until_terminated() - - return proc.future().result() - - async def _continue( - self, _communicator: kiwipy.Communicator, pid: 'PID_TYPE', nowait: bool, tag: Optional[str] = None - ) -> Union[PID_TYPE, ProcessResult]: - """ - Continue the process - - :param _communicator: the communicator - :param pid: the pid of the process to continue - :param nowait: if True don't wait for the process to complete - :param tag: the checkpoint tag to continue from - """ - from plumpy.rmq import communications - - if not self._persister: - LOGGER.warning('rejecting task: cannot continue process<%d> because no persister is available', pid) - raise communications.TaskRejected('Cannot continue process, no persister') - - # Do not catch exceptions here, because if these operations fail, the continue task should except and bubble up - saved_state = self._persister.load_checkpoint(pid, tag) - proc = cast('Process', saved_state.unbundle(self._load_context)) - - if nowait: - # XXX: can return a reference and gracefully use task to cancel itself when the upper call stack fails - asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 - return proc.pid - - await proc.step_until_terminated() - - return proc.future().result() - - async def _create( - self, - _communicator: kiwipy.Communicator, - process_class: str, - persist: bool, - init_args: Optional[Sequence[Any]] = None, - init_kwargs: Optional[Dict[str, Any]] = None, - ) -> 'PID_TYPE': - """ - Create the process - - :param _communicator: the communicator - :param process_class: the process class to create - :param persist: should the process be persisted - :param init_args: positional arguments to the process constructor - :param init_kwargs: keyword arguments to the process constructor - :return: the pid of the created process - """ - from plumpy.rmq import communications - - if persist and not self._persister: - raise communications.TaskRejected('Cannot persist process, no persister') - - if init_args is None: - init_args = () - if init_kwargs is None: - init_kwargs = {} - - proc_class = self._loader.load_object(process_class) - proc = proc_class(*init_args, **init_kwargs) - if persist and self._persister is not None: - self._persister.save_checkpoint(proc) - - return proc.pid diff --git a/tests/test_communications.py b/tests/rmq/test_communications.py similarity index 100% rename from tests/test_communications.py rename to tests/rmq/test_communications.py diff --git a/tests/rmq/test_communicator.py b/tests/rmq/test_communicator.py index 2a0bfebc..26c9a852 100644 --- a/tests/rmq/test_communicator.py +++ b/tests/rmq/test_communicator.py @@ -13,8 +13,7 @@ from kiwipy import BroadcastFilter, rmq import plumpy -from plumpy import process_comms -from plumpy.rmq import communications +from plumpy.rmq import communications, process_comms from .. import utils diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index 593c8ef4..bba06518 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -8,8 +8,8 @@ from kiwipy import rmq import plumpy -import plumpy.rmq.communications -from plumpy import process_comms +from plumpy.message import KILL_MSG, MESSAGE_KEY +from plumpy.rmq import process_comms from .. import utils @@ -196,8 +196,8 @@ async def test_kill_all(self, thread_communicator, sync_controller): for _ in range(10): procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) - msg = copy.copy(process_comms.KILL_MSG) - msg[process_comms.MESSAGE_KEY] = 'bang bang, I shot you down' + msg = copy.copy(KILL_MSG) + msg[MESSAGE_KEY] = 'bang bang, I shot you down' sync_controller.kill_all(msg) await utils.wait_util(lambda: all([proc.killed() for proc in procs])) diff --git a/tests/test_process_comms.py b/tests/test_message.py similarity index 90% rename from tests/test_process_comms.py rename to tests/test_message.py index c59737ac..82951afd 100644 --- a/tests/test_process_comms.py +++ b/tests/test_message.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- import pytest -from tests import utils import plumpy -from plumpy import process_comms +from plumpy import message +from tests import utils class Process(plumpy.Process): @@ -37,7 +37,7 @@ async def test_continue(): del process process = None - result = await launcher._continue(None, **plumpy.create_continue_body(pid)[process_comms.TASK_ARGS]) + result = await launcher._continue(None, **plumpy.create_continue_body(pid)[message.TASK_ARGS]) assert result == utils.DummyProcess.EXPECTED_OUTPUTS @@ -51,5 +51,5 @@ async def test_loader_is_used(): launcher = plumpy.ProcessLauncher(persister=persister, loader=loader) continue_task = plumpy.create_continue_body(proc.pid) - result = await launcher._continue(None, **continue_task[process_comms.TASK_ARGS]) + result = await launcher._continue(None, **continue_task[message.TASK_ARGS]) assert result == utils.DummyProcess.EXPECTED_OUTPUTS diff --git a/tests/test_processes.py b/tests/test_processes.py index a4238fbd..cc57f7dc 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -13,7 +13,7 @@ import plumpy from plumpy import BundleKeys, Process, ProcessState -from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from plumpy.message import KILL_MSG, MESSAGE_KEY from plumpy.utils import AttributesFrozendict diff --git a/tests/utils.py b/tests/utils.py index 9f7bfb22..a614b853 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,7 +9,7 @@ import plumpy from plumpy import persistence, process_states, processes, utils -from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from plumpy.message import KILL_MSG, MESSAGE_KEY Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs'])