From ea1d655b12470f9f2a9cafea6cbb6dbeaaed0c97 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sat, 30 Nov 2024 23:45:12 +0100 Subject: [PATCH] Alias MessageType for message passing --- src/plumpy/process_comms.py | 27 +++++++++++++-------------- src/plumpy/process_states.py | 4 +++- src/plumpy/processes.py | 23 ++++++++++++++++------- tests/test_processes.py | 2 ++ 4 files changed, 34 insertions(+), 22 deletions(-) diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 13aa5fb3..1d280334 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -41,11 +41,12 @@ class Intent: KILL: str = 'kill' STATUS: str = 'status' +MessageType = dict[str, Any] -PAUSE_MSG = {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None} -PLAY_MSG = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None} -KILL_MSG = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False} -STATUS_MSG = {INTENT_KEY: Intent.STATUS, MESSAGE_KEY: None} +PAUSE_MSG: MessageType= {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None} +PLAY_MSG: MessageType = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None} +KILL_MSG: MessageType = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False} +STATUS_MSG: MessageType = {INTENT_KEY: Intent.STATUS, MESSAGE_KEY: None} TASK_KEY = 'task' TASK_ARGS = 'args' @@ -197,7 +198,7 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult': result = await asyncio.wrap_future(future) return result - async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult': + async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult': """ Kill the process @@ -205,12 +206,11 @@ async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pro :param msg: optional kill message :return: True if killed, False otherwise """ - message = copy.copy(KILL_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + if msg is None: + msg = copy.copy(KILL_MSG) # Wait for the communication to go through - kill_future = self._communicator.rpc_send(pid, message) + kill_future = self._communicator.rpc_send(pid, msg) future = await asyncio.wrap_future(kill_future) # Now wait for the kill to be enacted result = await asyncio.wrap_future(future) @@ -372,7 +372,7 @@ def play_all(self) -> None: """ self._communicator.broadcast_send(None, subject=Intent.PLAY) - def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: + def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future: """ Kill the process @@ -381,11 +381,10 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fut :return: a response future from the process to be killed """ - message = copy.copy(KILL_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + if msg is None: + msg = copy.copy(KILL_MSG) - return self._communicator.rpc_send(pid, message) + return self._communicator.rpc_send(pid, msg) def kill_all(self, msg: Optional[Any]) -> None: """ diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 7ae6e9bd..10ebfdab 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -8,6 +8,8 @@ import yaml from yaml.loader import Loader +from plumpy.process_comms import MessageType + try: import tblib @@ -402,7 +404,7 @@ def __init__(self, process: 'Process', result: Any, successful: bool) -> None: class Killed(State): LABEL = ProcessState.KILLED - def __init__(self, process: 'Process', msg: Optional[str]): + def __init__(self, process: 'Process', msg: Optional[MessageType]): """ :param process: The associated process :param msg: Optional kill message diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index ba7967d3..25a8f78e 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -47,6 +47,7 @@ from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected +from .process_comms import KILL_MSG, MESSAGE_KEY, MessageType __all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed'] @@ -320,7 +321,9 @@ def init(self) -> None: def try_killing(future: futures.Future) -> None: if future.cancelled(): - if not self.kill('Killed by future being cancelled'): + msg = copy.copy(KILL_MSG) + msg[MESSAGE_KEY] = 'Killed by future being cancelled' + if not self.kill(msg): self.logger.warning('Process<%s>: Failed to kill process on future cancel', self.pid) self._future.add_done_callback(try_killing) @@ -857,10 +860,15 @@ def on_excepted(self) -> None: self._fire_event(ProcessListener.on_process_excepted, str(self.future().exception())) @super_check - def on_kill(self, msg: Optional[str]) -> None: + def on_kill(self, msg: Optional[MessageType]) -> None: """Entering the KILLED state.""" - self.set_status(msg) - self.future().set_exception(exceptions.KilledError(msg)) + if msg is None: + msg_txt = '' + else: + msg_txt = msg[MESSAGE_KEY] or '' + + self.set_status(msg_txt) + self.future().set_exception(exceptions.KilledError(msg_txt)) @super_check def on_killed(self) -> None: @@ -915,7 +923,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An if intent == process_comms.Intent.PAUSE: return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None)) if intent == process_comms.Intent.KILL: - return self._schedule_rpc(self.kill, msg=msg.get(process_comms.MESSAGE_KEY, None)) + return self._schedule_rpc(self.kill, msg=msg) if intent == process_comms.Intent.STATUS: status_info: Dict[str, Any] = {} self.get_status_info(status_info) @@ -1071,7 +1079,8 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: process_states.State) -> Any: try: # Ignore the next state - self.transition_to(process_states.ProcessState.KILLED, str(exception)) + __import__('ipdb').set_trace() + self.transition_to(process_states.ProcessState.KILLED, exception) return True finally: self._killing = None @@ -1125,7 +1134,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac """ self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace_back) - def kill(self, msg: Union[str, None] = None) -> Union[bool, asyncio.Future]: + def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ Kill the process :param msg: An optional kill message diff --git a/tests/test_processes.py b/tests/test_processes.py index faea9eae..b4526403 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -4,6 +4,8 @@ import asyncio import copy import enum +from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from test import utils import unittest import kiwipy