From 6bfb87df69889f0082e8ecbe7732d24ad69e64c2 Mon Sep 17 00:00:00 2001
From: Jusong Yu <jusong.yeu@gmail.com>
Date: Tue, 3 Dec 2024 00:48:53 +0100
Subject: [PATCH] Refactoring create_state as static function initialize state
 from label

create_state refact

Hashable initialized + parameters passed to Hashable

Fix pre-commit errors
---
 src/plumpy/base/state_machine.py |  45 +++---
 src/plumpy/process_states.py     | 235 ++++++++++++++++---------------
 src/plumpy/processes.py          |  42 +++---
 src/plumpy/workchains.py         |  10 +-
 tests/base/test_statemachine.py  |  15 +-
 5 files changed, 173 insertions(+), 174 deletions(-)

diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py
index 27b1e5f8..fc926008 100644
--- a/src/plumpy/base/state_machine.py
+++ b/src/plumpy/base/state_machine.py
@@ -34,7 +34,6 @@
 
 _LOGGER = logging.getLogger(__name__)
 
-LABEL_TYPE = Union[None, enum.Enum, str]
 EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None]
 
 
@@ -131,9 +130,12 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any:
 
 @runtime_checkable
 class State(Protocol):
-    LABEL: ClassVar[LABEL_TYPE]
+    LABEL: ClassVar[Any]
+    ALLOWED: ClassVar[set[Any]]
     is_terminal: ClassVar[bool]
 
+    def __init__(self, *args: Any, **kwargs: Any): ...
+
     def enter(self) -> None: ...
 
     def exit(self) -> None: ...
@@ -146,7 +148,6 @@ def interrupt(self, reason: Exception) -> None: ...
 
 @runtime_checkable
 class Proceedable(Protocol):
-
     def execute(self) -> State | None:
         """
         Execute the state, performing the actions that this state is responsible for.
@@ -155,6 +156,14 @@ def execute(self) -> State | None:
         ...
 
 
+def create_state(st: StateMachine, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
+    if state_label not in st.get_states_map():
+        raise ValueError(f'{state_label} is not a valid state')
+
+    state_cls = st.get_states_map()[state_label]
+    return state_cls(*args, **kwargs)
+
+
 class StateEventHook(enum.Enum):
     """
     Hooks that can be used to register callback at various points in the state transition
@@ -203,13 +212,13 @@ def get_states(cls) -> Sequence[Type[State]]:
         raise RuntimeError('States not defined')
 
     @classmethod
-    def initial_state_label(cls) -> LABEL_TYPE:
+    def initial_state_label(cls) -> Any:
         cls.__ensure_built()
         assert cls.STATES is not None
         return cls.STATES[0].LABEL
 
     @classmethod
-    def get_state_class(cls, label: LABEL_TYPE) -> Type[State]:
+    def get_state_class(cls, label: Any) -> Type[State]:
         cls.__ensure_built()
         assert cls._STATES_MAP is not None
         return cls._STATES_MAP[label]
@@ -253,11 +262,11 @@ def init(self) -> None:
     def __str__(self) -> str:
         return f'<{self.__class__.__name__}> ({self.state})'
 
-    def create_initial_state(self) -> State:
-        return self.get_state_class(self.initial_state_label())(self)
+    def create_initial_state(self, *args: Any, **kwargs: Any) -> State:
+        return self.get_state_class(self.initial_state_label())(self, *args, **kwargs)
 
     @property
-    def state(self) -> Optional[LABEL_TYPE]:
+    def state(self) -> Any:
         if self._state is None:
             return None
         return self._state.LABEL
