Skip to content

Commit

Permalink
Forming a State protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 2, 2024
1 parent 396bf33 commit 8841574
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 114 deletions.
45 changes: 14 additions & 31 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from __future__ import annotations

import abc
from abc import abstractmethod, ABCMeta
import enum
import functools
import inspect
Expand All @@ -22,7 +22,6 @@
Optional,
Protocol,
Sequence,
Set,
Type,
Union,
runtime_checkable,
Expand All @@ -49,7 +48,7 @@ class StateEntryFailed(Exception): # noqa: N818
Failed to enter a state, can provide the next state to go to via this exception
"""

def __init__(self, state: type['State'], *args: Any, **kwargs: Any) -> None:
def __init__(self, state: 'State', *args: Any, **kwargs: Any) -> None:
super().__init__('failed to enter state')
self.state = state
self.args = args
Expand Down Expand Up @@ -92,12 +91,12 @@ def event(
if from_states != '*':
if inspect.isclass(from_states):
from_states = (from_states,)
if not all(issubclass(state, State) for state in from_states): # type: ignore
if not all(isinstance(state, State) for state in from_states): # type: ignore
raise TypeError(f'from_states: {from_states}')
if to_states != '*':
if inspect.isclass(to_states):
to_states = (to_states,)
if not all(issubclass(state, State) for state in to_states): # type: ignore
if not all(isinstance(state, State) for state in to_states): # type: ignore
raise TypeError(f'to_states: {to_states}')

def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]:
Expand Down Expand Up @@ -131,32 +130,20 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any:
return wrapper


class State(abc.ABC):
@runtime_checkable
class State(Protocol):
LABEL: ClassVar[LABEL_TYPE]

def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any):
"""
:param state_machine: The process this state belongs to
"""
self.state_machine = state_machine

def execute(self) -> Optional['State']:
async def execute(self) -> State | None:
"""
Execute the state, performing the actions that this state is responsible for.
:returns: a state to transition to or None if finished.
"""
...

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State':
return self.state_machine.create_state(state_label, *args, **kwargs)

@abc.abstractmethod
def enter(self) -> None:
...
def enter(self) -> None: ...

@abc.abstractmethod
def exit(self) -> None:
...
def exit(self) -> None: ...


class StateEventHook(enum.Enum):
Expand Down Expand Up @@ -233,7 +220,7 @@ def __ensure_built(cls) -> None:
# Build the states map
cls._STATES_MAP = {}
for state_cls in cls.STATES:
assert issubclass(state_cls, State)
assert isinstance(state_cls, State)
label = state_cls.LABEL
assert label not in cls._STATES_MAP, f"Duplicate label '{label}'"
cls._STATES_MAP[label] = state_cls
Expand Down Expand Up @@ -293,7 +280,7 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None:
def on_terminated(self) -> None:
"""Called when a terminal state is entered"""

def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> None:
def transition_to(self, new_state: State | None, **kwargs: Any) -> None:
"""Transite to the new state.
The new target state will be create lazily when the state is not yet instantiated,
Expand All @@ -311,10 +298,6 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) ->
try:
self._transitioning = True

if not isinstance(new_state, State):
# Make sure we have a state instance
new_state = self._create_state_instance(new_state, **kwargs)

label = new_state.LABEL

# If the previous transition failed, do not try to exit it but go straight to next state
Expand All @@ -325,8 +308,7 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) ->
self._enter_next_state(new_state)
except StateEntryFailed as exception:
# Make sure we have a state instance
if not isinstance(exception.state, State):
new_state = self._create_state_instance(exception.state, **exception.kwargs)
new_state = exception.state
label = new_state.LABEL
self._exit_current_state(new_state)
self._enter_next_state(new_state)
Expand Down Expand Up @@ -370,7 +352,8 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> Stat
# This method should be replaced by `_create_state_instance`.
# aiida-core using this method for its Waiting state override.
try:
return self.get_states_map()[state_label](self, *args, **kwargs)
state_cls = self.get_states_map()[state_label]
return state_cls(self, *args, **kwargs)
except KeyError:
raise ValueError(f'{state_label} is not a valid state')

Expand Down
Loading

0 comments on commit 8841574

Please sign in to comment.