Skip to content

Commit

Permalink
forming the enter/exit for State protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 4, 2024
1 parent 985c2ed commit e4e0df5
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 128 deletions.
66 changes: 18 additions & 48 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
from typing import (
Any,
Callable,
ClassVar,
Dict,
Hashable,
Iterable,
List,
Optional,
Protocol,
Sequence,
Set,
Type,
Union,
runtime_checkable,
)

from plumpy.futures import Future
Expand Down Expand Up @@ -88,12 +90,12 @@ def event(
if from_states != '*':
if inspect.isclass(from_states):
from_states = (from_states,)
if not all(issubclass(state, State) for state in from_states): # type: ignore
if not all(isinstance(state, State) for state in from_states): # type: ignore
raise TypeError(f'from_states: {from_states}')
if to_states != '*':
if inspect.isclass(to_states):
to_states = (to_states,)
if not all(issubclass(state, State) for state in to_states): # type: ignore
if not all(isinstance(state, State) for state in to_states): # type: ignore
raise TypeError(f'to_states: {to_states}')

def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]:
Expand Down Expand Up @@ -127,53 +129,20 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any:
return wrapper


class State:
LABEL: LABEL_TYPE = None
# A set containing the labels of states that can be entered
# from this one
ALLOWED: Set[LABEL_TYPE] = set()
@runtime_checkable
class State(Protocol):
LABEL: ClassVar[LABEL_TYPE]

def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any):
"""
:param state_machine: The process this state belongs to
"""
self.state_machine = state_machine
self.in_state: bool = False

def __str__(self) -> str:
return str(self.LABEL)

@property
def label(self) -> LABEL_TYPE:
"""Convenience property to get the state label"""
return self.LABEL

@super_check
def enter(self) -> None:
"""Entering the state"""

def execute(self) -> Optional['State']:
async def execute(self) -> State | None:
"""
Execute the state, performing the actions that this state is responsible for.
:returns: a state to transition to or None if finished.
"""
...

@super_check
def exit(self) -> None:
"""Exiting the state"""
if self.is_terminal:
raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State':
return self.state_machine.create_state(state_label, *args, **kwargs)

def do_enter(self) -> None:
call_with_super_check(self.enter)
self.in_state = True
def enter(self) -> None: ...

def do_exit(self) -> None:
call_with_super_check(self.exit)
self.in_state = False
def exit(self) -> None: ...


class StateEventHook(enum.Enum):
Expand Down Expand Up @@ -250,7 +219,7 @@ def __ensure_built(cls) -> None:
# Build the states map
cls._STATES_MAP = {}
for state_cls in cls.STATES:
assert issubclass(state_cls, State)
assert isinstance(state_cls, State)
label = state_cls.LABEL
assert label not in cls._STATES_MAP, f"Duplicate label '{label}'"
cls._STATES_MAP[label] = state_cls
Expand Down Expand Up @@ -380,7 +349,8 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> Stat
# This method should be replaced by `_create_state_instance`.
# aiida-core using this method for its Waiting state override.
try:
return self.get_states_map()[state_label](self, *args, **kwargs)
state_cls = self.get_states_map()[state_label]
return state_cls(self, *args, **kwargs)
except KeyError:
raise ValueError(f'{state_label} is not a valid state')

Expand All @@ -390,20 +360,20 @@ def _exit_current_state(self, next_state: State) -> None:
# If we're just being constructed we may not have a state yet to exit,
# in which case check the new state is the initial state
if self._state is None:
if next_state.label != self.initial_state_label():
if next_state.LABEL != self.initial_state_label():
raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state")
return # Nothing to exit

if next_state.LABEL not in self._state.ALLOWED:
raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}')
self._fire_state_event(StateEventHook.EXITING_STATE, next_state)
self._state.do_exit()
self._state.exit()

def _enter_next_state(self, next_state: State) -> None:
last_state = self._state
self._fire_state_event(StateEventHook.ENTERING_STATE, next_state)
# Enter the new state
next_state.do_enter()
next_state.enter()
self._state = next_state
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)

Expand Down
Loading

0 comments on commit e4e0df5

Please sign in to comment.