Skip to content

Commit

Permalink
Alias MessageType for message passing
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 1, 2024
1 parent 7732647 commit ea1d655
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 22 deletions.
27 changes: 13 additions & 14 deletions src/plumpy/process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ class Intent:
KILL: str = 'kill'
STATUS: str = 'status'

MessageType = dict[str, Any]

PAUSE_MSG = {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None}
PLAY_MSG = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None}
KILL_MSG = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False}
STATUS_MSG = {INTENT_KEY: Intent.STATUS, MESSAGE_KEY: None}
PAUSE_MSG: MessageType= {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None}
PLAY_MSG: MessageType = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None}
KILL_MSG: MessageType = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False}
STATUS_MSG: MessageType = {INTENT_KEY: Intent.STATUS, MESSAGE_KEY: None}

TASK_KEY = 'task'
TASK_ARGS = 'args'
Expand Down Expand Up @@ -197,20 +198,19 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult':
result = await asyncio.wrap_future(future)
return result

async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult':
async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult':
"""
Kill the process
:param pid: the pid of the process to kill
:param msg: optional kill message
:return: True if killed, False otherwise
"""
message = copy.copy(KILL_MSG)
if msg is not None:
message[MESSAGE_KEY] = msg
if msg is None:
msg = copy.copy(KILL_MSG)

# Wait for the communication to go through
kill_future = self._communicator.rpc_send(pid, message)
kill_future = self._communicator.rpc_send(pid, msg)
future = await asyncio.wrap_future(kill_future)
# Now wait for the kill to be enacted
result = await asyncio.wrap_future(future)
Expand Down Expand Up @@ -372,7 +372,7 @@ def play_all(self) -> None:
"""
self._communicator.broadcast_send(None, subject=Intent.PLAY)

def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future:
def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future:
"""
Kill the process
Expand All @@ -381,11 +381,10 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fut
:return: a response future from the process to be killed
"""
message = copy.copy(KILL_MSG)
if msg is not None:
message[MESSAGE_KEY] = msg
if msg is None:
msg = copy.copy(KILL_MSG)

return self._communicator.rpc_send(pid, message)
return self._communicator.rpc_send(pid, msg)

def kill_all(self, msg: Optional[Any]) -> None:
"""
Expand Down
4 changes: 3 additions & 1 deletion src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import yaml
from yaml.loader import Loader

from plumpy.process_comms import MessageType

try:
import tblib

Expand Down Expand Up @@ -402,7 +404,7 @@ def __init__(self, process: 'Process', result: Any, successful: bool) -> None:
class Killed(State):
LABEL = ProcessState.KILLED

def __init__(self, process: 'Process', msg: Optional[str]):
def __init__(self, process: 'Process', msg: Optional[MessageType]):
"""
:param process: The associated process
:param msg: Optional kill message
Expand Down
23 changes: 16 additions & 7 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from .process_listener import ProcessListener
from .process_spec import ProcessSpec
from .utils import PID_TYPE, SAVED_STATE_TYPE, protected
from .process_comms import KILL_MSG, MESSAGE_KEY, MessageType

__all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed']

Expand Down Expand Up @@ -320,7 +321,9 @@ def init(self) -> None:

def try_killing(future: futures.Future) -> None:
if future.cancelled():
if not self.kill('Killed by future being cancelled'):
msg = copy.copy(KILL_MSG)
msg[MESSAGE_KEY] = 'Killed by future being cancelled'
if not self.kill(msg):
self.logger.warning('Process<%s>: Failed to kill process on future cancel', self.pid)

self._future.add_done_callback(try_killing)
Expand Down Expand Up @@ -857,10 +860,15 @@ def on_excepted(self) -> None:
self._fire_event(ProcessListener.on_process_excepted, str(self.future().exception()))

@super_check
def on_kill(self, msg: Optional[str]) -> None:
def on_kill(self, msg: Optional[MessageType]) -> None:
"""Entering the KILLED state."""
self.set_status(msg)
self.future().set_exception(exceptions.KilledError(msg))
if msg is None:
msg_txt = ''
else:
msg_txt = msg[MESSAGE_KEY] or ''

self.set_status(msg_txt)
self.future().set_exception(exceptions.KilledError(msg_txt))

@super_check
def on_killed(self) -> None:
Expand Down Expand Up @@ -915,7 +923,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An
if intent == process_comms.Intent.PAUSE:
return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None))
if intent == process_comms.Intent.KILL:
return self._schedule_rpc(self.kill, msg=msg.get(process_comms.MESSAGE_KEY, None))
return self._schedule_rpc(self.kill, msg=msg)
if intent == process_comms.Intent.STATUS:
status_info: Dict[str, Any] = {}
self.get_status_info(status_info)
Expand Down Expand Up @@ -1071,7 +1079,8 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu
def do_kill(_next_state: process_states.State) -> Any:
try:
# Ignore the next state
self.transition_to(process_states.ProcessState.KILLED, str(exception))
__import__('ipdb').set_trace()
self.transition_to(process_states.ProcessState.KILLED, exception)
return True
finally:
self._killing = None
Expand Down Expand Up @@ -1125,7 +1134,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac
"""
self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace_back)

def kill(self, msg: Union[str, None] = None) -> Union[bool, asyncio.Future]:
def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
"""
Kill the process
:param msg: An optional kill message
Expand Down
2 changes: 2 additions & 0 deletions tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import asyncio
import copy
import enum
from plumpy.process_comms import KILL_MSG, MESSAGE_KEY
from test import utils
import unittest

import kiwipy
Expand Down

0 comments on commit ea1d655

Please sign in to comment.