Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State class de-abstract #4

Merged
merged 6 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ py:class kiwipy.communications.Communicator

# unavailable forward references
py:class plumpy.process_states.Command
py:class plumpy.process_states.State
py:class plumpy.state_machine.State
py:class plumpy.base.state_machine.State
py:class State
py:class Process
Expand Down
120 changes: 41 additions & 79 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
from typing import (
Any,
Callable,
ClassVar,
Dict,
Hashable,
Iterable,
List,
Optional,
Protocol,
Sequence,
Set,
Type,
Union,
runtime_checkable,
)

from plumpy.futures import Future
Expand All @@ -32,7 +34,6 @@

_LOGGER = logging.getLogger(__name__)

LABEL_TYPE = Union[None, enum.Enum, str]
EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None]


Expand All @@ -45,7 +46,7 @@ class StateEntryFailed(Exception): # noqa: N818
Failed to enter a state, can provide the next state to go to via this exception
"""

def __init__(self, state: type['State'], *args: Any, **kwargs: Any) -> None:
def __init__(self, state: State, *args: Any, **kwargs: Any) -> None:
super().__init__('failed to enter state')
self.state = state
self.args = args
Expand Down Expand Up @@ -88,12 +89,12 @@ def event(
if from_states != '*':
if inspect.isclass(from_states):
from_states = (from_states,)
if not all(issubclass(state, State) for state in from_states): # type: ignore
if not all(isinstance(state, State) for state in from_states): # type: ignore
raise TypeError(f'from_states: {from_states}')
if to_states != '*':
if inspect.isclass(to_states):
to_states = (to_states,)
if not all(issubclass(state, State) for state in to_states): # type: ignore
if not all(isinstance(state, State) for state in to_states): # type: ignore
raise TypeError(f'to_states: {to_states}')

def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]:
Expand Down Expand Up @@ -127,57 +128,40 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any:
return wrapper


class State:
LABEL: LABEL_TYPE = None
# A set containing the labels of states that can be entered
# from this one
ALLOWED: Set[LABEL_TYPE] = set()
@runtime_checkable
class State(Protocol):
LABEL: ClassVar[Any]
ALLOWED: ClassVar[set[Any]]
is_terminal: ClassVar[bool]

@classmethod
def is_terminal(cls) -> bool:
return not cls.ALLOWED
def __init__(self, *args: Any, **kwargs: Any): ...

def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any):
"""
:param state_machine: The process this state belongs to
"""
self.state_machine = state_machine
self.in_state: bool = False
def enter(self) -> None: ...

def __str__(self) -> str:
return str(self.LABEL)
def exit(self) -> None: ...

@property
def label(self) -> LABEL_TYPE:
"""Convenience property to get the state label"""
return self.LABEL

@super_check
def enter(self) -> None:
"""Entering the state"""
@runtime_checkable
class Interruptable(Protocol):
def interrupt(self, reason: Exception) -> None: ...

def execute(self) -> Optional['State']:

@runtime_checkable
class Proceedable(Protocol):
def execute(self) -> State | None:
"""
Execute the state, performing the actions that this state is responsible for.
:returns: a state to transition to or None if finished.
"""
...

@super_check
def exit(self) -> None:
"""Exiting the state"""
if self.is_terminal():
raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State':
return self.state_machine.create_state(state_label, *args, **kwargs)
def create_state(st: StateMachine, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
if state_label not in st.get_states_map():
raise ValueError(f'{state_label} is not a valid state')

def do_enter(self) -> None:
call_with_super_check(self.enter)
self.in_state = True

def do_exit(self) -> None:
call_with_super_check(self.exit)
self.in_state = False
state_cls = st.get_states_map()[state_label]
return state_cls(*args, **kwargs)


class StateEventHook(enum.Enum):
Expand Down Expand Up @@ -228,13 +212,13 @@ def get_states(cls) -> Sequence[Type[State]]:
raise RuntimeError('States not defined')

@classmethod
def initial_state_label(cls) -> LABEL_TYPE:
def initial_state_label(cls) -> Any:
cls.__ensure_built()
assert cls.STATES is not None
return cls.STATES[0].LABEL

@classmethod
def get_state_class(cls, label: LABEL_TYPE) -> Type[State]:
def get_state_class(cls, label: Any) -> Type[State]:
cls.__ensure_built()
assert cls._STATES_MAP is not None
return cls._STATES_MAP[label]
Expand All @@ -254,7 +238,7 @@ def __ensure_built(cls) -> None:
# Build the states map
cls._STATES_MAP = {}
for state_cls in cls.STATES:
assert issubclass(state_cls, State)
assert isinstance(state_cls, State)
label = state_cls.LABEL
assert label not in cls._STATES_MAP, f"Duplicate label '{label}'"
cls._STATES_MAP[label] = state_cls
Expand All @@ -278,11 +262,11 @@ def init(self) -> None:
def __str__(self) -> str:
return f'<{self.__class__.__name__}> ({self.state})'

def create_initial_state(self) -> State:
return self.get_state_class(self.initial_state_label())(self)
def create_initial_state(self, *args: Any, **kwargs: Any) -> State:
return self.get_state_class(self.initial_state_label())(self, *args, **kwargs)

@property
def state(self) -> Optional[LABEL_TYPE]:
def state(self) -> Any:
if self._state is None:
return None
return self._state.LABEL
Expand Down Expand Up @@ -314,14 +298,15 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None:
def on_terminated(self) -> None:
"""Called when a terminal state is entered"""

def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> None:
def transition_to(self, new_state: State | None, **kwargs: Any) -> None:
"""Transite to the new state.