@@ -297,6 +306,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None:
         The arguments are passed to the state class to create state instance.
         (process arg does not need to pass since it will always call with 'self' as process)
         """
+        print(f'try: {self._state} -> {new_state}')
         assert not self._transitioning, 'Cannot call transition_to when already transitioning state'
 
         if new_state is None:
@@ -353,17 +363,6 @@ def get_debug(self) -> bool:
     def set_debug(self, enabled: bool) -> None:
         self._debug: bool = enabled
 
-    def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
-        # XXX: this method create state from label, which is duplicate as _create_state_instance and less generic
-        # because the label is defined after the state and required to be know before calling this function.
-        # This method should be replaced by `_create_state_instance`.
-        # aiida-core using this method for its Waiting state override.
-        try:
-            state_cls = self.get_states_map()[state_label]
-            return state_cls(self, *args, **kwargs)
-        except KeyError:
-            raise ValueError(f'{state_label} is not a valid state')
-
     def _exit_current_state(self, next_state: State) -> None:
         """Exit the given state"""
 
@@ -375,7 +374,7 @@ def _exit_current_state(self, next_state: State) -> None:
             return  # Nothing to exit
 
         if next_state.LABEL not in self._state.ALLOWED:
-            raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}')
+            raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.LABEL}')
         self._fire_state_event(StateEventHook.EXITING_STATE, next_state)
         self._state.exit()
 
@@ -386,9 +385,3 @@ def _enter_next_state(self, next_state: State) -> None:
         next_state.enter()
         self._state = next_state
         self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)
-
-    def _create_state_instance(self, state_cls: type[State], **kwargs: Any) -> State:
-        if state_cls.LABEL not in self.get_states_map():
-            raise ValueError(f'{state_cls.LABEL} is not a valid state')
-
-        return state_cls(self, **kwargs)
diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py
index cc9169c7..5f3e8237 100644
--- a/src/plumpy/process_states.py
+++ b/src/plumpy/process_states.py
@@ -5,7 +5,20 @@
 import traceback
 from enum import Enum
 from types import TracebackType
-from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast, final
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    ClassVar,
+    Optional,
+    Protocol,
+    Tuple,
+    Type,
+    Union,
+    cast,
+    final,
+    runtime_checkable,
+)
 
 import yaml
 from yaml.loader import Loader
@@ -20,9 +33,9 @@
     _HAS_TBLIB = False
 
 from . import exceptions, futures, persistence, utils
-from .base import state_machine
+from .base import state_machine as st
 from .lang import NULL
-from .persistence import auto_persist
+from .persistence import LoadSaveContext, auto_persist
 from .utils import SAVED_STATE_TYPE
 
 __all__ = [
@@ -138,22 +151,28 @@ class ProcessState(Enum):
     The possible states that a :class:`~plumpy.processes.Process` can be in.
     """
 
-    CREATED: str = 'created'
-    RUNNING: str = 'running'
-    WAITING: str = 'waiting'
-    FINISHED: str = 'finished'
-    EXCEPTED: str = 'excepted'
-    KILLED: str = 'killed'
+    # FIXME: see LSP error of return a exception, the type is Literal[str] which is invariant, tricky
+    CREATED = 'created'
+    RUNNING = 'running'
+    WAITING = 'waiting'
+    FINISHED = 'finished'
+    EXCEPTED = 'excepted'
+    KILLED = 'killed'
+
+
+@runtime_checkable
+class Savable(Protocol):
+    def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ...
 
 
 @final
-@auto_persist('args', 'kwargs', 'in_state')
-class Created(state_machine.State, persistence.Savable):
-    LABEL = ProcessState.CREATED
-    ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED}
+@auto_persist('args', 'kwargs')
+class Created(persistence.Savable):
+    LABEL: ClassVar = ProcessState.CREATED
+    ALLOWED: ClassVar = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED}
 
     RUN_FN = 'run_fn'
-    is_terminal = False
+    is_terminal: ClassVar[bool] = False
 
     def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
         assert run_fn is not None
@@ -161,7 +180,6 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, *
         self.run_fn = run_fn
         self.args = args
         self.kwargs = kwargs
-        self.in_state = True
 
     def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None:
         super().save_instance_state(out_state, save_context)
