Skip to content

Commit

Permalink
Mapping states from state name
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 3, 2024
1 parent e5c74ad commit 17a541a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 19 deletions.
13 changes: 3 additions & 10 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class StateEntryFailed(Exception): # noqa: N818
Failed to enter a state, can provide the next state to go to via this exception
"""

def __init__(self, state: type['State'], *args: Any, **kwargs: Any) -> None:
def __init__(self, state: State, *args: Any, **kwargs: Any) -> None:
super().__init__('failed to enter state')
self.state = state
self.args = args
Expand Down Expand Up @@ -314,7 +314,7 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None:
def on_terminated(self) -> None:
"""Called when a terminal state is entered"""

def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> None:
def transition_to(self, new_state: State | None, **kwargs: Any) -> None:
"""Transite to the new state.
The new target state will be create lazily when the state is not yet instantiated,
Expand All @@ -331,11 +331,6 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) ->
label = None
try:
self._transitioning = True

if not isinstance(new_state, State):
# Make sure we have a state instance
new_state = self._create_state_instance(new_state, **kwargs)

label = new_state.LABEL

# If the previous transition failed, do not try to exit it but go straight to next state
Expand All @@ -345,9 +340,7 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) ->
try:
self._enter_next_state(new_state)
except StateEntryFailed as exception:
# Make sure we have a state instance
if not isinstance(exception.state, State):
new_state = self._create_state_instance(exception.state, **exception.kwargs)
new_state = exception.state
label = new_state.LABEL
self._exit_current_state(new_state)
self._enter_next_state(new_state)
Expand Down
20 changes: 15 additions & 5 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,9 @@ def on_finish(self, result: Any, successful: bool) -> None:
if successful:
validation_error = self.spec().outputs.validate(self.outputs)
if validation_error:
raise StateEntryFailed(process_states.Finished, result=result, successful=False)
state_cls = self.get_states_map()[process_states.ProcessState.FINISHED]
finished_state = state_cls(self, result=result, successful=False)
raise StateEntryFailed(finished_state)

self.future().set_result(self.outputs)

Expand Down Expand Up @@ -1064,7 +1066,9 @@ def transition_failed(
if final_state == process_states.ProcessState.CREATED:
raise exception.with_traceback(trace)

self.transition_to(process_states.Excepted, exception=exception, trace_back=trace)
state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED]
new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace)
self.transition_to(new_state)

def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]:
"""Pause the process.
Expand Down Expand Up @@ -1127,7 +1131,9 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu

def do_kill(_next_state: process_states.State) -> Any:
try:
self.transition_to(process_states.Killed, msg=exception.msg)
state_class = self.get_states_map()[process_states.ProcessState.KILLED]
new_state = self._create_state_instance(state_class, msg=exception.msg)
self.transition_to(new_state)
return True
finally:
self._killing = None
Expand Down Expand Up @@ -1179,7 +1185,9 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac
:param exception: The exception that caused the failure
:param trace_back: Optional exception traceback
"""
self.transition_to(process_states.Excepted, exception=exception, trace_back=trace_back)
state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED]
new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace_back)
self.transition_to(new_state)

def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
"""
Expand Down Expand Up @@ -1207,7 +1215,9 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]
self._state.interrupt(interrupt_exception)
return cast(futures.CancellableAction, self._interrupt_action)

self.transition_to(process_states.Killed, msg=msg)
state_class = self.get_states_map()[process_states.ProcessState.KILLED]
new_state = self._create_state_instance(state_class, msg=msg)
self.transition_to(new_state)
return True

@property
Expand Down
9 changes: 5 additions & 4 deletions tests/base/test_statemachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,15 @@ class Paused(state_machine.State):
def __init__(self, player, playing_state):
assert isinstance(playing_state, Playing), 'Must provide the playing state to pause'
super().__init__(player)
self._player = player
self.playing_state = playing_state

def __str__(self):
return f'|| ({self.playing_state})'

def play(self, track=None):
if track is not None:
self.state_machine.transition_to(Playing, track=track)
self.state_machine.transition_to(Playing(player=self.state_machine, track=track))
else:
self.state_machine.transition_to(self.playing_state)

Expand All @@ -80,7 +81,7 @@ def __str__(self):
return '[]'

def play(self, track):
self.state_machine.transition_to(Playing, track=track)
self.state_machine.transition_to(Playing(self.state_machine, track=track))


class CdPlayer(state_machine.StateMachine):
Expand All @@ -107,12 +108,12 @@ def play(self, track=None):

@state_machine.event(from_states=Playing, to_states=Paused)
def pause(self):
self.transition_to(Paused, playing_state=self._state)
self.transition_to(Paused(self, playing_state=self._state))
return True

@state_machine.event(from_states=(Playing, Paused), to_states=Stopped)
def stop(self):
self.transition_to(Stopped)
self.transition_to(Stopped(self))


class TestStateMachine(unittest.TestCase):
Expand Down

0 comments on commit 17a541a

Please sign in to comment.