-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Message passing with more information #291
base: master
Are you sure you want to change the base?
Changes from all commits
1117eeb
b82791d
d4c0489
c5a195c
8db6675
74d048d
667af7a
4be6931
88259d6
c3c9db4
d0e4e73
e3c2ae8
e5c74ad
17a541a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,28 @@ | ||
# -*- coding: utf-8 -*- | ||
"""The state machine for processes""" | ||
|
||
from __future__ import annotations | ||
|
||
import enum | ||
import functools | ||
import inspect | ||
import logging | ||
import os | ||
import sys | ||
from types import TracebackType | ||
from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Sequence, Set, Type, Union, cast | ||
from typing import ( | ||
Any, | ||
Callable, | ||
Dict, | ||
Hashable, | ||
Iterable, | ||
List, | ||
Optional, | ||
Sequence, | ||
Set, | ||
Type, | ||
Union, | ||
) | ||
|
||
from plumpy.futures import Future | ||
|
||
|
@@ -31,7 +45,7 @@ class StateEntryFailed(Exception): # noqa: N818 | |
Failed to enter a state, can provide the next state to go to via this exception | ||
""" | ||
|
||
def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None: | ||
def __init__(self, state: State, *args: Any, **kwargs: Any) -> None: | ||
super().__init__('failed to enter state') | ||
self.state = state | ||
self.args = args | ||
|
@@ -187,7 +201,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine': | |
:param kwargs: Any keyword arguments to be passed to the constructor | ||
:return: An instance of the state machine | ||
""" | ||
inst = super().__call__(*args, **kwargs) | ||
inst: StateMachine = super().__call__(*args, **kwargs) | ||
inst.transition_to(inst.create_initial_state()) | ||
call_with_super_check(inst.init) | ||
return inst | ||
|
@@ -300,16 +314,23 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: | |
def on_terminated(self) -> None: | ||
"""Called when a terminal state is entered""" | ||
|
||
def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None: | ||
def transition_to(self, new_state: State | None, **kwargs: Any) -> None: | ||
"""Transite to the new state. | ||
|
||
The new target state will be create lazily when the state is not yet instantiated, | ||
which will happened for states not in the expect path such as pause and kill. | ||
The arguments are passed to the state class to create state instance. | ||
(process arg does not need to pass since it will always call with 'self' as process) | ||
""" | ||
assert not self._transitioning, 'Cannot call transition_to when already transitioning state' | ||
|
||
if new_state is None: | ||
return None | ||
|
||
initial_state_label = self._state.LABEL if self._state is not None else None | ||
label = None | ||
try: | ||
self._transitioning = True | ||
|
||
# Make sure we have a state instance | ||
new_state = self._create_state_instance(new_state, *args, **kwargs) | ||
label = new_state.LABEL | ||
|
||
# If the previous transition failed, do not try to exit it but go straight to next state | ||
|
@@ -319,8 +340,7 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A | |
try: | ||
self._enter_next_state(new_state) | ||
except StateEntryFailed as exception: | ||
# Make sure we have a state instance | ||
new_state = self._create_state_instance(exception.state, *exception.args, **exception.kwargs) | ||
new_state = exception.state | ||
label = new_state.LABEL | ||
self._exit_current_state(new_state) | ||
self._enter_next_state(new_state) | ||
|
@@ -338,7 +358,11 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A | |
self._transitioning = False | ||
|
||
def transition_failed( | ||
self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType | ||
self, | ||
initial_state: Hashable, | ||
final_state: Hashable, | ||
exception: Exception, | ||
trace: TracebackType, | ||
) -> None: | ||
"""Called when a state transitions fails. | ||
|
||
|
@@ -355,6 +379,10 @@ def set_debug(self, enabled: bool) -> None: | |
self._debug: bool = enabled | ||
|
||
def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State: | ||
# XXX: this method create state from label, which is duplicate as _create_state_instance and less generic | ||
# because the label is defined after the state and required to be know before calling this function. | ||
# This method should be replaced by `_create_state_instance`. | ||
# aiida-core using this method for its Waiting state override. | ||
try: | ||
return self.get_states_map()[state_label](self, *args, **kwargs) | ||
except KeyError: | ||
|
@@ -383,20 +411,8 @@ def _enter_next_state(self, next_state: State) -> None: | |
self._state = next_state | ||
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) | ||
|
||
def _create_state_instance(self, state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> State: | ||
if isinstance(state, State): | ||
# It's already a state instance | ||
return state | ||
|
||
# OK, have to create it | ||
state_cls = self._ensure_state_class(state) | ||
return state_cls(self, *args, **kwargs) | ||
|
||
def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State]: | ||
if inspect.isclass(state) and issubclass(state, State): | ||
return state | ||
def _create_state_instance(self, state_cls: type[State], **kwargs: Any) -> State: | ||
if state_cls.LABEL not in self.get_states_map(): | ||
raise ValueError(f'{state_cls.LABEL} is not a valid state') | ||
Comment on lines
+414
to
+416
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is one of the major changes in this PR. I made this function only create the instance as name indicated. If it is already a state instance the logic is moved directly to |
||
|
||
try: | ||
return self.get_states_map()[cast(Hashable, state)] | ||
except KeyError: | ||
raise ValueError(f'{state} is not a valid state') | ||
return state_cls(self, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Module for process level communication functions and classes""" | ||
|
||
from __future__ import annotations | ||
|
||
import asyncio | ||
import copy | ||
import logging | ||
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast | ||
|
||
|
@@ -12,13 +13,13 @@ | |
from .utils import PID_TYPE | ||
|
||
__all__ = [ | ||
'KILL_MSG', | ||
'PAUSE_MSG', | ||
'PLAY_MSG', | ||
'STATUS_MSG', | ||
'KillMessage', | ||
'PauseMessage', | ||
'PlayMessage', | ||
'ProcessLauncher', | ||
'RemoteProcessController', | ||
'RemoteProcessThreadController', | ||
'StatusMessage', | ||
'create_continue_body', | ||
'create_launch_body', | ||
] | ||
|
@@ -31,6 +32,7 @@ | |
|
||
INTENT_KEY = 'intent' | ||
MESSAGE_KEY = 'message' | ||
FORCE_KILL_KEY = 'force_kill' | ||
|
||
|
||
class Intent: | ||
|
@@ -42,10 +44,45 @@ class Intent: | |
STATUS: str = 'status' | ||
|
||
|
||
PAUSE_MSG = {INTENT_KEY: Intent.PAUSE} | ||
PLAY_MSG = {INTENT_KEY: Intent.PLAY} | ||
KILL_MSG = {INTENT_KEY: Intent.KILL} | ||
STATUS_MSG = {INTENT_KEY: Intent.STATUS} | ||
MessageType = Dict[str, Any] | ||
|
||
|
||
class PlayMessage: | ||
@classmethod | ||
def build(cls, message: str | None = None) -> MessageType: | ||
return { | ||
INTENT_KEY: Intent.PLAY, | ||
MESSAGE_KEY: message, | ||
} | ||
|
||
|
||
class PauseMessage: | ||
@classmethod | ||
def build(cls, message: str | None = None) -> MessageType: | ||
return { | ||
INTENT_KEY: Intent.PAUSE, | ||
MESSAGE_KEY: message, | ||
} | ||
|
||
|
||
class KillMessage: | ||
@classmethod | ||
def build(cls, message: str | None = None, force: bool = False) -> MessageType: | ||
return { | ||
INTENT_KEY: Intent.KILL, | ||
MESSAGE_KEY: message, | ||
FORCE_KILL_KEY: force, | ||
} | ||
|
||
|
||
class StatusMessage: | ||
@classmethod | ||
def build(cls, message: str | None = None) -> MessageType: | ||
return { | ||
INTENT_KEY: Intent.STATUS, | ||
MESSAGE_KEY: message, | ||
} | ||
|
||
Comment on lines
+50
to
+85
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of using global variable for message and using |
||
|
||
TASK_KEY = 'task' | ||
TASK_ARGS = 'args' | ||
|
@@ -162,7 +199,7 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus': | |
:param pid: the process id | ||
:return: the status response from the process | ||
""" | ||
future = self._communicator.rpc_send(pid, STATUS_MSG) | ||
future = self._communicator.rpc_send(pid, StatusMessage.build()) | ||
result = await asyncio.wrap_future(future) | ||
return result | ||
|
||
|
@@ -174,11 +211,9 @@ async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pr | |
:param msg: optional pause message | ||
:return: True if paused, False otherwise | ||
""" | ||
message = copy.copy(PAUSE_MSG) | ||
if msg is not None: | ||
message[MESSAGE_KEY] = msg | ||
msg = PauseMessage.build(message=msg) | ||
|
||
pause_future = self._communicator.rpc_send(pid, message) | ||
pause_future = self._communicator.rpc_send(pid, msg) | ||
# rpc_send return a thread future from communicator | ||
future = await asyncio.wrap_future(pause_future) | ||
# future is just returned from rpc call which return a kiwipy future | ||
|
@@ -192,25 +227,24 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult': | |
:param pid: the pid of the process to play | ||
:return: True if played, False otherwise | ||
""" | ||
play_future = self._communicator.rpc_send(pid, PLAY_MSG) | ||
play_future = self._communicator.rpc_send(pid, PlayMessage.build()) | ||
future = await asyncio.wrap_future(play_future) | ||
result = await asyncio.wrap_future(future) | ||
return result | ||
|
||
async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult': | ||
async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult': | ||
""" | ||
Kill the process | ||
|
||
:param pid: the pid of the process to kill | ||
:param msg: optional kill message | ||
:return: True if killed, False otherwise | ||
""" | ||
message = copy.copy(KILL_MSG) | ||
if msg is not None: | ||
message[MESSAGE_KEY] = msg | ||
if msg is None: | ||
msg = KillMessage.build() | ||
|
||
# Wait for the communication to go through | ||
kill_future = self._communicator.rpc_send(pid, message) | ||
kill_future = self._communicator.rpc_send(pid, msg) | ||
future = await asyncio.wrap_future(kill_future) | ||
# Now wait for the kill to be enacted | ||
result = await asyncio.wrap_future(future) | ||
|
@@ -331,7 +365,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future: | |
:param pid: the process id | ||
:return: the status response from the process | ||
""" | ||
return self._communicator.rpc_send(pid, STATUS_MSG) | ||
return self._communicator.rpc_send(pid, StatusMessage.build()) | ||
|
||
def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: | ||
""" | ||
|
@@ -342,11 +376,9 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu | |
:return: a response future from the process to be paused | ||
|
||
""" | ||
message = copy.copy(PAUSE_MSG) | ||
if msg is not None: | ||
message[MESSAGE_KEY] = msg | ||
msg = PauseMessage.build(message=msg) | ||
|
||
return self._communicator.rpc_send(pid, message) | ||
return self._communicator.rpc_send(pid, msg) | ||
|
||
def pause_all(self, msg: Any) -> None: | ||
""" | ||
|
@@ -364,15 +396,15 @@ def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future: | |
:return: a response future from the process to be played | ||
|
||
""" | ||
return self._communicator.rpc_send(pid, PLAY_MSG) | ||
return self._communicator.rpc_send(pid, PlayMessage.build()) | ||
|
||
def play_all(self) -> None: | ||
""" | ||
Play all processes that are subscribed to the same communicator | ||
""" | ||
self._communicator.broadcast_send(None, subject=Intent.PLAY) | ||
|
||
def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: | ||
def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future: | ||
""" | ||
Kill the process | ||
|
||
|
@@ -381,18 +413,20 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fut | |
:return: a response future from the process to be killed | ||
|
||
""" | ||
message = copy.copy(KILL_MSG) | ||
if msg is not None: | ||
message[MESSAGE_KEY] = msg | ||
if msg is None: | ||
msg = KillMessage.build() | ||
|
||
return self._communicator.rpc_send(pid, message) | ||
return self._communicator.rpc_send(pid, msg) | ||
|
||
def kill_all(self, msg: Optional[Any]) -> None: | ||
def kill_all(self, msg: Optional[MessageType]) -> None: | ||
""" | ||
Kill all processes that are subscribed to the same communicator | ||
|
||
:param msg: an optional pause message | ||
""" | ||
if msg is None: | ||
msg = KillMessage.build() | ||
|
||
self._communicator.broadcast_send(msg, subject=Intent.KILL) | ||
|
||
def continue_process( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to cover the case in
Process.step
when thestate.execute
lead to the next state is none from terminal state.