@@ -173,24 +191,21 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi
 
         self.run_fn = getattr(self.process, saved_state[self.RUN_FN])
 
-    async def execute(self) -> state_machine.State:
-        return self.process.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs)
-
-    def enter(self) -> None:
-        self.in_state = True
+    def execute(self) -> st.State:
+        return st.create_state(
+            self.process, ProcessState.RUNNING, process=self.process, run_fn=self.run_fn, *self.args, **self.kwargs
+        )
 
-    def exit(self) -> None:
-        if self.is_terminal:
-            raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')
+    def enter(self) -> None: ...
 
-        self.in_state = False
+    def exit(self) -> None: ...
 
 
 @final
-@auto_persist('args', 'kwargs', 'in_state')
-class Running(state_machine.State, persistence.Savable):
-    LABEL = ProcessState.RUNNING
-    ALLOWED = {
+@auto_persist('args', 'kwargs')
+class Running(persistence.Savable):
+    LABEL: ClassVar = ProcessState.RUNNING
+    ALLOWED: ClassVar = {
         ProcessState.RUNNING,
         ProcessState.WAITING,
         ProcessState.FINISHED,
@@ -206,7 +221,7 @@ class Running(state_machine.State, persistence.Savable):
     _running: bool = False
     _run_handle = None
 
-    is_terminal = False
+    is_terminal: ClassVar[bool] = False
 
     def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
         assert run_fn is not None
@@ -215,7 +230,6 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, *
         self.args = args
         self.kwargs = kwargs
         self._run_handle = None
-        self.in_state = False
 
     def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None:
         super().save_instance_state(out_state, save_context)
@@ -234,7 +248,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi
     def interrupt(self, reason: Any) -> None:
         pass
 
-    async def execute(self) -> state_machine.State:
+    def execute(self) -> st.State:
         if self._command is not None:
             command = self._command
         else:
@@ -248,8 +262,10 @@ async def execute(self) -> state_machine.State:
                 # Let this bubble up to the caller
                 raise
             except Exception:
-                excepted = self.process.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:])
-                return cast(state_machine.State, excepted)
+                _, exception, traceback = sys.exc_info()
+                # excepted = state_cls(exception=exception, traceback=traceback)
+                excepted = Excepted(exception=exception, traceback=traceback)
+                return excepted
             else:
                 if not isinstance(result, Command):
                     if isinstance(result, exceptions.UnsuccessfulResult):
@@ -258,42 +274,52 @@ async def execute(self) -> state_machine.State:
                         # Got passed a basic return type
                         result = Stop(result, True)
 
-                command = result
+                command = cast(Stop, result)
 
         next_state = self._action_command(command)
         return next_state
 
-    def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_machine.State:
+    def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> st.State:
         if isinstance(command, Kill):
-            state = self.process.create_state(ProcessState.KILLED, command.msg)
+            state = st.create_state(self.process, ProcessState.KILLED, msg=command.msg)
         # elif isinstance(command, Pause):
         #     self.pause()
         elif isinstance(command, Stop):
-            state = self.process.create_state(ProcessState.FINISHED, command.result, command.successful)
+            state = st.create_state(
+                self.process, ProcessState.FINISHED, result=command.result, successful=command.successful
+            )
         elif isinstance(command, Wait):
-            state = self.process.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data)
+            state = st.create_state(
+                self.process,
+                ProcessState.WAITING,
+                process=self.process,
+                done_callback=command.continue_fn,
+                msg=command.msg,
+                data=command.data,
+            )
         elif isinstance(command, Continue):
-            state = self.process.create_state(ProcessState.RUNNING, command.continue_fn, *command.args)
+            state = st.create_state(
+                self.process,
+                ProcessState.RUNNING,
+                process=self.process,
+                run_fn=command.continue_fn,
+                *command.args,
+                **command.kwargs,
+            )
         else:
             raise ValueError('Unrecognised command')
 
-        return cast(state_machine.State, state)  # casting from base.State to process.State
-
-    def enter(self) -> None:
-        self.in_state = True
+        return state
 
-    def exit(self) -> None:
-        if self.is_terminal:
-            raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')
+    def enter(self) -> None: ...
 
-        self.in_state = False
+    def exit(self) -> None: ...
 
 
-@final
-@auto_persist('msg', 'data', 'in_state')
-class Waiting(state_machine.State, persistence.Savable):
-    LABEL = ProcessState.WAITING
-    ALLOWED = {
+@auto_persist('msg', 'data')
+class Waiting(persistence.Savable):
+    LABEL: ClassVar = ProcessState.WAITING
+    ALLOWED: ClassVar = {
         ProcessState.RUNNING,
         ProcessState.WAITING,
         ProcessState.KILLED,
@@ -305,7 +331,7 @@ class Waiting(state_machine.State, persistence.Savable):
 
     _interruption = None
 
-    is_terminal = False
+    is_terminal: ClassVar[bool] = False
 
     def __str__(self) -> str:
         state_info = super().__str__()
@@ -325,7 +351,6 @@ def __init__(
         self.msg = msg
         self.data = data
         self._waiting_future: futures.Future = futures.Future()
-        self.in_state = False
 
     def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None:
         super().save_instance_state(out_state, save_context)
@@ -347,7 +372,7 @@ def interrupt(self, reason: Exception) -> None:
         # This will cause the future in execute() to raise the exception
         self._waiting_future.set_exception(reason)
 
-    async def execute(self) -> state_machine.State:  # type: ignore
+    async def execute(self) -> st.State:
         try:
             result = await self._waiting_future
         except Interruption:
@@ -358,11 +383,15 @@ async def execute(self) -> state_machine.State:  # type: ignore
             raise
 
         if result == NULL:
-            next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback)
+            next_state = st.create_state(
+                self.process, ProcessState.RUNNING, process=self.process, run_fn=self.done_callback
+            )
         else:
-            next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback, result)
+            next_state = st.create_state(
+                self.process, ProcessState.RUNNING, process=self.process, done_callback=self.done_callback, *result
+            )
 
-        return cast(state_machine.State, next_state)  # casting from base.State to process.State
+        return next_state
 
     def resume(self, value: Any = NULL) -> None:
         assert self._waiting_future is not None, 'Not yet waiting'
@@ -372,47 +401,39 @@ def resume(self, value: Any = NULL) -> None:
 
         self._waiting_future.set_result(value)
 
-    def enter(self) -> None:
-        self.in_state = True
+    def enter(self) -> None: ...
 
-    def exit(self) -> None:
-        if self.is_terminal:
-            raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')
+    def exit(self) -> None: ...
 
-        self.in_state = False
 
-
-@auto_persist('in_state')
-class Excepted(state_machine.State, persistence.Savable):
+@final
+class Excepted(persistence.Savable):
     """
-    Excepted state, can optionally provide exception and trace_back
+    Excepted state, can optionally provide exception and traceback
 
     :param exception: The exception instance
-    :param trace_back: An optional exception traceback
+    :param traceback: An optional exception traceback
     """
 
-    LABEL = ProcessState.EXCEPTED
+    LABEL: ClassVar = ProcessState.EXCEPTED
+    ALLOWED: ClassVar[set[str]] = set()
 
     EXC_VALUE = 'ex_value'
     TRACEBACK = 'traceback'
 
-    is_terminal = True
+    is_terminal: ClassVar = True
 
     def __init__(
         self,
-        process: 'Process',
         exception: Optional[BaseException],
-        trace_back: Optional[TracebackType] = None,
+        traceback: Optional[TracebackType] = None,
     ):
         """
-        :param process: The associated process
         :param exception: The exception instance
-        :param trace_back: An optional exception traceback
+        :param traceback: An optional exception traceback
         """
-        self.process = process
         self.exception = exception
-        self.traceback = trace_back
-        self.in_state = False
+        self.traceback = traceback
 
     def __str__(self) -> str:
         exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0]
