From 6f626bf5ccf79446ed2d59d5bef428fc3465a34a Mon Sep 17 00:00:00 2001 From: Ali Date: Wed, 2 Oct 2024 15:16:37 +0200 Subject: [PATCH] Implementation with minimal changes --- src/plumpy/process_comms.py | 11 ++----- src/plumpy/process_states.py | 3 ++ src/plumpy/processes.py | 62 +++++++++--------------------------- 3 files changed, 20 insertions(+), 56 deletions(-) diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 9a8b03df..c66e8431 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -30,7 +30,6 @@ INTENT_KEY = 'intent' MESSAGE_KEY = 'message' -FORCE_KILL_KEY = 'force_kill' class Intent: @@ -197,7 +196,7 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult': result = await asyncio.wrap_future(future) return result - async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None, force_kill: bool = False) -> 'ProcessResult': + async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult': """ Kill the process @@ -205,11 +204,9 @@ async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None, force_k :param msg: optional kill message :return: True if killed, False otherwise """ - breakpoint() message = copy.copy(KILL_MSG) if msg is not None: message[MESSAGE_KEY] = msg - message[FORCE_KILL_KEY] = force_kill # Wait for the communication to go through kill_future = self._communicator.rpc_send(pid, message) @@ -378,7 +375,7 @@ def play_all(self) -> None: """ self._communicator.broadcast_send(None, subject=Intent.PLAY) - def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None, force_kill: bool = False) -> kiwipy.Future: + def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: """ Kill the process @@ -387,11 +384,9 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None, force_kill: b :return: a response future from the process to be killed """ - breakpoint() message = copy.copy(KILL_MSG) if msg is not None: message[MESSAGE_KEY] = msg - message[FORCE_KILL_KEY] = force_kill return self._communicator.rpc_send(pid, message) @@ -410,7 +405,6 @@ def continue_process( nowait: bool = False, no_reply: bool = False ) -> Union[None, PID_TYPE, ProcessResult]: - breakpoint() message = create_continue_body(pid=pid, tag=tag, nowait=nowait) return self.task_send(message, no_reply=no_reply) @@ -485,7 +479,6 @@ def task_send(self, message: Any, no_reply: bool = False) -> Optional[Any]: :param no_reply: if True, this call will be fire-and-forget, i.e. no return value :return: the response from the remote side (if no_reply=False) """ - breakpoint() return self._communicator.task_send(message, no_reply=no_reply) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 3407412d..77fb8ba4 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -36,6 +36,7 @@ 'Continue', 'Interruption', 'KillInterruption', + 'ForceKillInterruption', 'PauseInterruption', ] @@ -50,6 +51,8 @@ class Interruption(Exception): class KillInterruption(Interruption): pass +class ForceKillInterruption(Interruption): + pass class PauseInterruption(Interruption): pass diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 95941f70..3be7abe3 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -51,29 +51,7 @@ __all__ = ['Process', 'ProcessSpec', 'BundleKeys', 'TransitionFailed'] - -#file_handler = logging.FileHandler(filename='tmp.log') -#stdout_handler = logging.StreamHandler(stream=sys.stdout) -#handlers = [file_handler, stdout_handler] -# -#logging.basicConfig( -# level=logging.DEBUG, -# format='[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', -# handlers=handlers -#) - -#file_handler = logging.FileHandler(filename="/Users/alexgo/code/aiida-core/plumpy2.log") -#stdout_handler = logging.StreamHandler(stream=sys.stdout) -#handlers = [file_handler, stdout_handler] -# -#logging.basicConfig( -# level=logging.DEBUG, -# format='[%(asctime)s] {%(pathname)s:%(lineno)d} %(levelname)s - %(message)s', -# handlers=handlers -#) _LOGGER = logging.getLogger(__name__) - - PROCESS_STACK = ContextVar('process stack', default=[]) @@ -411,8 +389,8 @@ def logger(self) -> logging.Logger: :return: The logger. """ - #if self._logger is not None: - # return self._logger + if self._logger is not None: + return self._logger return _LOGGER @@ -930,7 +908,6 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An :param msg: the message :return: the outcome of processing the message, the return value will be sent back as a response to the sender """ - breakpoint() self.logger.debug("Process<%s>: received RPC message with communicator '%s': %r", self.pid, _comm, msg) intent = msg[process_comms.INTENT_KEY] @@ -940,11 +917,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An if intent == process_comms.Intent.PAUSE: return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None)) if intent == process_comms.Intent.KILL: - breakpoint() - # have problems to pass new argument get - # Error: failed to kill Process<699>: Process.kill() got an unexpected keyword argument 'force_kill' - #return self._schedule_rpc(self.kill, msg=msg.get(process_comms.MESSAGE_KEY, None), force_kill=msg.get(process_comms.FORCE_KILL_KEY, False)) - return self._schedule_rpc(self.kill, msg=msg) + return self._schedule_rpc(self.kill, msg=msg.get(process_comms.MESSAGE_KEY, None)) if intent == process_comms.Intent.STATUS: status_info: Dict[str, Any] = {} self.get_status_info(status_info) @@ -961,7 +934,6 @@ def broadcast_receive(self, _comm: kiwipy.Communicator, body: Any, sender: Any, :param _comm: the communicator that sent the message :param msg: the message """ - breakpoint() # pylint: disable=unused-argument self.logger.debug( "Process<%s>: received broadcast message '%s' with communicator '%s': %r", self.pid, subject, _comm, body @@ -973,7 +945,6 @@ def broadcast_receive(self, _comm: kiwipy.Communicator, body: Any, sender: Any, if subject == process_comms.Intent.PAUSE: return self._schedule_rpc(self.pause, msg=body) if subject == process_comms.Intent.KILL: - # TODO deal with this return self._schedule_rpc(self.kill, msg=body) return None @@ -1096,7 +1067,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu do_pause = functools.partial(self._do_pause, str(exception)) return futures.CancellableAction(do_pause, cookie=exception) - if isinstance(exception, process_states.KillInterruption): + if isinstance(exception, (process_states.KillInterruption, process_states.ForceKillInterruption)): def do_kill(_next_state: process_states.State) -> Any: try: @@ -1155,20 +1126,12 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac """ self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace_back) - def kill(self, msg: Union[dict, None] = None, force_kill: bool = False) -> Union[bool, asyncio.Future]: + def kill(self, msg: Union[str, None] = None) -> Union[bool, asyncio.Future]: """ Kill the process - - # PR_COMMENT have not figured out how to integrate force_kill as argument - # so I just pass the dict - :param msg: An optional kill message """ - breakpoint() - if msg is None: - force_kill = False - else: - force_kill = msg.get(process_comms.FORCE_KILL_KEY, False) + force_kill = isinstance(msg, str) and '-F' in msg if self.state == process_states.ProcessState.KILLED: # Already killed @@ -1178,20 +1141,25 @@ def kill(self, msg: Union[dict, None] = None, force_kill: bool = False) -> Union # Can't kill return False - if self._killing: + if self._killing and not force_kill: # Already killing return self._killing - if self._stepping and not force_kill: + if force_kill: + # Skip interrupting the state and go straight to killed + interrupt_exception = process_states.ForceKillInterruption(msg) + self._killing = self._interrupt_action + self._state.interrupt(interrupt_exception) + + elif self._stepping: # Ask the step function to pause by setting this flag and giving the # caller back a future - interrupt_exception = process_states.KillInterruption(msg.get(process_comms.MESSAGE_KEY, None)) + interrupt_exception = process_states.KillInterruption(msg) # type: ignore self._set_interrupt_action_from_exception(interrupt_exception) self._killing = self._interrupt_action self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - breakpoint() self.transition_to(process_states.ProcessState.KILLED, msg) return True