The new target state will be create lazily when the state is not yet instantiated,
which will happened for states not in the expect path such as pause and kill.
The arguments are passed to the state class to create state instance.
(process arg does not need to pass since it will always call with 'self' as process)
"""
print(f'try: {self._state} -> {new_state}')
assert not self._transitioning, 'Cannot call transition_to when already transitioning state'

if new_state is None:
Expand All @@ -331,11 +316,6 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) ->
label = None
try:
self._transitioning = True

if not isinstance(new_state, State):
# Make sure we have a state instance
new_state = self._create_state_instance(new_state, **kwargs)

label = new_state.LABEL

# If the previous transition failed, do not try to exit it but go straight to next state
Expand All @@ -345,14 +325,12 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) ->
try:
self._enter_next_state(new_state)
except StateEntryFailed as exception:
# Make sure we have a state instance
if not isinstance(exception.state, State):
new_state = self._create_state_instance(exception.state, **exception.kwargs)
new_state = exception.state
label = new_state.LABEL
self._exit_current_state(new_state)
self._enter_next_state(new_state)

if self._state is not None and self._state.is_terminal():
if self._state is not None and self._state.is_terminal:
call_with_super_check(self.on_terminated)
except Exception:
self._transitioning = False
Expand Down Expand Up @@ -385,41 +363,25 @@ def get_debug(self) -> bool:
def set_debug(self, enabled: bool) -> None:
self._debug: bool = enabled

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
# XXX: this method create state from label, which is duplicate as _create_state_instance and less generic
# because the label is defined after the state and required to be know before calling this function.
# This method should be replaced by `_create_state_instance`.
# aiida-core using this method for its Waiting state override.
try:
return self.get_states_map()[state_label](self, *args, **kwargs)
except KeyError:
raise ValueError(f'{state_label} is not a valid state')

def _exit_current_state(self, next_state: State) -> None:
"""Exit the given state"""

# If we're just being constructed we may not have a state yet to exit,
# in which case check the new state is the initial state
if self._state is None:
if next_state.label != self.initial_state_label():
if next_state.LABEL != self.initial_state_label():
raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state")
return # Nothing to exit

if next_state.LABEL not in self._state.ALLOWED:
raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}')
raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.LABEL}')
self._fire_state_event(StateEventHook.EXITING_STATE, next_state)
self._state.do_exit()
self._state.exit()

def _enter_next_state(self, next_state: State) -> None:
last_state = self._state
self._fire_state_event(StateEventHook.ENTERING_STATE, next_state)
# Enter the new state
next_state.do_enter()
next_state.enter()
self._state = next_state
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)

def _create_state_instance(self, state_cls: type[State], **kwargs: Any) -> State:
if state_cls.LABEL not in self.get_states_map():
raise ValueError(f'{state_cls.LABEL} is not a valid state')

return state_cls(self, **kwargs)
Loading
Loading