@@ -426,7 +447,6 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist
 
     def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
         super().load_instance_state(saved_state, load_context)
-        self.process = load_context.process
 
         self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader)
         if _HAS_TBLIB:
@@ -449,50 +469,40 @@ def get_exc_info(
             self.traceback,
         )
 
-    def enter(self) -> None:
-        self.in_state = True
+    def enter(self) -> None: ...
 
-    def exit(self) -> None:
-        if self.is_terminal:
-            raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')
+    def exit(self) -> None: ...
 
-        self.in_state = False
 
-
-@auto_persist('result', 'successful', 'in_state')
-class Finished(state_machine.State, persistence.Savable):
+@final
+@auto_persist('result', 'successful')
+class Finished(persistence.Savable):
     """State for process is finished.
 
     :param result: The result of process
     :param successful: Boolean for the exit code is ``0`` the process is successful.
     """
 
-    LABEL = ProcessState.FINISHED
+    LABEL: ClassVar = ProcessState.FINISHED
+    ALLOWED: ClassVar[set[str]] = set()
 
-    is_terminal = True
+    is_terminal: ClassVar[bool] = True
 
-    def __init__(self, process: 'Process', result: Any, successful: bool) -> None:
-        self.process = process
+    def __init__(self, result: Any, successful: bool) -> None:
         self.result = result
         self.successful = successful
-        self.in_state = False
 
     def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
         super().load_instance_state(saved_state, load_context)
-        self.process = load_context.process
 
-    def enter(self) -> None:
-        self.in_state = True
+    def enter(self) -> None: ...
 
-    def exit(self) -> None:
-        if self.is_terminal:
-            raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')
+    def exit(self) -> None: ...
 
-        self.in_state = False
 
-
-@auto_persist('msg', 'in_state')
-class Killed(state_machine.State, persistence.Savable):
+@final
+@auto_persist('msg')
+class Killed(persistence.Savable):
     """
     Represents a state where a process has been killed.
 
@@ -502,30 +512,23 @@ class Killed(state_machine.State, persistence.Savable):
     :param msg: An optional message explaining the reason for the process termination.
     """
 
-    LABEL = ProcessState.KILLED
+    LABEL: ClassVar = ProcessState.KILLED
+    ALLOWED: ClassVar[set[str]] = set()
 
-    is_terminal = True
+    is_terminal: ClassVar[bool] = True
 
-    def __init__(self, process: 'Process', msg: Optional[MessageType]):
+    def __init__(self, msg: Optional[MessageType]):
         """
-        :param process: The associated process
         :param msg: Optional kill message
         """
-        self.process = process
         self.msg = msg
 
     def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
         super().load_instance_state(saved_state, load_context)
-        self.process = load_context.process
-
-    def enter(self) -> None:
-        self.in_state = True
 
-    def exit(self) -> None:
-        if self.is_terminal:
-            raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')
+    def enter(self) -> None: ...
 
-        self.in_state = False
+    def exit(self) -> None: ...
 
 
 # endregion
diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py
index 74808291..bae08dd4 100644
--- a/src/plumpy/processes.py
+++ b/src/plumpy/processes.py
@@ -1,6 +1,8 @@
 # -*- coding: utf-8 -*-
 """The main Process module"""
 
+from __future__ import annotations
+
 import abc
 import asyncio
 import contextlib
@@ -58,6 +60,7 @@
     StateMachine,
     StateMachineError,
     TransitionFailed,
+    create_state,
     event,
 )
 from .base.utils import call_with_super_check, super_check
@@ -194,7 +197,7 @@ def get_states(cls) -> Sequence[Type[state_machine.State]]:
         )
 
     @classmethod
