From 396bf3325856ce9d9380a1aee7b090894d425b73 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 17:38:58 +0100 Subject: [PATCH] Forming the contract with abc.ABC for state_machine.ABC --- src/plumpy/base/state_machine.py | 25 ++++++++++--------------- tests/test_processes.py | 2 +- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 8a8fb4af..a650d0ed 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -3,6 +3,7 @@ from __future__ import annotations +import abc import enum import functools import inspect @@ -13,15 +14,18 @@ from typing import ( Any, Callable, + ClassVar, Dict, Hashable, Iterable, List, Optional, + Protocol, Sequence, Set, Type, Union, + runtime_checkable, ) from plumpy.futures import Future @@ -127,39 +131,30 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: return wrapper -class State: - LABEL: LABEL_TYPE = None - # A set containing the labels of states that can be entered - # from this one - ALLOWED: Set[LABEL_TYPE] = set() +class State(abc.ABC): + LABEL: ClassVar[LABEL_TYPE] def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): """ :param state_machine: The process this state belongs to """ self.state_machine = state_machine - self.in_state: bool = False - - def __str__(self) -> str: - return str(self.LABEL) - - @property - def label(self) -> LABEL_TYPE: - """Convenience property to get the state label""" - return self.LABEL def execute(self) -> Optional['State']: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. """ + ... def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': return self.state_machine.create_state(state_label, *args, **kwargs) + @abc.abstractmethod def enter(self) -> None: ... + @abc.abstractmethod def exit(self) -> None: ... @@ -385,7 +380,7 @@ def _exit_current_state(self, next_state: State) -> None: # If we're just being constructed we may not have a state yet to exit, # in which case check the new state is the initial state if self._state is None: - if next_state.label != self.initial_state_label(): + if next_state.LABEL != self.initial_state_label(): raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state") return # Nothing to exit diff --git a/tests/test_processes.py b/tests/test_processes.py index eb5bf599..4b8cc606 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -653,7 +653,7 @@ def test_exception_during_on_entered(self): class RaisingProcess(Process): def on_entered(self, from_state): - if from_state is not None and from_state.label == ProcessState.RUNNING: + if from_state is not None and from_state.LABEL == ProcessState.RUNNING: raise RuntimeError('exception during on_entered') super().on_entered(from_state)