Skip to content

Commit

Permalink
Message interface is hidden inside function call
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 16, 2024
1 parent e06d522 commit 560098c
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 45 deletions.
28 changes: 12 additions & 16 deletions src/plumpy/process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,15 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus':
result = await asyncio.wrap_future(future)
return result

async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult':
async def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult':
"""
Pause the process
:param pid: the pid of the process to pause
:param msg: optional pause message
:return: True if paused, False otherwise
"""
msg = MessageBuilder.pause(text=msg)
msg = MessageBuilder.pause(text=msg_text)

pause_future = self._communicator.rpc_send(pid, msg)
# rpc_send return a thread future from communicator
Expand All @@ -229,16 +229,15 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult':
result = await asyncio.wrap_future(future)
return result

async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult':
async def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult':
"""
Kill the process
:param pid: the pid of the process to kill
:param msg: optional kill message
:return: True if killed, False otherwise
"""
if msg is None:
msg = MessageBuilder.kill()
msg = MessageBuilder.kill(text=msg_text)

# Wait for the communication to go through
kill_future = self._communicator.rpc_send(pid, msg)
Expand Down Expand Up @@ -364,7 +363,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future:
"""
return self._communicator.rpc_send(pid, MessageBuilder.status())

def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future:
def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future:
"""
Pause the process
Expand All @@ -373,16 +372,17 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu
:return: a response future from the process to be paused
"""
msg = MessageBuilder.pause(text=msg)
msg = MessageBuilder.pause(text=msg_text)

return self._communicator.rpc_send(pid, msg)

def pause_all(self, msg: Any) -> None:
def pause_all(self, msg_text: Optional[str]) -> None:
"""
Pause all processes that are subscribed to the same communicator
:param msg: an optional pause message
"""
msg = MessageBuilder.pause(text=msg_text)
self._communicator.broadcast_send(msg, subject=Intent.PAUSE)

def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future:
Expand All @@ -401,28 +401,24 @@ def play_all(self) -> None:
"""
self._communicator.broadcast_send(None, subject=Intent.PLAY)

def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future:
def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None, force_kill: bool = False) -> kiwipy.Future:
"""
Kill the process
:param pid: the pid of the process to kill
:param msg: optional kill message
:return: a response future from the process to be killed
"""
if msg is None:
msg = MessageBuilder.kill()

msg = MessageBuilder.kill(text=msg_text, force_kill=force_kill)
return self._communicator.rpc_send(pid, msg)

def kill_all(self, msg: Optional[MessageType]) -> None:
def kill_all(self, msg_text: Optional[str]) -> None:
"""
Kill all processes that are subscribed to the same communicator
:param msg: an optional pause message
"""
if msg is None:
msg = MessageBuilder.kill()
msg = MessageBuilder.kill(msg_text)

self._communicator.broadcast_send(msg, subject=Intent.KILL)

Expand Down
11 changes: 7 additions & 4 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,19 @@ class Interruption(Exception): # noqa: N818


class KillInterruption(Interruption):
def __init__(self, msg: MessageType | None):
def __init__(self, msg_text: str | None):
super().__init__()
if msg is None:
msg = MessageBuilder.kill()
msg = MessageBuilder.kill(text=msg_text)

self.msg: MessageType = msg


class PauseInterruption(Interruption):
pass
def __init__(self, msg_text: str | None):
super().__init__()
msg = MessageBuilder.pause(text=msg_text)

self.msg: MessageType = msg


# region Commands
Expand Down
43 changes: 25 additions & 18 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,7 @@ def init(self) -> None:

def try_killing(future: futures.Future) -> None:
if future.cancelled():
msg = MessageBuilder.kill(text='Killed by future being cancelled')
if not self.kill(msg):
if not self.kill('Killed by future being cancelled'):
self.logger.warning(
'Process<%s>: Failed to kill process on future cancel',
self.pid,
Expand Down Expand Up @@ -944,7 +943,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non

# region Communication

def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> Any:
def message_receive(self, _comm: kiwipy.Communicator, msg: MessageType) -> Any:
"""
Coroutine called when the process receives a message from the communicator
Expand All @@ -964,9 +963,9 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An
if intent == process_comms.Intent.PLAY:
return self._schedule_rpc(self.play)
if intent == process_comms.Intent.PAUSE:
return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None))
return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_KEY, None))
if intent == process_comms.Intent.KILL:
return self._schedule_rpc(self.kill, msg=msg)
return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_KEY, None))
if intent == process_comms.Intent.STATUS:
status_info: Dict[str, Any] = {}
self.get_status_info(status_info)
Expand All @@ -976,7 +975,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An
raise RuntimeError('Unknown intent')

def broadcast_receive(
self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any
self, _comm: kiwipy.Communicator, msg: MessageType, sender: Any, subject: Any, correlation_id: Any
) -> Optional[kiwipy.Future]:
"""
Coroutine called when the process receives a message from the communicator
Expand All @@ -990,16 +989,16 @@ def broadcast_receive(
self.pid,
subject,
_comm,
body,
msg,
)

# If we get a message we recognise then action it, otherwise ignore
if subject == process_comms.Intent.PLAY:
return self._schedule_rpc(self.play)
if subject == process_comms.Intent.PAUSE:
return self._schedule_rpc(self.pause, msg=body)
return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_KEY, None))
if subject == process_comms.Intent.KILL:
return self._schedule_rpc(self.kill, msg=body)
return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_KEY, None))
return None

def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future:
Expand Down Expand Up @@ -1071,7 +1070,7 @@ def transition_failed(
)
self.transition_to(new_state)

def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]:
def pause(self, msg_text: Optional[str] = None) -> Union[bool, futures.CancellableAction]:
"""Pause the process.
:param msg: an optional message to set as the status. The current status will be saved in the private
Expand All @@ -1095,22 +1094,29 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable
if self._stepping:
# Ask the step function to pause by setting this flag and giving the
# caller back a future
interrupt_exception = process_states.PauseInterruption(msg)
interrupt_exception = process_states.PauseInterruption(msg_text)
self._set_interrupt_action_from_exception(interrupt_exception)
self._pausing = self._interrupt_action
# Try to interrupt the state
self._state.interrupt(interrupt_exception)
return cast(futures.CancellableAction, self._interrupt_action)

return self._do_pause(msg)
msg = MessageBuilder.pause(msg_text)
return self._do_pause(state_msg=msg)

def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool:
def _do_pause(self, state_msg: Optional[MessageType], next_state: Optional[process_states.State] = None) -> bool:
"""Carry out the pause procedure, optionally transitioning to the next state first"""
try:
if next_state is not None:
self.transition_to(next_state)
call_with_super_check(self.on_pausing, state_msg)
call_with_super_check(self.on_paused, state_msg)

if state_msg is None:
msg_text = ''
else:
msg_text = state_msg[MESSAGE_KEY]

call_with_super_check(self.on_pausing, msg_text)
call_with_super_check(self.on_paused, msg_text)
finally:
self._pausing = None

Expand All @@ -1125,7 +1131,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu
"""
if isinstance(exception, process_states.PauseInterruption):
do_pause = functools.partial(self._do_pause, str(exception))
do_pause = functools.partial(self._do_pause, exception.msg)
return futures.CancellableAction(do_pause, cookie=exception)

if isinstance(exception, process_states.KillInterruption):
Expand Down Expand Up @@ -1190,7 +1196,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac
)
self.transition_to(new_state)

def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]:
"""
Kill the process
:param msg: An optional kill message
Expand All @@ -1210,12 +1216,13 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]
if self._stepping:
# Ask the step function to pause by setting this flag and giving the
# caller back a future
interrupt_exception = process_states.KillInterruption(msg)
interrupt_exception = process_states.KillInterruption(msg_text)
self._set_interrupt_action_from_exception(interrupt_exception)
self._killing = self._interrupt_action
self._state.interrupt(interrupt_exception)
return cast(futures.CancellableAction, self._interrupt_action)

msg = MessageBuilder.kill(msg_text)
new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg)
self.transition_to(new_state)
return True
Expand Down
4 changes: 1 addition & 3 deletions tests/rmq/test_process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ async def test_kill_all(self, thread_communicator, sync_controller):
for _ in range(10):
procs.append(utils.WaitForSignalProcess(communicator=thread_communicator))

msg = process_comms.MessageBuilder.kill(text='bang bang, I shot you down')

sync_controller.kill_all(msg)
sync_controller.kill_all(msg_text='bang bang, I shot you down')
await utils.wait_util(lambda: all([proc.killed() for proc in procs]))
assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs])

Expand Down
8 changes: 4 additions & 4 deletions tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import plumpy
from plumpy import BundleKeys, Process, ProcessState
from plumpy.process_comms import MessageBuilder
from plumpy.process_comms import MESSAGE_KEY, MessageBuilder
from plumpy.utils import AttributesFrozendict
from tests import utils

Expand Down Expand Up @@ -322,10 +322,10 @@ def run(self, **kwargs):
def test_kill(self):
proc: Process = utils.DummyProcess()

msg = MessageBuilder.kill(text='Farewell!')
proc.kill(msg)
msg_text = 'Farewell!'
proc.kill(msg_text=msg_text)
self.assertTrue(proc.killed())
self.assertEqual(proc.killed_msg(), msg)
self.assertEqual(proc.killed_msg()[MESSAGE_KEY], msg_text)
self.assertEqual(proc.state, ProcessState.KILLED)

def test_wait_continue(self):
Expand Down

0 comments on commit 560098c

Please sign in to comment.