Skip to content

Commit

Permalink
Remove the middle layer of statemachine.State + Savable abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 2, 2024
1 parent e5c74ad commit 26b50fd
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 49 deletions.
2 changes: 1 addition & 1 deletion docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ py:class kiwipy.communications.Communicator

# unavailable forward references
py:class plumpy.process_states.Command
py:class plumpy.process_states.State
py:class plumpy.state_machine.State
py:class plumpy.base.state_machine.State
py:class State
py:class Process
Expand Down
111 changes: 77 additions & 34 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.state_machine = load_context.process
try:
self.continue_fn = utils.load_function(saved_state[self.CONTINUE_FN])
except ValueError:
Expand All @@ -145,25 +146,8 @@ class ProcessState(Enum):
KILLED: str = 'killed'


@auto_persist('in_state')
class State(state_machine.State, persistence.Savable):
@property
def process(self) -> state_machine.StateMachine:
"""
:return: The process
"""
return self.state_machine

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.state_machine = load_context.process

def interrupt(self, reason: Any) -> None:
pass


@auto_persist('args', 'kwargs')
class Created(State):
@auto_persist('args', 'kwargs', 'in_state')
class Created(state_machine.State, persistence.Savable):
LABEL = ProcessState.CREATED
ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED}

Expand All @@ -182,14 +166,23 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.state_machine = load_context.process

self.run_fn = getattr(self.process, saved_state[self.RUN_FN])

def execute(self) -> state_machine.State:
return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs)

@property
def process(self) -> state_machine.StateMachine:
"""
:return: The process
"""
return self.state_machine

@auto_persist('args', 'kwargs')
class Running(State):

@auto_persist('args', 'kwargs', 'in_state')
class Running(state_machine.State, persistence.Savable):
LABEL = ProcessState.RUNNING
ALLOWED = {
ProcessState.RUNNING,
Expand Down Expand Up @@ -223,14 +216,16 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.state_machine = load_context.process

self.run_fn = getattr(self.process, saved_state[self.RUN_FN])
if self.COMMAND in saved_state:
self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore

def interrupt(self, reason: Any) -> None:
pass

async def execute(self) -> State: # type: ignore
async def execute(self) -> state_machine.State: # type: ignore
if self._command is not None:
command = self._command
else:
Expand All @@ -245,7 +240,7 @@ async def execute(self) -> State: # type: ignore
raise
except Exception:
excepted = self.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:])
return cast(State, excepted)
return cast(state_machine.State, excepted)
else:
if not isinstance(result, Command):
if isinstance(result, exceptions.UnsuccessfulResult):
Expand All @@ -259,7 +254,7 @@ async def execute(self) -> State: # type: ignore
next_state = self._action_command(command)
return next_state

def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State:
def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_machine.State:
if isinstance(command, Kill):
state = self.create_state(ProcessState.KILLED, command.msg)
# elif isinstance(command, Pause):
Expand All @@ -273,11 +268,18 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State:
else:
raise ValueError('Unrecognised command')

return cast(State, state) # casting from base.State to process.State
return cast(state_machine.State, state) # casting from base.State to process.State

@property
def process(self) -> state_machine.StateMachine:
"""
:return: The process
"""
return self.state_machine

@auto_persist('msg', 'data')
class Waiting(State):

@auto_persist('msg', 'data', 'in_state')
class Waiting(state_machine.State, persistence.Savable):
LABEL = ProcessState.WAITING
ALLOWED = {
ProcessState.RUNNING,
Expand Down Expand Up @@ -317,6 +319,8 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.state_machine = load_context.process

callback_name = saved_state.get(self.DONE_CALLBACK, None)
if callback_name is not None:
self.done_callback = getattr(self.process, callback_name)
Expand All @@ -328,7 +332,7 @@ def interrupt(self, reason: Any) -> None:
# This will cause the future in execute() to raise the exception
self._waiting_future.set_exception(reason)

async def execute(self) -> State: # type: ignore
async def execute(self) -> state_machine.State: # type: ignore
try:
result = await self._waiting_future
except Interruption:
Expand All @@ -343,7 +347,7 @@ async def execute(self) -> State: # type: ignore
else:
next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result)

return cast(State, next_state) # casting from base.State to process.State
return cast(state_machine.State, next_state) # casting from base.State to process.State

def resume(self, value: Any = NULL) -> None:
assert self._waiting_future is not None, 'Not yet waiting'
Expand All @@ -353,8 +357,16 @@ def resume(self, value: Any = NULL) -> None:

self._waiting_future.set_result(value)

@property
def process(self) -> state_machine.StateMachine:
"""
:return: The process
"""
return self.state_machine


class Excepted(State):
@auto_persist('in_state')
class Excepted(state_machine.State, persistence.Savable):
"""
Excepted state, can optionally provide exception and trace_back
Expand Down Expand Up @@ -394,6 +406,8 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.state_machine = load_context.process

self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader)
if _HAS_TBLIB:
try:
Expand All @@ -415,9 +429,16 @@ def get_exc_info(
self.traceback,
)

@property
def process(self) -> state_machine.StateMachine:
"""
:return: The process
"""
return self.state_machine


@auto_persist('result', 'successful')
class Finished(State):
@auto_persist('result', 'successful', 'in_state')
class Finished(state_machine.State, persistence.Savable):
"""State for process is finished.
:param result: The result of process
Expand All @@ -431,9 +452,20 @@ def __init__(self, process: 'Process', result: Any, successful: bool) -> None:
self.result = result
self.successful = successful

