Skip to content

Commit

Permalink
Killed state all through passing msg
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 1, 2024
1 parent dad61ac commit d7078fc
Show file tree
Hide file tree
Showing 4 changed files with 336 additions and 138 deletions.
19 changes: 10 additions & 9 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine':
:param kwargs: Any keyword arguments to be passed to the constructor
:return: An instance of the state machine
"""
inst = super().__call__(*args, **kwargs)
inst: StateMachine = super().__call__(*args, **kwargs)
inst.transition_to(inst.create_initial_state())
call_with_super_check(inst.init)
return inst
Expand Down Expand Up @@ -325,13 +325,14 @@ def on_terminated(self) -> None:
"""Called when a terminal state is entered"""

def transition_to(
self, new_state: Union[State, Type[State]], *args: Any, **kwargs: Any
self, new_state: Union[State, Type[State]], **kwargs: Any
) -> None:
"""Transite to the new state.
The new target state will be create lazily when the state
is not yet instantiated, which will happened for states not in the expect path such as
pause and kill.
The new target state will be create lazily when the state is not yet instantiated,
which will happened for states not in the expect path such as pause and kill.
The arguments are passed to the state class to create state instance.
(process arg does not need to pass since it will always call with 'self' as process)
"""
assert (
not self._transitioning
Expand All @@ -344,7 +345,7 @@ def transition_to(

if not isinstance(new_state, State):
# Make sure we have a state instance
new_state = self._create_state_instance(new_state, *args, **kwargs)
new_state = self._create_state_instance(new_state, **kwargs)

label = new_state.LABEL

Expand All @@ -358,7 +359,7 @@ def transition_to(
# Make sure we have a state instance
if not isinstance(exception.state, State):
new_state = self._create_state_instance(
exception.state, *exception.args, **exception.kwargs
exception.state, **exception.kwargs
)
label = new_state.LABEL
self._exit_current_state(new_state)
Expand Down Expand Up @@ -435,9 +436,9 @@ def _enter_next_state(self, next_state: State) -> None:
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)

def _create_state_instance(
self, state_cls: type[State], *args: Any, **kwargs: Any
self, state_cls: type[State], **kwargs: Any
) -> State:
if state_cls.LABEL not in self.get_states_map():
raise ValueError(f"{state_cls.LABEL} is not a valid state")

return state_cls(self, *args, **kwargs)
return state_cls(self, **kwargs)
5 changes: 4 additions & 1 deletion src/plumpy/process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,12 +386,15 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> ki

return self._communicator.rpc_send(pid, msg)

def kill_all(self, msg: Optional[Any]) -> None:
def kill_all(self, msg: Optional[MessageType]) -> None:
"""
Kill all processes that are subscribed to the same communicator
:param msg: an optional pause message
"""
if msg is None:
msg = copy.copy(KILL_MSG)

self._communicator.broadcast_send(msg, subject=Intent.KILL)

def continue_process(
Expand Down
Loading

0 comments on commit d7078fc

Please sign in to comment.