Skip to content

Commit

Permalink
KillMessage build msg from parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 2, 2024
1 parent 4be6931 commit 88259d6
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 28 deletions.
21 changes: 16 additions & 5 deletions src/plumpy/process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from .utils import PID_TYPE

__all__ = [
'KILL_MSG',
'PAUSE_MSG',
'PLAY_MSG',
'STATUS_MSG',
'KillMessage',
'ProcessLauncher',
'RemoteProcessController',
'RemoteProcessThreadController',
Expand Down Expand Up @@ -47,9 +47,20 @@ class Intent:

PAUSE_MSG: MessageType = {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None}
PLAY_MSG: MessageType = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None}
KILL_MSG: MessageType = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False}
# KILL_MSG: MessageType = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False}
STATUS_MSG: MessageType = {INTENT_KEY: Intent.STATUS, MESSAGE_KEY: None}


class KillMessage:
@classmethod
def build(cls, message: str | None = None, force: bool = False) -> MessageType:
return {
INTENT_KEY: Intent.KILL,
MESSAGE_KEY: message,
FORCE_KILL_KEY: force,
}


TASK_KEY = 'task'
TASK_ARGS = 'args'
PERSIST_KEY = 'persist'
Expand Down Expand Up @@ -209,7 +220,7 @@ async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None)
:return: True if killed, False otherwise
"""
if msg is None:
msg = copy.copy(KILL_MSG)
msg = KillMessage.build()

# Wait for the communication to go through
kill_future = self._communicator.rpc_send(pid, msg)
Expand Down Expand Up @@ -384,7 +395,7 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> ki
"""
if msg is None:
msg = copy.copy(KILL_MSG)
msg = KillMessage.build()

return self._communicator.rpc_send(pid, msg)

Expand All @@ -395,7 +406,7 @@ def kill_all(self, msg: Optional[MessageType]) -> None:
:param msg: an optional pause message
"""
if msg is None:
msg = copy.copy(KILL_MSG)
msg = KillMessage.build()

self._communicator.broadcast_send(msg, subject=Intent.KILL)

Expand Down
5 changes: 2 additions & 3 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# -*- coding: utf-8 -*-
import copy
import sys
import traceback
from enum import Enum
Expand All @@ -9,7 +8,7 @@
import yaml
from yaml.loader import Loader

from plumpy.process_comms import KILL_MSG, MessageType
from plumpy.process_comms import KillMessage, MessageType

try:
import tblib
Expand Down Expand Up @@ -54,7 +53,7 @@ class KillInterruption(Interruption):
def __init__(self, msg: MessageType | None):
super().__init__()
if msg is None:
msg = copy.copy(KILL_MSG)
msg = KillMessage.build()

self.msg: MessageType = msg

Expand Down
12 changes: 5 additions & 7 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import warnings
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Expand All @@ -27,6 +26,7 @@
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
)
Expand Down Expand Up @@ -54,13 +54,12 @@
from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event
from .base.utils import call_with_super_check, super_check
from .event_helper import EventHelper
from .process_comms import KILL_MSG, MESSAGE_KEY, MessageType
from .process_comms import MESSAGE_KEY, KillMessage, MessageType
from .process_listener import ProcessListener
from .process_spec import ProcessSpec
from .utils import PID_TYPE, SAVED_STATE_TYPE, protected

if TYPE_CHECKING:
from .process_states import State
T = TypeVar('T')

__all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed']

Expand Down Expand Up @@ -345,8 +344,7 @@ def init(self) -> None:

def try_killing(future: futures.Future) -> None:
if future.cancelled():
msg = copy.copy(KILL_MSG)
msg[MESSAGE_KEY] = 'Killed by future being cancelled'
msg = KillMessage.build(message='Killed by future being cancelled')
if not self.kill(msg):
self.logger.warning(
'Process<%s>: Failed to kill process on future cancel',
Expand Down Expand Up @@ -594,7 +592,7 @@ def _process_scope(self) -> Generator[None, None, None]:
stack_copy.pop()
PROCESS_STACK.set(stack_copy)

async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
async def _run_task(self, callback: Callable[..., T], *args: Any, **kwargs: Any) -> T:
"""
This method should be used to run all Process related functions and coroutines.
If there is an exception the process will enter the EXCEPTED state.
Expand Down
3 changes: 1 addition & 2 deletions tests/rmq/test_process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,7 @@ async def test_kill_all(self, thread_communicator, sync_controller):
for _ in range(10):
procs.append(utils.WaitForSignalProcess(communicator=thread_communicator))

msg = copy.copy(process_comms.KILL_MSG)
msg[process_comms.MESSAGE_KEY] = 'bang bang, I shot you down'
msg = process_comms.KillMessage.build(message='bang bang, I shot you down')

sync_controller.kill_all(msg)
await utils.wait_util(lambda: all([proc.killed() for proc in procs]))
Expand Down
10 changes: 3 additions & 7 deletions tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
"""Process tests"""

import asyncio
import copy
import enum
from plumpy.process_comms import KILL_MSG, MESSAGE_KEY
from plumpy.process_comms import KillMessage
import unittest

import kiwipy
Expand All @@ -16,7 +15,6 @@

import plumpy
from plumpy import BundleKeys, Process, ProcessState
from plumpy.process_comms import KILL_MSG, MESSAGE_KEY
from plumpy.utils import AttributesFrozendict


Expand Down Expand Up @@ -327,8 +325,7 @@ def run(self, **kwargs):
def test_kill(self):
proc: Process = utils.DummyProcess()

msg = copy.copy(KILL_MSG)
msg[MESSAGE_KEY] = 'Farewell!'
msg = KillMessage.build(message='Farewell!')
proc.kill(msg)
self.assertTrue(proc.killed())
self.assertEqual(proc.killed_msg(), msg)
Expand Down Expand Up @@ -434,8 +431,7 @@ class KillProcess(Process):
after_kill = False

def run(self, **kwargs):
msg = copy.copy(KILL_MSG)
msg[MESSAGE_KEY] = 'killed'
msg = KillMessage.build(message='killed')
self.kill(msg)
# The following line should be executed because kill will not
# interrupt execution of a method call in the RUNNING state
Expand Down
6 changes: 2 additions & 4 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@

import asyncio
import collections
import copy
import unittest
from collections.abc import Mapping

import plumpy
from plumpy import persistence, process_states, processes, utils
from plumpy.process_comms import KILL_MSG, MESSAGE_KEY
from plumpy.process_comms import KillMessage

Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs'])

Expand Down Expand Up @@ -86,8 +85,7 @@ def last_step(self):
class KillProcess(processes.Process):
@utils.override
def run(self):
msg = copy.copy(KILL_MSG)
msg[MESSAGE_KEY] = 'killed'
msg = KillMessage.build(message='killed')
return process_states.Kill(msg=msg)


Expand Down

0 comments on commit 88259d6

Please sign in to comment.