diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 3397c40d..a371084a 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -205,7 +205,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine': :param kwargs: Any keyword arguments to be passed to the constructor :return: An instance of the state machine """ - inst = super().__call__(*args, **kwargs) + inst: StateMachine = super().__call__(*args, **kwargs) inst.transition_to(inst.create_initial_state()) call_with_super_check(inst.init) return inst @@ -325,13 +325,14 @@ def on_terminated(self) -> None: """Called when a terminal state is entered""" def transition_to( - self, new_state: Union[State, Type[State]], *args: Any, **kwargs: Any + self, new_state: Union[State, Type[State]], **kwargs: Any ) -> None: """Transite to the new state. - The new target state will be create lazily when the state - is not yet instantiated, which will happened for states not in the expect path such as - pause and kill. + The new target state will be create lazily when the state is not yet instantiated, + which will happened for states not in the expect path such as pause and kill. + The arguments are passed to the state class to create state instance. + (process arg does not need to pass since it will always call with 'self' as process) """ assert ( not self._transitioning @@ -344,7 +345,7 @@ def transition_to( if not isinstance(new_state, State): # Make sure we have a state instance - new_state = self._create_state_instance(new_state, *args, **kwargs) + new_state = self._create_state_instance(new_state, **kwargs) label = new_state.LABEL @@ -358,7 +359,7 @@ def transition_to( # Make sure we have a state instance if not isinstance(exception.state, State): new_state = self._create_state_instance( - exception.state, *exception.args, **exception.kwargs + exception.state, **exception.kwargs ) label = new_state.LABEL self._exit_current_state(new_state) @@ -435,9 +436,9 @@ def _enter_next_state(self, next_state: State) -> None: self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) def _create_state_instance( - self, state_cls: type[State], *args: Any, **kwargs: Any + self, state_cls: type[State], **kwargs: Any ) -> State: if state_cls.LABEL not in self.get_states_map(): raise ValueError(f"{state_cls.LABEL} is not a valid state") - return state_cls(self, *args, **kwargs) + return state_cls(self, **kwargs) diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 1d280334..9e1e4110 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -386,12 +386,15 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> ki return self._communicator.rpc_send(pid, msg) - def kill_all(self, msg: Optional[Any]) -> None: + def kill_all(self, msg: Optional[MessageType]) -> None: """ Kill all processes that are subscribed to the same communicator :param msg: an optional pause message """ + if msg is None: + msg = copy.copy(KILL_MSG) + self._communicator.broadcast_send(msg, subject=Intent.KILL) def continue_process( diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 10ebfdab..46f29d8f 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import sys +import copy import traceback from enum import Enum from types import TracebackType @@ -8,7 +9,7 @@ import yaml from yaml.loader import Loader -from plumpy.process_comms import MessageType +from plumpy.process_comms import KILL_MSG, MessageType try: import tblib @@ -50,7 +51,12 @@ class Interruption(Exception): # noqa: N818 class KillInterruption(Interruption): - pass + def __init__(self, msg: MessageType | None): + super().__init__() + if msg is None: + msg = copy.copy(KILL_MSG) + + self.msg: MessageType = msg class PauseInterruption(Interruption): @@ -64,9 +70,9 @@ class Command(persistence.Savable): pass -@auto_persist('msg') +@auto_persist("msg") class Kill(Command): - def __init__(self, msg: Optional[Any] = None): + def __init__(self, msg: Optional[MessageType] = None): super().__init__() self.msg = msg @@ -75,10 +81,13 @@ class Pause(Command): pass -@auto_persist('msg', 'data') +@auto_persist("msg", "data") class Wait(Command): def __init__( - self, continue_fn: Optional[Callable[..., Any]] = None, msg: Optional[Any] = None, data: Optional[Any] = None + self, + continue_fn: Optional[Callable[..., Any]] = None, + msg: Optional[Any] = None, + data: Optional[Any] = None, ): super().__init__() self.continue_fn = continue_fn @@ -86,7 +95,7 @@ def __init__( self.data = data -@auto_persist('result') +@auto_persist("result") class Stop(Command): def __init__(self, result: Any, successful: bool) -> None: super().__init__() @@ -94,9 +103,9 @@ def __init__(self, result: Any, successful: bool) -> None: self.successful = successful -@auto_persist('args', 'kwargs') +@auto_persist("args", "kwargs") class Continue(Command): - CONTINUE_FN = 'continue_fn' + CONTINUE_FN = "continue_fn" def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): super().__init__() @@ -104,11 +113,15 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): self.args = args self.kwargs = kwargs - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: + def save_instance_state( + self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext + ) -> None: super().save_instance_state(out_state, save_context) out_state[self.CONTINUE_FN] = self.continue_fn.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) try: self.continue_fn = utils.load_function(saved_state[self.CONTINUE_FN]) @@ -135,7 +148,7 @@ class ProcessState(Enum): KILLED: str = 'killed' -@auto_persist('in_state') +@auto_persist("in_state") class State(state_machine.State, persistence.Savable): @property def process(self) -> state_machine.StateMachine: @@ -144,7 +157,9 @@ def process(self) -> state_machine.StateMachine: """ return self.state_machine - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) self.state_machine = load_context.process @@ -152,33 +167,41 @@ def interrupt(self, reason: Any) -> None: pass -@auto_persist('args', 'kwargs') +@auto_persist("args", "kwargs") class Created(State): LABEL = ProcessState.CREATED ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} - RUN_FN = 'run_fn' + RUN_FN = "run_fn" - def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + def __init__( + self, process: "Process", run_fn: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: super().__init__(process) assert run_fn is not None self.run_fn = run_fn self.args = args self.kwargs = kwargs - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: + def save_instance_state( + self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext + ) -> None: super().save_instance_state(out_state, save_context) out_state[self.RUN_FN] = self.run_fn.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) def execute(self) -> state_machine.State: - return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) + return self.create_state( + ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs + ) -@auto_persist('args', 'kwargs') +@auto_persist("args", "kwargs") class Running(State): LABEL = ProcessState.RUNNING ALLOWED = { @@ -189,15 +212,17 @@ class Running(State): ProcessState.EXCEPTED, } - RUN_FN = 'run_fn' # The key used to store the function to run - COMMAND = 'command' # The key used to store an upcoming command + RUN_FN = "run_fn" # The key used to store the function to run + COMMAND = "command" # The key used to store an upcoming command # Class level defaults _command: Union[None, Kill, Stop, Wait, Continue] = None _running: bool = False _run_handle = None - def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + def __init__( + self, process: "Process", run_fn: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: super().__init__(process) assert run_fn is not None self.run_fn = run_fn @@ -205,17 +230,23 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.kwargs = kwargs self._run_handle = None - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: + def save_instance_state( + self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext + ) -> None: super().save_instance_state(out_state, save_context) out_state[self.RUN_FN] = self.run_fn.__name__ if self._command is not None: out_state[self.COMMAND] = self._command.save() - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) if self.COMMAND in saved_state: - self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore + self._command = persistence.Savable.load( + saved_state[self.COMMAND], load_context + ) # type: ignore def interrupt(self, reason: Any) -> None: pass @@ -255,18 +286,24 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.create_state(ProcessState.FINISHED, command.result, command.successful) + state = self.create_state( + ProcessState.FINISHED, command.result, command.successful + ) elif isinstance(command, Wait): - state = self.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) + state = self.create_state( + ProcessState.WAITING, command.continue_fn, command.msg, command.data + ) elif isinstance(command, Continue): - state = self.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) + state = self.create_state( + ProcessState.RUNNING, command.continue_fn, *command.args + ) else: - raise ValueError('Unrecognised command') + raise ValueError("Unrecognised command") return cast(State, state) # casting from base.State to process.State -@auto_persist('msg', 'data') +@auto_persist("msg", "data") class Waiting(State): LABEL = ProcessState.WAITING ALLOWED = { @@ -277,19 +314,19 @@ class Waiting(State): ProcessState.FINISHED, } - DONE_CALLBACK = 'DONE_CALLBACK' + DONE_CALLBACK = "DONE_CALLBACK" _interruption = None def __str__(self) -> str: state_info = super().__str__() if self.msg is not None: - state_info += f' ({self.msg})' + state_info += f" ({self.msg})" return state_info def __init__( self, - process: 'Process', + process: "Process", done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, data: Optional[Any] = None, @@ -300,12 +337,16 @@ def __init__( self.data = data self._waiting_future: futures.Future = futures.Future() - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: + def save_instance_state( + self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext + ) -> None: super().save_instance_state(out_state, save_context) if self.done_callback is not None: out_state[self.DONE_CALLBACK] = self.done_callback.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) callback_name = saved_state.get(self.DONE_CALLBACK, None) if callback_name is not None: @@ -331,12 +372,14 @@ async def execute(self) -> State: # type: ignore if result == NULL: next_state = self.create_state(ProcessState.RUNNING, self.done_callback) else: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result) + next_state = self.create_state( + ProcessState.RUNNING, self.done_callback, result + ) return cast(State, next_state) # casting from base.State to process.State def resume(self, value: Any = NULL) -> None: - assert self._waiting_future is not None, 'Not yet waiting' + assert self._waiting_future is not None, "Not yet waiting" if self._waiting_future.done(): return @@ -345,13 +388,23 @@ def resume(self, value: Any = NULL) -> None: class Excepted(State): + """ + Excepted state, can optionally provide exception and trace_back + + :param exception: The exception instance + :param trace_back: An optional exception traceback + """ + LABEL = ProcessState.EXCEPTED - EXC_VALUE = 'ex_value' - TRACEBACK = 'traceback' + EXC_VALUE = "ex_value" + TRACEBACK = "traceback" def __init__( - self, process: 'Process', exception: Optional[BaseException], trace_back: Optional[TracebackType] = None + self, + process: "Process", + exception: Optional[BaseException], + trace_back: Optional[TracebackType] = None, ): """ :param process: The associated process @@ -363,16 +416,22 @@ def __init__( self.traceback = trace_back def __str__(self) -> str: - exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] - return super().__str__() + f'({exception})' + exception = traceback.format_exception_only( + type(self.exception) if self.exception else None, self.exception + )[0] + return super().__str__() + f"({exception})" - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: + def save_instance_state( + self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext + ) -> None: super().save_instance_state(out_state, save_context) out_state[self.EXC_VALUE] = yaml.dump(self.exception) if self.traceback is not None: - out_state[self.TRACEBACK] = ''.join(traceback.format_tb(self.traceback)) + out_state[self.TRACEBACK] = "".join(traceback.format_tb(self.traceback)) - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -383,32 +442,53 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi else: self.traceback = None - def get_exc_info(self) -> Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]: + def get_exc_info( + self, + ) -> Tuple[ + Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType] + ]: """ Recreate the exc_info tuple and return it """ - return type(self.exception) if self.exception else None, self.exception, self.traceback + return ( + type(self.exception) if self.exception else None, + self.exception, + self.traceback, + ) -@auto_persist('result', 'successful') +@auto_persist("result", "successful") class Finished(State): + """State for process is finished. + + :param result: The result of process + :param successful: Boolean for the exit code is ``0`` the process is successful. + """ LABEL = ProcessState.FINISHED - def __init__(self, process: 'Process', result: Any, successful: bool) -> None: + def __init__(self, process: "Process", result: Any, successful: bool) -> None: super().__init__(process) self.result = result self.successful = successful -@auto_persist('msg') +@auto_persist("msg") class Killed(State): + """ + Represents a state where a process has been killed. + + This state is used to indicate that a process has been terminated and can optionally + include a message providing details about the termination. + + :param msg: An optional message explaining the reason for the process termination. + """ + LABEL = ProcessState.KILLED - def __init__(self, process: 'Process', msg: Optional[MessageType]): + def __init__(self, process: "Process", msg: Optional[MessageType]): """ :param process: The associated process :param msg: Optional kill message - """ super().__init__(process) self.msg = msg diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index a4b3b017..07e2d20c 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -39,7 +39,16 @@ import yaml from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed -from . import events, exceptions, futures, persistence, ports, process_comms, process_states, utils +from . import ( + events, + exceptions, + futures, + persistence, + ports, + process_comms, + process_states, + utils, +) from .base import state_machine from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check @@ -52,7 +61,7 @@ __all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed'] _LOGGER = logging.getLogger(__name__) -PROCESS_STACK = ContextVar('process stack', default=[]) +PROCESS_STACK = ContextVar("process stack", default=[]) class BundleKeys: @@ -85,14 +94,20 @@ def ensure_not_closed(func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(func) def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: if self._closed: - raise exceptions.ClosedError('Process is closed') + raise exceptions.ClosedError("Process is closed") return func(self, *args, **kwargs) return func_wrapper @persistence.auto_persist( - '_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper' + "_pid", + "_creation_time", + "_future", + "_paused", + "_status", + "_pre_paused_status", + "_event_helper", ) class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): """ @@ -146,7 +161,7 @@ class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMe __called: bool = False @classmethod - def current(cls) -> Optional['Process']: + def current(cls) -> Optional["Process"]: """ Get the currently running process i.e. the one at the top of the stack @@ -182,15 +197,15 @@ def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: @classmethod def spec(cls) -> ProcessSpec: try: - return cls.__getattribute__(cls, '_spec') + return cls.__getattribute__(cls, "_spec") except AttributeError: try: cls._spec: ProcessSpec = cls._spec_class() # type: ignore cls.__called: bool = False # type: ignore cls.define(cls._spec) # type: ignore assert cls.__called, ( - f'Process.define() was not called by {cls}\nHint: Did you forget to call the superclass method in ' - 'your define? Try: super().define(spec)' + f"Process.define() was not called by {cls}\nHint: Did you forget to call the superclass method in " + "your define? Try: super().define(spec)" ) return cls._spec # type: ignore except Exception: @@ -222,18 +237,20 @@ def get_description(cls) -> Dict[str, Any]: description: Dict[str, Any] = {} if cls.__doc__: - description['description'] = cls.__doc__.strip() + description["description"] = cls.__doc__.strip() spec_description = cls.spec().get_description() if spec_description: - description['spec'] = spec_description + description["spec"] = spec_description return description @classmethod def recreate_from( - cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None - ) -> 'Process': + cls, + saved_state: SAVED_STATE_TYPE, + load_context: Optional[persistence.LoadSaveContext] = None, + ) -> "Process": """ Recreate a process from a saved state, passing any positional and keyword arguments on to load_instance_state @@ -281,7 +298,9 @@ def __init__( self._paused = None # Input/output - self._raw_inputs = None if inputs is None else utils.AttributesFrozendict(inputs) + self._raw_inputs = ( + None if inputs is None else utils.AttributesFrozendict(inputs) + ) self._pid = pid self._parsed_inputs: Optional[utils.AttributesFrozendict] = None self._outputs: Dict[str, Any] = {} @@ -304,27 +323,49 @@ def init(self) -> None: if self._communicator is not None: try: - identifier = self._communicator.add_rpc_subscriber(self.message_receive, identifier=str(self.pid)) - self.add_cleanup(functools.partial(self._communicator.remove_rpc_subscriber, identifier)) + identifier = self._communicator.add_rpc_subscriber( + self.message_receive, identifier=str(self.pid) + ) + self.add_cleanup( + functools.partial( + self._communicator.remove_rpc_subscriber, identifier + ) + ) except kiwipy.TimeoutError: - self.logger.exception('Process<%s>: failed to register as an RPC subscriber', self.pid) + self.logger.exception( + "Process<%s>: failed to register as an RPC subscriber", self.pid + ) try: # filter out state change broadcasts - subscriber = kiwipy.BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*')) - identifier = self._communicator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) - self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier)) + subscriber = kiwipy.BroadcastFilter( + self.broadcast_receive, subject=re.compile(r"^(?!state_changed).*") + ) + identifier = self._communicator.add_broadcast_subscriber( + subscriber, identifier=str(self.pid) + ) + self.add_cleanup( + functools.partial( + self._communicator.remove_broadcast_subscriber, identifier + ) + ) except kiwipy.TimeoutError: - self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) + self.logger.exception( + "Process<%s>: failed to register as a broadcast subscriber", + self.pid, + ) if not self._future.done(): def try_killing(future: futures.Future) -> None: if future.cancelled(): msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'Killed by future being cancelled' + msg[MESSAGE_KEY] = "Killed by future being cancelled" if not self.kill(msg): - self.logger.warning('Process<%s>: Failed to kill process on future cancel', self.pid) + self.logger.warning( + "Process<%s>: Failed to kill process on future cancel", + self.pid, + ) self._future.add_done_callback(try_killing) @@ -419,7 +460,7 @@ def future(self) -> persistence.SavableFuture: @ensure_not_closed def launch( self, - process_class: Type['Process'], + process_class: Type["Process"], inputs: Optional[dict] = None, pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, @@ -428,7 +469,13 @@ def launch( The process is started asynchronously, without blocking other task in the event loop. """ - process = process_class(inputs=inputs, pid=pid, logger=logger, loop=self.loop, communicator=self._communicator) + process = process_class( + inputs=inputs, + pid=pid, + logger=logger, + loop=self.loop, + communicator=self._communicator, + ) self.loop.create_task(process.step_until_terminated()) return process @@ -451,7 +498,7 @@ def result(self) -> Any: if isinstance(self._state, process_states.Killed): raise exceptions.KilledError(self._state.msg) if isinstance(self._state, process_states.Excepted): - raise (self._state.exception or Exception('process excepted')) + raise (self._state.exception or Exception("process excepted")) raise exceptions.InvalidStateError @@ -463,7 +510,9 @@ def successful(self) -> bool: try: return self._state.successful # type: ignore except AttributeError as exception: - raise exceptions.InvalidStateError('process is not in the finished state') from exception + raise exceptions.InvalidStateError( + "process is not in the finished state" + ) from exception @property def is_successful(self) -> bool: @@ -480,12 +529,12 @@ def killed(self) -> bool: """Return whether the process is killed.""" return self.state == process_states.ProcessState.KILLED - def killed_msg(self) -> Optional[str]: + def killed_msg(self) -> Optional[MessageType]: """Return the killed message.""" if isinstance(self._state, process_states.Killed): return self._state.msg - raise exceptions.InvalidStateError('Has not been killed') + raise exceptions.InvalidStateError("Has not been killed") def exception(self) -> Optional[BaseException]: """Return exception, if the process is terminated in excepted state.""" @@ -520,7 +569,9 @@ def loop(self) -> asyncio.AbstractEventLoop: """Return the event loop of the process.""" return self._loop - def call_soon(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> events.ProcessCallback: + def call_soon( + self, callback: Callable[..., Any], *args: Any, **kwargs: Any + ) -> events.ProcessCallback: """ Schedule a callback to what is considered an internal process function (this needn't be a method). @@ -532,7 +583,10 @@ def call_soon(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> return handle def callback_excepted( - self, _callback: Callable[..., Any], exception: Optional[BaseException], trace: Optional[TracebackType] + self, + _callback: Callable[..., Any], + exception: Optional[BaseException], + trace: Optional[TracebackType], ) -> None: if self.state != process_states.ProcessState.EXCEPTED: self.fail(exception, trace) @@ -551,14 +605,16 @@ def _process_scope(self) -> Generator[None, None, None]: yield None finally: assert Process.current() is self, ( - 'Somehow, the process at the top of the stack is not me, but another process! ' - f'({self} != {Process.current()})' + "Somehow, the process at the top of the stack is not me, but another process! " + f"({self} != {Process.current()})" ) stack_copy = PROCESS_STACK.get().copy() stack_copy.pop() PROCESS_STACK.set(stack_copy) - async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + async def _run_task( + self, callback: Callable[..., Any], *args: Any, **kwargs: Any + ) -> Any: """ This method should be used to run all Process related functions and coroutines. If there is an exception the process will enter the EXCEPTED state. @@ -579,7 +635,9 @@ async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: An # region Persistence def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] + self, + out_state: SAVED_STATE_TYPE, + save_context: Optional[persistence.LoadSaveContext], ) -> None: """ Ask the process to save its current instance state. @@ -589,7 +647,7 @@ def save_instance_state( """ super().save_instance_state(out_state, save_context) - out_state['_state'] = self._state.save() + out_state["_state"] = self._state.save() # Inputs/outputs if self.raw_inputs is not None: @@ -602,7 +660,9 @@ def save_instance_state( out_state[BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) @protected - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: """Load the process from its saved instance state. :param saved_state: A bundle to load the state from @@ -620,17 +680,17 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self._logger = None self._communicator = None - if 'loop' in load_context: + if "loop" in load_context: self._loop = load_context.loop else: self._loop = asyncio.get_event_loop() - self._state: process_states.State = self.recreate_state(saved_state['_state']) + self._state: process_states.State = self.recreate_state(saved_state["_state"]) - if 'communicator' in load_context: + if "communicator" in load_context: self._communicator = load_context.communicator - if 'logger' in load_context: + if "logger" in load_context: self._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above @@ -679,7 +739,7 @@ def set_logger(self, logger: logging.Logger) -> None: @protected def log_with_pid(self, level: int, msg: str) -> None: """Log the message with the process pid.""" - self.logger.log(level, '%s: %s', self.pid, msg) + self.logger.log(level, "%s: %s", self.pid, msg) # region Events @@ -714,16 +774,24 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: call_with_super_check(self.on_killed) if self._communicator and isinstance(self.state, enum.Enum): - from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None - subject = f'state_changed.{from_label}.{self.state.value}' - self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) + from_label = ( + cast(enum.Enum, from_state.LABEL).value + if from_state is not None + else None + ) + subject = f"state_changed.{from_label}.{self.state.value}" + self.logger.info( + "Process<%s>: Broadcasting state change: %s", self.pid, subject + ) try: - self._communicator.broadcast_send(body=None, sender=self.pid, subject=subject) + self._communicator.broadcast_send( + body=None, sender=self.pid, subject=subject + ) except (ConnectionClosed, ChannelInvalidStateError): - message = 'Process<%s>: no connection available to broadcast state change from %s to %s' + message = "Process<%s>: no connection available to broadcast state change from %s to %s" self.logger.warning(message, self.pid, from_label, self.state.value) except kiwipy.TimeoutError: - message = 'Process<%s>: sending broadcast of state change from %s to %s timed out' + message = "Process<%s>: sending broadcast of state change from %s to %s timed out" self.logger.warning(message, self.pid, from_label, self.state.value) def on_exiting(self) -> None: @@ -741,7 +809,10 @@ def on_create(self) -> None: def recursively_copy_dictionaries(value: Any) -> Any: """Recursively copy the mapping but only create copies of the dictionaries not the values.""" if isinstance(value, dict): - return {key: recursively_copy_dictionaries(subvalue) for key, subvalue in value.items()} + return { + key: recursively_copy_dictionaries(subvalue) + for key, subvalue in value.items() + } return value # This will parse the inputs with respect to the input portnamespace of the spec and validate them. The @@ -749,7 +820,11 @@ def recursively_copy_dictionaries(value: Any) -> Any: # ``_raw_inputs`` should not be modified, we pass a clone of it. Note that we only need a clone of the nested # dictionaries, so we don't use ``copy.deepcopy`` (which might seem like the obvious choice) as that will also # create a clone of the values, which we don't want. - raw_inputs = recursively_copy_dictionaries(dict(self._raw_inputs)) if self._raw_inputs else {} + raw_inputs = ( + recursively_copy_dictionaries(dict(self._raw_inputs)) + if self._raw_inputs + else {} + ) self._parsed_inputs = self.spec().inputs.pre_process(raw_inputs) result = self.spec().inputs.validate(self._parsed_inputs) @@ -782,7 +857,9 @@ def on_output_emitting(self, output_port: str, value: Any) -> None: """Output is about to be emitted.""" def on_output_emitted(self, output_port: str, value: Any, dynamic: bool) -> None: - self._event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic) + self._event_helper.fire_event( + ProcessListener.on_output_emitted, self, output_port, value, dynamic + ) @super_check def on_wait(self, awaitables: Sequence[Awaitable]) -> None: @@ -831,7 +908,9 @@ def on_finish(self, result: Any, successful: bool) -> None: if successful: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: - raise StateEntryFailed(process_states.Finished, result, False) + raise StateEntryFailed( + process_states.Finished, result=result, successful=False + ) self.future().set_result(self.outputs) @@ -857,16 +936,17 @@ def on_except(self, exc_info: Tuple[Any, Exception, TracebackType]) -> None: @super_check def on_excepted(self) -> None: """Entered the EXCEPTED state.""" - self._fire_event(ProcessListener.on_process_excepted, str(self.future().exception())) + self._fire_event( + ProcessListener.on_process_excepted, str(self.future().exception()) + ) @super_check def on_kill(self, msg: Optional[MessageType]) -> None: """Entering the KILLED state.""" if msg is None: - msg_txt = '' + msg_txt = "" else: - # msg_txt = msg[MESSAGE_KEY] or '' - msg_txt = msg + msg_txt = msg[MESSAGE_KEY] or "" self.set_status(msg_txt) self.future().set_exception(exceptions.KilledError(msg_txt)) @@ -915,14 +995,21 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An :param msg: the message :return: the outcome of processing the message, the return value will be sent back as a response to the sender """ - self.logger.debug("Process<%s>: received RPC message with communicator '%s': %r", self.pid, _comm, msg) + self.logger.debug( + "Process<%s>: received RPC message with communicator '%s': %r", + self.pid, + _comm, + msg, + ) intent = msg[process_comms.INTENT_KEY] if intent == process_comms.Intent.PLAY: return self._schedule_rpc(self.play) if intent == process_comms.Intent.PAUSE: - return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None)) + return self._schedule_rpc( + self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None) + ) if intent == process_comms.Intent.KILL: return self._schedule_rpc(self.kill, msg=msg) if intent == process_comms.Intent.STATUS: @@ -931,7 +1018,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An return status_info # Didn't match any known intents - raise RuntimeError('Unknown intent') + raise RuntimeError("Unknown intent") def broadcast_receive( self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any @@ -944,7 +1031,11 @@ def broadcast_receive( """ self.logger.debug( - "Process<%s>: received broadcast message '%s' with communicator '%s': %r", self.pid, subject, _comm, body + "Process<%s>: received broadcast message '%s' with communicator '%s': %r", + self.pid, + subject, + _comm, + body, ) # If we get a message we recognise then action it, otherwise ignore @@ -956,7 +1047,9 @@ def broadcast_receive( return self._schedule_rpc(self.kill, msg=body) return None - def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future: + def _schedule_rpc( + self, callback: Callable[..., Any], *args: Any, **kwargs: Any + ) -> kiwipy.Future: """ Schedule a call to a callback as a result of an RPC communication call, this will return a future that resolves to the final result (even after one or more layer of futures being @@ -1010,15 +1103,23 @@ def close(self) -> None: # region State related methods def transition_failed( - self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType + self, + initial_state: Hashable, + final_state: Hashable, + exception: Exception, + trace: TracebackType, ) -> None: # If we are creating, then reraise instead of failing. if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to(process_states.Excepted, exception, trace) + self.transition_to( + process_states.Excepted, exception=exception, trace_back=trace + ) - def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: + def pause( + self, msg: Union[str, None] = None + ) -> Union[bool, futures.CancellableAction]: """Pause the process. :param msg: an optional message to set as the status. The current status will be saved in the private @@ -1063,7 +1164,9 @@ def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_state return True - def _create_interrupt_action(self, exception: process_states.Interruption) -> futures.CancellableAction: + def _create_interrupt_action( + self, exception: process_states.Interruption + ) -> futures.CancellableAction: """ Create an interrupt action from the corresponding interrupt exception @@ -1079,9 +1182,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: process_states.State) -> Any: try: - # Ignore the next state - # __import__('ipdb').set_trace() - self.transition_to(process_states.Killed, exception) + self.transition_to(process_states.Killed, msg=exception.msg) return True finally: self._killing = None @@ -1090,7 +1191,9 @@ def do_kill(_next_state: process_states.State) -> Any: raise ValueError(f"Got unknown interruption type '{type(exception)}'") - def _set_interrupt_action(self, new_action: Optional[futures.CancellableAction]) -> None: + def _set_interrupt_action( + self, new_action: Optional[futures.CancellableAction] + ) -> None: """ Set the interrupt action cancelling the current one if it exists :param new_action: The new interrupt action to set @@ -1127,13 +1230,17 @@ def resume(self, *args: Any) -> None: return self._state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) - def fail(self, exception: Optional[BaseException], trace_back: Optional[TracebackType]) -> None: + def fail( + self, exception: Optional[BaseException], trace_back: Optional[TracebackType] + ) -> None: """ Fail the process in response to an exception :param exception: The exception that caused the failure :param trace_back: Optional exception traceback """ - self.transition_to(process_states.Excepted, exception, trace_back) + self.transition_to( + process_states.Excepted, exception=exception, trace_back=trace_back + ) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ @@ -1161,7 +1268,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - self.transition_to(process_states.Killed, msg) + self.transition_to(process_states.Killed, msg=msg) return True @property @@ -1178,7 +1285,10 @@ def create_initial_state(self) -> process_states.State: :return: A Created state """ - return cast(process_states.State, self.get_state_class(process_states.ProcessState.CREATED)(self, self.run)) + return cast( + process_states.State, + self.get_state_class(process_states.ProcessState.CREATED)(self, self.run), + ) def recreate_state(self, saved_state: persistence.Bundle) -> process_states.State: """ @@ -1188,7 +1298,9 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat :return: An instance of the object with its state loaded from the save state. """ load_context = persistence.LoadSaveContext(process=self) - return cast(process_states.State, persistence.Savable.load(saved_state, load_context)) + return cast( + process_states.State, persistence.Savable.load(saved_state, load_context) + ) # endregion @@ -1221,7 +1333,7 @@ async def step(self) -> None: The execute function running in this method is dependent on the state of the process. """ - assert not self.has_terminated(), 'Cannot step, already terminated' + assert not self.has_terminated(), "Cannot step, already terminated" if self.paused and self._paused is not None: await self._paused @@ -1246,7 +1358,9 @@ async def step(self) -> None: raise except Exception: # Overwrite the next state to go to excepted directly - next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:]) + next_state = self.create_state( + process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:] + ) self._set_interrupt_action(None) if self._interrupt_action: