diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index e615ee4a..3b1556fb 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -200,7 +200,7 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus': result = await asyncio.wrap_future(future) return result - async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult': + async def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult': """ Pause the process @@ -208,7 +208,7 @@ async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pr :param msg: optional pause message :return: True if paused, False otherwise """ - msg = MessageBuilder.pause(text=msg) + msg = MessageBuilder.pause(text=msg_text) pause_future = self._communicator.rpc_send(pid, msg) # rpc_send return a thread future from communicator @@ -229,7 +229,7 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult': result = await asyncio.wrap_future(future) return result - async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult': + async def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult': """ Kill the process @@ -237,8 +237,7 @@ async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) :param msg: optional kill message :return: True if killed, False otherwise """ - if msg is None: - msg = MessageBuilder.kill() + msg = MessageBuilder.kill(text=msg_text) # Wait for the communication to go through kill_future = self._communicator.rpc_send(pid, msg) @@ -364,7 +363,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future: """ return self._communicator.rpc_send(pid, MessageBuilder.status()) - def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: + def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future: """ Pause the process @@ -373,16 +372,17 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu :return: a response future from the process to be paused """ - msg = MessageBuilder.pause(text=msg) + msg = MessageBuilder.pause(text=msg_text) return self._communicator.rpc_send(pid, msg) - def pause_all(self, msg: Any) -> None: + def pause_all(self, msg_text: Optional[str]) -> None: """ Pause all processes that are subscribed to the same communicator :param msg: an optional pause message """ + msg = MessageBuilder.pause(text=msg_text) self._communicator.broadcast_send(msg, subject=Intent.PAUSE) def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future: @@ -401,28 +401,24 @@ def play_all(self) -> None: """ self._communicator.broadcast_send(None, subject=Intent.PLAY) - def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future: + def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None, force_kill: bool = False) -> kiwipy.Future: """ Kill the process :param pid: the pid of the process to kill :param msg: optional kill message :return: a response future from the process to be killed - """ - if msg is None: - msg = MessageBuilder.kill() - + msg = MessageBuilder.kill(text=msg_text, force_kill=force_kill) return self._communicator.rpc_send(pid, msg) - def kill_all(self, msg: Optional[MessageType]) -> None: + def kill_all(self, msg_text: Optional[str]) -> None: """ Kill all processes that are subscribed to the same communicator :param msg: an optional pause message """ - if msg is None: - msg = MessageBuilder.kill() + msg = MessageBuilder.kill(msg_text) self._communicator.broadcast_send(msg, subject=Intent.KILL) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index d369a1e9..931dbc5e 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -52,16 +52,19 @@ class Interruption(Exception): # noqa: N818 class KillInterruption(Interruption): - def __init__(self, msg: MessageType | None): + def __init__(self, msg_text: str | None): super().__init__() - if msg is None: - msg = MessageBuilder.kill() + msg = MessageBuilder.kill(text=msg_text) self.msg: MessageType = msg class PauseInterruption(Interruption): - pass + def __init__(self, msg_text: str | None): + super().__init__() + msg = MessageBuilder.pause(text=msg_text) + + self.msg: MessageType = msg # region Commands diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 0866ee41..d984e171 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -344,8 +344,7 @@ def init(self) -> None: def try_killing(future: futures.Future) -> None: if future.cancelled(): - msg = MessageBuilder.kill(text='Killed by future being cancelled') - if not self.kill(msg): + if not self.kill('Killed by future being cancelled'): self.logger.warning( 'Process<%s>: Failed to kill process on future cancel', self.pid, @@ -944,7 +943,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non # region Communication - def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> Any: + def message_receive(self, _comm: kiwipy.Communicator, msg: MessageType) -> Any: """ Coroutine called when the process receives a message from the communicator @@ -964,9 +963,9 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An if intent == process_comms.Intent.PLAY: return self._schedule_rpc(self.play) if intent == process_comms.Intent.PAUSE: - return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None)) + return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_KEY, None)) if intent == process_comms.Intent.KILL: - return self._schedule_rpc(self.kill, msg=msg) + return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_KEY, None)) if intent == process_comms.Intent.STATUS: status_info: Dict[str, Any] = {} self.get_status_info(status_info) @@ -976,7 +975,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An raise RuntimeError('Unknown intent') def broadcast_receive( - self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any + self, _comm: kiwipy.Communicator, msg: MessageType, sender: Any, subject: Any, correlation_id: Any ) -> Optional[kiwipy.Future]: """ Coroutine called when the process receives a message from the communicator @@ -990,16 +989,16 @@ def broadcast_receive( self.pid, subject, _comm, - body, + msg, ) # If we get a message we recognise then action it, otherwise ignore if subject == process_comms.Intent.PLAY: return self._schedule_rpc(self.play) if subject == process_comms.Intent.PAUSE: - return self._schedule_rpc(self.pause, msg=body) + return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_KEY, None)) if subject == process_comms.Intent.KILL: - return self._schedule_rpc(self.kill, msg=body) + return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_KEY, None)) return None def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future: @@ -1071,7 +1070,7 @@ def transition_failed( ) self.transition_to(new_state) - def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: + def pause(self, msg_text: Optional[str] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. :param msg: an optional message to set as the status. The current status will be saved in the private @@ -1095,22 +1094,29 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable if self._stepping: # Ask the step function to pause by setting this flag and giving the # caller back a future - interrupt_exception = process_states.PauseInterruption(msg) + interrupt_exception = process_states.PauseInterruption(msg_text) self._set_interrupt_action_from_exception(interrupt_exception) self._pausing = self._interrupt_action # Try to interrupt the state self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - return self._do_pause(msg) + msg = MessageBuilder.pause(msg_text) + return self._do_pause(state_msg=msg) - def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool: + def _do_pause(self, state_msg: Optional[MessageType], next_state: Optional[process_states.State] = None) -> bool: """Carry out the pause procedure, optionally transitioning to the next state first""" try: if next_state is not None: self.transition_to(next_state) - call_with_super_check(self.on_pausing, state_msg) - call_with_super_check(self.on_paused, state_msg) + + if state_msg is None: + msg_text = '' + else: + msg_text = state_msg[MESSAGE_KEY] + + call_with_super_check(self.on_pausing, msg_text) + call_with_super_check(self.on_paused, msg_text) finally: self._pausing = None @@ -1125,7 +1131,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu """ if isinstance(exception, process_states.PauseInterruption): - do_pause = functools.partial(self._do_pause, str(exception)) + do_pause = functools.partial(self._do_pause, exception.msg) return futures.CancellableAction(do_pause, cookie=exception) if isinstance(exception, process_states.KillInterruption): @@ -1190,7 +1196,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac ) self.transition_to(new_state) - def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: + def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: """ Kill the process :param msg: An optional kill message @@ -1210,12 +1216,13 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] if self._stepping: # Ask the step function to pause by setting this flag and giving the # caller back a future - interrupt_exception = process_states.KillInterruption(msg) + interrupt_exception = process_states.KillInterruption(msg_text) self._set_interrupt_action_from_exception(interrupt_exception) self._killing = self._interrupt_action self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) + msg = MessageBuilder.kill(msg_text) new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg) self.transition_to(new_state) return True diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index a6249d10..7a03fac4 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -195,9 +195,7 @@ async def test_kill_all(self, thread_communicator, sync_controller): for _ in range(10): procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) - msg = process_comms.MessageBuilder.kill(text='bang bang, I shot you down') - - sync_controller.kill_all(msg) + sync_controller.kill_all(msg_text='bang bang, I shot you down') await utils.wait_util(lambda: all([proc.killed() for proc in procs])) assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs]) diff --git a/tests/test_processes.py b/tests/test_processes.py index 7b21c463..bba80739 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -10,7 +10,7 @@ import plumpy from plumpy import BundleKeys, Process, ProcessState -from plumpy.process_comms import MessageBuilder +from plumpy.process_comms import MESSAGE_KEY, MessageBuilder from plumpy.utils import AttributesFrozendict from tests import utils @@ -322,10 +322,10 @@ def run(self, **kwargs): def test_kill(self): proc: Process = utils.DummyProcess() - msg = MessageBuilder.kill(text='Farewell!') - proc.kill(msg) + msg_text = 'Farewell!' + proc.kill(msg_text=msg_text) self.assertTrue(proc.killed()) - self.assertEqual(proc.killed_msg(), msg) + self.assertEqual(proc.killed_msg()[MESSAGE_KEY], msg_text) self.assertEqual(proc.state, ProcessState.KILLED) def test_wait_continue(self):