Skip to content

Commit

Permalink
Forming the contract with abc.ABC for state_machine.ABC
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 2, 2024
1 parent 31167cf commit 396bf33
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
25 changes: 10 additions & 15 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import abc
import enum
import functools
import inspect
Expand All @@ -13,15 +14,18 @@
from typing import (
Any,
Callable,
ClassVar,
Dict,
Hashable,
Iterable,
List,
Optional,
Protocol,
Sequence,
Set,
Type,
Union,
runtime_checkable,
)

from plumpy.futures import Future
Expand Down Expand Up @@ -127,39 +131,30 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any:
return wrapper


class State:
LABEL: LABEL_TYPE = None
# A set containing the labels of states that can be entered
# from this one
ALLOWED: Set[LABEL_TYPE] = set()
class State(abc.ABC):
LABEL: ClassVar[LABEL_TYPE]

def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any):
"""
:param state_machine: The process this state belongs to
"""
self.state_machine = state_machine
self.in_state: bool = False

def __str__(self) -> str:
return str(self.LABEL)

@property
def label(self) -> LABEL_TYPE:
"""Convenience property to get the state label"""
return self.LABEL

def execute(self) -> Optional['State']:
"""
Execute the state, performing the actions that this state is responsible for.
:returns: a state to transition to or None if finished.
"""
...

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State':
return self.state_machine.create_state(state_label, *args, **kwargs)

@abc.abstractmethod
def enter(self) -> None:
...

@abc.abstractmethod
def exit(self) -> None:
...

Expand Down Expand Up @@ -385,7 +380,7 @@ def _exit_current_state(self, next_state: State) -> None:
# If we're just being constructed we may not have a state yet to exit,
# in which case check the new state is the initial state
if self._state is None:
if next_state.label != self.initial_state_label():
if next_state.LABEL != self.initial_state_label():
raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state")
return # Nothing to exit

Expand Down
2 changes: 1 addition & 1 deletion tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def test_exception_during_on_entered(self):

class RaisingProcess(Process):
def on_entered(self, from_state):
if from_state is not None and from_state.label == ProcessState.RUNNING:
if from_state is not None and from_state.LABEL == ProcessState.RUNNING:
raise RuntimeError('exception during on_entered')
super().on_entered(from_state)

Expand Down

0 comments on commit 396bf33

Please sign in to comment.