diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 773a9742..bc2fa125 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -12,10 +12,10 @@ from .utils import PID_TYPE __all__ = [ - 'KILL_MSG', 'PAUSE_MSG', 'PLAY_MSG', 'STATUS_MSG', + 'KillMessage', 'ProcessLauncher', 'RemoteProcessController', 'RemoteProcessThreadController', @@ -47,9 +47,20 @@ class Intent: PAUSE_MSG: MessageType = {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None} PLAY_MSG: MessageType = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None} -KILL_MSG: MessageType = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False} +# KILL_MSG: MessageType = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False} STATUS_MSG: MessageType = {INTENT_KEY: Intent.STATUS, MESSAGE_KEY: None} + +class KillMessage: + @classmethod + def build(cls, message: str | None = None, force: bool = False) -> MessageType: + return { + INTENT_KEY: Intent.KILL, + MESSAGE_KEY: message, + FORCE_KILL_KEY: force, + } + + TASK_KEY = 'task' TASK_ARGS = 'args' PERSIST_KEY = 'persist' @@ -209,7 +220,7 @@ async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) :return: True if killed, False otherwise """ if msg is None: - msg = copy.copy(KILL_MSG) + msg = KillMessage.build() # Wait for the communication to go through kill_future = self._communicator.rpc_send(pid, msg) @@ -384,7 +395,7 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> ki """ if msg is None: - msg = copy.copy(KILL_MSG) + msg = KillMessage.build() return self._communicator.rpc_send(pid, msg) @@ -395,7 +406,7 @@ def kill_all(self, msg: Optional[MessageType]) -> None: :param msg: an optional pause message """ if msg is None: - msg = copy.copy(KILL_MSG) + msg = KillMessage.build() self._communicator.broadcast_send(msg, subject=Intent.KILL) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index ede846e4..45178b42 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -import copy import sys import traceback from enum import Enum @@ -9,7 +8,7 @@ import yaml from yaml.loader import Loader -from plumpy.process_comms import KILL_MSG, MessageType +from plumpy.process_comms import KillMessage, MessageType try: import tblib @@ -54,7 +53,7 @@ class KillInterruption(Interruption): def __init__(self, msg: MessageType | None): super().__init__() if msg is None: - msg = copy.copy(KILL_MSG) + msg = KillMessage.build() self.msg: MessageType = msg diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 9358d927..ef558fa1 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -15,7 +15,6 @@ import warnings from types import TracebackType from typing import ( - TYPE_CHECKING, Any, Awaitable, Callable, @@ -27,6 +26,7 @@ Sequence, Tuple, Type, + TypeVar, Union, cast, ) @@ -54,13 +54,12 @@ from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper -from .process_comms import KILL_MSG, MESSAGE_KEY, MessageType +from .process_comms import MESSAGE_KEY, KillMessage, MessageType from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected -if TYPE_CHECKING: - from .process_states import State +T = TypeVar('T') __all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed'] @@ -345,8 +344,7 @@ def init(self) -> None: def try_killing(future: futures.Future) -> None: if future.cancelled(): - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'Killed by future being cancelled' + msg = KillMessage.build(message='Killed by future being cancelled') if not self.kill(msg): self.logger.warning( 'Process<%s>: Failed to kill process on future cancel', @@ -594,7 +592,7 @@ def _process_scope(self) -> Generator[None, None, None]: stack_copy.pop() PROCESS_STACK.set(stack_copy) - async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + async def _run_task(self, callback: Callable[..., T], *args: Any, **kwargs: Any) -> T: """ This method should be used to run all Process related functions and coroutines. If there is an exception the process will enter the EXCEPTED state. diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index 7223b888..4c7a4f1a 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -196,8 +196,7 @@ async def test_kill_all(self, thread_communicator, sync_controller): for _ in range(10): procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) - msg = copy.copy(process_comms.KILL_MSG) - msg[process_comms.MESSAGE_KEY] = 'bang bang, I shot you down' + msg = process_comms.KillMessage.build(message='bang bang, I shot you down') sync_controller.kill_all(msg) await utils.wait_util(lambda: all([proc.killed() for proc in procs])) diff --git a/tests/test_processes.py b/tests/test_processes.py index 47085c90..cec20c51 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -2,9 +2,8 @@ """Process tests""" import asyncio -import copy import enum -from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from plumpy.process_comms import KillMessage import unittest import kiwipy @@ -16,7 +15,6 @@ import plumpy from plumpy import BundleKeys, Process, ProcessState -from plumpy.process_comms import KILL_MSG, MESSAGE_KEY from plumpy.utils import AttributesFrozendict @@ -327,8 +325,7 @@ def run(self, **kwargs): def test_kill(self): proc: Process = utils.DummyProcess() - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'Farewell!' + msg = KillMessage.build(message='Farewell!') proc.kill(msg) self.assertTrue(proc.killed()) self.assertEqual(proc.killed_msg(), msg) @@ -434,8 +431,7 @@ class KillProcess(Process): after_kill = False def run(self, **kwargs): - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'killed' + msg = KillMessage.build(message='killed') self.kill(msg) # The following line should be executed because kill will not # interrupt execution of a method call in the RUNNING state diff --git a/tests/utils.py b/tests/utils.py index f2a58dfc..88638e01 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,13 +3,12 @@ import asyncio import collections -import copy import unittest from collections.abc import Mapping import plumpy from plumpy import persistence, process_states, processes, utils -from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from plumpy.process_comms import KillMessage Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs']) @@ -86,8 +85,7 @@ def last_step(self): class KillProcess(processes.Process): @utils.override def run(self): - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'killed' + msg = KillMessage.build(message='killed') return process_states.Kill(msg=msg)