-    def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]:
+    def get_state_classes(cls) -> dict[process_states.ProcessState, Type[state_machine.State]]:
         # A mapping of the State constants to the corresponding state class
         return {
             process_states.ProcessState.CREATED: process_states.Created,
@@ -633,7 +636,9 @@ def save_instance_state(
         """
         super().save_instance_state(out_state, save_context)
 
-        out_state['_state'] = self._state.save()
+        # FIXME: the combined ProcessState protocol should cover the case
+        if isinstance(self._state, process_states.Savable):
+            out_state['_state'] = self._state.save()
 
         # Inputs/outputs
         if self.raw_inputs is not None:
@@ -876,7 +881,7 @@ def on_finish(self, result: Any, successful: bool) -> None:
             validation_error = self.spec().outputs.validate(self.outputs)
             if validation_error:
                 state_cls = self.get_states_map()[process_states.ProcessState.FINISHED]
-                finished_state = state_cls(self, result=result, successful=False)
+                finished_state = state_cls(result=result, successful=False)
                 raise StateEntryFailed(finished_state)
 
         self.future().set_result(self.outputs)
@@ -1074,8 +1079,8 @@ def transition_failed(
         if final_state == process_states.ProcessState.CREATED:
             raise exception.with_traceback(trace)
 
-        state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED]
-        new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace)
+        # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED]
+        new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace)
         self.transition_to(new_state)
 
     def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]:
@@ -1148,10 +1153,11 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu
 
             def do_kill(_next_state: state_machine.State) -> Any:
                 try:
-                    state_class = self.get_states_map()[process_states.ProcessState.KILLED]
-                    new_state = self._create_state_instance(state_class, msg=exception.msg)
+                    new_state = create_state(self, process_states.ProcessState.KILLED, msg=exception.msg)
                     self.transition_to(new_state)
                     return True
+                    # FIXME: if try block except, will hit deadlock in event loop
+                    # need to know how to debug it, and where to set a timeout.
                 finally:
                     self._killing = None
 
@@ -1196,14 +1202,14 @@ def resume(self, *args: Any) -> None:
         return self._state.resume(*args)  # type: ignore
 
     @event(to_states=process_states.Excepted)
-    def fail(self, exception: Optional[BaseException], trace_back: Optional[TracebackType]) -> None:
+    def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
         """
         Fail the process in response to an exception
         :param exception: The exception that caused the failure
-        :param trace_back: Optional exception traceback
+        :param traceback: Optional exception traceback
         """
-        state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED]
-        new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace_back)
+        # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED]
+        new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback)
         self.transition_to(new_state)
 
     def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
@@ -1223,7 +1229,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]
             # Already killing
             return self._killing
 
-        if self._stepping:
+        if self._stepping and isinstance(self._state, Interruptable):
             # Ask the step function to pause by setting this flag and giving the
             # caller back a future
             interrupt_exception = process_states.KillInterruption(msg)
@@ -1232,8 +1238,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]
             self._state.interrupt(interrupt_exception)
             return cast(futures.CancellableAction, self._interrupt_action)
 
-        state_class = self.get_states_map()[process_states.ProcessState.KILLED]
-        new_state = self._create_state_instance(state_class, msg=msg)
+        new_state = create_state(self, process_states.ProcessState.KILLED, msg=msg)
         self.transition_to(new_state)
         return True
 
@@ -1251,10 +1256,7 @@ def create_initial_state(self) -> state_machine.State:
 
         :return: A Created state
         """
-        return cast(
-            state_machine.State,
-            self.get_state_class(process_states.ProcessState.CREATED)(self, self.run),
-        )
+        return self.get_state_class(process_states.ProcessState.CREATED)(self, self.run)
 
     def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State:
         """
@@ -1325,7 +1327,9 @@ async def step(self) -> None:
                 raise
             except Exception:
                 # Overwrite the next state to go to excepted directly
-                next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:])
+                next_state = create_state(
+                    self, process_states.ProcessState.EXCEPTED, exception=sys.exc_info()[1], traceback=sys.exc_info()[2]
+                )
                 self._set_interrupt_action(None)
 
             if self._interrupt_action:
diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py
index eefd57f1..865a5b61 100644
--- a/src/plumpy/workchains.py
+++ b/src/plumpy/workchains.py
@@ -11,7 +11,6 @@
     Any,
     Callable,
     Dict,
-    Hashable,
     List,
     Mapping,
     MutableSequence,
@@ -71,6 +70,7 @@ def get_outline(self) -> Union['_Instruction', '_FunctionCall']:
         return self._outline
 
 
+# FIXME:  better use composition here
 @persistence.auto_persist('_awaiting')
 class Waiting(process_states.Waiting):
     """Overwrite the waiting state"""
@@ -80,11 +80,11 @@ def __init__(
         process: 'WorkChain',
         done_callback: Optional[Callable[..., Any]],
         msg: Optional[str] = None,
-        awaiting: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None,
+        data: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None,
     ) -> None:
-        super().__init__(process, done_callback, msg, awaiting)
+        super().__init__(process, done_callback, msg, data)
         self._awaiting: Dict[asyncio.Future, str] = {}
-        for awaitable, key in (awaiting or {}).items():
+        for awaitable, key in (data or {}).items():
             resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable
             self._awaiting[resolved_awaitable] = key
 
@@ -124,7 +124,7 @@ class WorkChain(mixins.ContextMixin, processes.Process):
     _CONTEXT = 'CONTEXT'
 
     @classmethod
-    def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]:
+    def get_state_classes(cls) -> Dict[process_states.ProcessState, Type[state_machine.State]]:
         states_map = super().get_state_classes()
         states_map[process_states.ProcessState.WAITING] = Waiting
         return states_map
diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py
index b6100614..6a61fe00 100644
--- a/tests/base/test_statemachine.py
+++ b/tests/base/test_statemachine.py
@@ -17,7 +17,7 @@
 STOPPED = 'Stopped'
 
 
-class Playing(state_machine.State):
+class Playing:
     LABEL = PLAYING
     ALLOWED = {PAUSED, STOPPED}
     TRANSITIONS = {STOP: STOPPED}
@@ -56,7 +56,7 @@ def exit(self) -> None:
         self.in_state = False
 
 
-class Paused(state_machine.State):
+class Paused:
     LABEL = PAUSED
     ALLOWED = {PLAYING, STOPPED}
     TRANSITIONS = {STOP: STOPPED}
@@ -65,7 +65,6 @@ class Paused(state_machine.State):
 
     def __init__(self, player, playing_state):
         assert isinstance(playing_state, Playing), 'Must provide the playing state to pause'
-        super().__init__(player)
         self._player = player
         self.playing_state = playing_state
 
@@ -74,9 +73,9 @@ def __str__(self):
 
     def play(self, track=None):
         if track is not None:
-            self.state_machine.transition_to(Playing(player=self.state_machine, track=track))
+            self._player.transition_to(Playing(player=self.state_machine, track=track))
         else:
-            self.state_machine.transition_to(self.playing_state)
+            self._player.transition_to(self.playing_state)
 
     def enter(self) -> None:
         self.in_state = True
@@ -88,7 +87,7 @@ def exit(self) -> None:
         self.in_state = False
 
 
-class Stopped(state_machine.State):
+class Stopped:
     LABEL = STOPPED
     ALLOWED = {
         PLAYING,
@@ -98,13 +97,13 @@ class Stopped(state_machine.State):
     is_terminal = False
 
     def __init__(self, player):
-        self.state_machine = player
+        self._player = player
 
     def __str__(self):
         return '[]'
 
     def play(self, track):
-        self.state_machine.transition_to(Playing(self.state_machine, track=track))
+        self._player.transition_to(Playing(self._player, track=track))
 
     def enter(self) -> None:
         self.in_state = True