@property
def process(self) -> state_machine.StateMachine:
"""
:return: The process
"""
return self.state_machine

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.state_machine = load_context.process


@auto_persist('msg')
class Killed(State):
@auto_persist('msg', 'in_state')
class Killed(state_machine.State, persistence.Savable):
"""
Represents a state where a process has been killed.
Expand All @@ -453,5 +485,16 @@ def __init__(self, process: 'Process', msg: Optional[MessageType]):
super().__init__(process)
self.msg = msg

@property
def process(self) -> state_machine.StateMachine:
"""
:return: The process
"""
return self.state_machine

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.state_machine = load_context.process


# endregion
26 changes: 13 additions & 13 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def current(cls) -> Optional['Process']:
return None

@classmethod
def get_states(cls) -> Sequence[Type[process_states.State]]:
def get_states(cls) -> Sequence[Type[state_machine.State]]:
"""Return all allowed states of the process."""
state_classes = cls.get_state_classes()
return (
Expand All @@ -186,7 +186,7 @@ def get_states(cls) -> Sequence[Type[process_states.State]]:
)

@classmethod
def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]:
def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]:
# A mapping of the State constants to the corresponding state class
return {
process_states.ProcessState.CREATED: process_states.Created,
Expand Down Expand Up @@ -357,10 +357,10 @@ def _setup_event_hooks(self) -> None:
"""Set the event hooks to process, when it is created or loaded(recreated)."""
event_hooks = {
state_machine.StateEventHook.ENTERING_STATE: lambda _s, _h, state: self.on_entering(
cast(process_states.State, state)
cast(state_machine.State, state)
),
state_machine.StateEventHook.ENTERED_STATE: lambda _s, _h, from_state: self.on_entered(
cast(Optional[process_states.State], from_state)
cast(Optional[state_machine.State], from_state)
),
state_machine.StateEventHook.EXITING_STATE: lambda _s, _h, _state: self.on_exiting(),
}
Expand Down Expand Up @@ -661,7 +661,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi
else:
self._loop = asyncio.get_event_loop()

self._state: process_states.State = self.recreate_state(saved_state['_state'])
self._state: state_machine.State = self.recreate_state(saved_state['_state'])

if 'communicator' in load_context:
self._communicator = load_context.communicator
Expand Down Expand Up @@ -719,7 +719,7 @@ def log_with_pid(self, level: int, msg: str) -> None:

# region Events

def on_entering(self, state: process_states.State) -> None:
def on_entering(self, state: state_machine.State) -> None:
# Map these onto direct functions that the subclass can implement
state_label = state.LABEL
if state_label == process_states.ProcessState.CREATED:
Expand All @@ -735,7 +735,7 @@ def on_entering(self, state: process_states.State) -> None:
elif state_label == process_states.ProcessState.EXCEPTED:
call_with_super_check(self.on_except, state.get_exc_info()) # type: ignore

def on_entered(self, from_state: Optional[process_states.State]) -> None:
def on_entered(self, from_state: Optional[state_machine.State]) -> None:
# Map these onto direct functions that the subclass can implement
state_label = self._state.LABEL
if state_label == process_states.ProcessState.RUNNING:
Expand Down Expand Up @@ -1099,7 +1099,7 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable

return self._do_pause(msg)

def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool:
def _do_pause(self, state_msg: Optional[str], next_state: Optional[state_machine.State] = None) -> bool:
"""Carry out the pause procedure, optionally transitioning to the next state first"""
try:
if next_state is not None:
Expand All @@ -1125,7 +1125,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu

if isinstance(exception, process_states.KillInterruption):

def do_kill(_next_state: process_states.State) -> Any:
def do_kill(_next_state: state_machine.State) -> Any:
try:
self.transition_to(process_states.Killed, msg=exception.msg)
return True
Expand Down Expand Up @@ -1217,27 +1217,27 @@ def is_killing(self) -> bool:

# endregion

def create_initial_state(self) -> process_states.State:
def create_initial_state(self) -> state_machine.State:
"""This method is here to override its superclass.
Automatically enter the CREATED state when the process is created.
:return: A Created state
"""
return cast(
process_states.State,
state_machine.State,
self.get_state_class(process_states.ProcessState.CREATED)(self, self.run),
)

def recreate_state(self, saved_state: persistence.Bundle) -> process_states.State:
def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State:
"""
Create a state object from a saved state
:param saved_state: The saved state
:return: An instance of the object with its state loaded from the save state.
"""
load_context = persistence.LoadSaveContext(process=self)
return cast(process_states.State, persistence.Savable.load(saved_state, load_context))
return cast(state_machine.State, persistence.Savable.load(saved_state, load_context))

# endregion

Expand Down
4 changes: 3 additions & 1 deletion src/plumpy/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

import kiwipy

from plumpy.base import state_machine

from . import lang, mixins, persistence, process_states, processes
from .utils import PID_TYPE, SAVED_STATE_TYPE

Expand Down Expand Up @@ -117,7 +119,7 @@ class WorkChain(mixins.ContextMixin, processes.Process):
_CONTEXT = 'CONTEXT'

@classmethod
def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]:
def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]:
states_map = super().get_state_classes()
states_map[process_states.ProcessState.WAITING] = Waiting
return states_map
Expand Down

0 comments on commit 26b50fd

Please sign in to comment.