Skip to content

Commit

Permalink
Furthur simplipy _create_state_instant only create state from class
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 1, 2024
1 parent 28db202 commit dad61ac
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
31 changes: 19 additions & 12 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
Set,
Type,
Union,
cast,
)

from plumpy.futures import Future
Expand Down Expand Up @@ -325,8 +324,18 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None:
def on_terminated(self) -> None:
"""Called when a terminal state is entered"""

def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None:
assert not self._transitioning, 'Cannot call transition_to when already transitioning state'
def transition_to(
self, new_state: Union[State, Type[State]], *args: Any, **kwargs: Any
) -> None:
"""Transite to the new state.
The new target state will be create lazily when the state
is not yet instantiated, which will happened for states not in the expect path such as
pause and kill.
"""
assert (
not self._transitioning
), "Cannot call transition_to when already transitioning state"

initial_state_label = self._state.LABEL if self._state is not None else None
label = None
Expand Down Expand Up @@ -389,6 +398,10 @@ def set_debug(self, enabled: bool) -> None:
self._debug: bool = enabled

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
# XXX: this method create state from label, which is duplicate as _create_state_instance and less generic
# because the label is defined after the state and required to be know before calling this function.
# This method should be replaced by `_create_state_instance`.
# aiida-core using this method for its Waiting state override.
try:
return self.get_states_map()[state_label](self, *args, **kwargs)
except KeyError:
Expand Down Expand Up @@ -422,15 +435,9 @@ def _enter_next_state(self, next_state: State) -> None:
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)

def _create_state_instance(
self, state: Union[Hashable, Type[State]], *args: Any, **kwargs: Any
self, state_cls: type[State], *args: Any, **kwargs: Any
) -> State:
# build from state class
if inspect.isclass(state) and issubclass(state, State):
state_cls = state
else:
try:
state_cls = self.get_states_map()[cast(Hashable, state)] # pylint: disable=unsubscriptable-object
except KeyError:
raise ValueError(f"{state} is not a valid state")
if state_cls.LABEL not in self.get_states_map():
raise ValueError(f"{state_cls.LABEL} is not a valid state")

return state_cls(self, *args, **kwargs)
10 changes: 5 additions & 5 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ def on_finish(self, result: Any, successful: bool) -> None:
if successful:
validation_error = self.spec().outputs.validate(self.outputs)
if validation_error:
raise StateEntryFailed(process_states.ProcessState.FINISHED, result, False)
raise StateEntryFailed(process_states.Finished, result, False)

self.future().set_result(self.outputs)

Expand Down Expand Up @@ -1016,7 +1016,7 @@ def transition_failed(
if final_state == process_states.ProcessState.CREATED:
raise exception.with_traceback(trace)

self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace)
self.transition_to(process_states.Excepted, exception, trace)

def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]:
"""Pause the process.
Expand Down Expand Up @@ -1081,7 +1081,7 @@ def do_kill(_next_state: process_states.State) -> Any:
try:
# Ignore the next state
# __import__('ipdb').set_trace()
self.transition_to(process_states.ProcessState.KILLED, exception)
self.transition_to(process_states.Killed, exception)
return True
finally:
self._killing = None
Expand Down Expand Up @@ -1133,7 +1133,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac
:param exception: The exception that caused the failure
:param trace_back: Optional exception traceback
"""
self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace_back)
self.transition_to(process_states.Excepted, exception, trace_back)

def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
"""
Expand Down Expand Up @@ -1161,7 +1161,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]
self._state.interrupt(interrupt_exception)
return cast(futures.CancellableAction, self._interrupt_action)

self.transition_to(process_states.ProcessState.KILLED, msg)
self.transition_to(process_states.Killed, msg)
return True

@property
Expand Down

0 comments on commit dad61ac

Please sign in to comment.