Skip to content

Commit

Permalink
Furthur simplipy _create_state_instant only create state from class
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Nov 30, 2024
1 parent 989f995 commit e551eff
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 16 deletions.
20 changes: 9 additions & 11 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class StateEntryFailed(Exception):
Failed to enter a state, can provide the next state to go to via this exception
"""

def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None: # pylint: disable=keyword-arg-before-vararg
def __init__(self, state: type[State] = None, *args: Any, **kwargs: Any) -> None: # pylint: disable=keyword-arg-before-vararg
super().__init__("failed to enter state")
self.state = state
self.args = args
Expand Down Expand Up @@ -330,7 +330,7 @@ def on_terminated(self) -> None:
"""Called when a terminal state is entered"""

def transition_to(
self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any
self, new_state: Union[State, Type[State]], *args: Any, **kwargs: Any
) -> None:
"""Transite to the new state.
Expand Down Expand Up @@ -403,6 +403,10 @@ def set_debug(self, enabled: bool) -> None:
self._debug: bool = enabled

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
# XXX: this method create state from label, which is duplicate as _create_state_instance and less generic
# because the label is defined after the state and required to be know before calling this function.
# This method should be replaced by `_create_state_instance`.
# aiida-core using this method for its Waiting state override.
try:
return self.get_states_map()[state_label](self, *args, **kwargs) # pylint: disable=unsubscriptable-object
except KeyError:
Expand Down Expand Up @@ -436,15 +440,9 @@ def _enter_next_state(self, next_state: State) -> None:
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)

def _create_state_instance(
self, state: Union[Hashable, Type[State]], *args: Any, **kwargs: Any
self, state_cls: type[State], *args: Any, **kwargs: Any
) -> State:
# build from state class
if inspect.isclass(state) and issubclass(state, State):
state_cls = state
else:
try:
state_cls = self.get_states_map()[cast(Hashable, state)] # pylint: disable=unsubscriptable-object
except KeyError:
raise ValueError(f"{state} is not a valid state")
if state_cls.LABEL not in self.get_states_map():
raise ValueError(f"{state_cls.LABEL} is not a valid state")

return state_cls(self, *args, **kwargs)
10 changes: 5 additions & 5 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def on_finish(self, result: Any, successful: bool) -> None:
if successful:
validation_error = self.spec().outputs.validate(self.outputs)
if validation_error:
raise StateEntryFailed(process_states.ProcessState.FINISHED, result, False)
raise StateEntryFailed(process_states.Finished, result, False)

self.future().set_result(self.outputs)

Expand Down Expand Up @@ -1017,7 +1017,7 @@ def transition_failed(
if final_state == process_states.ProcessState.CREATED:
raise exception.with_traceback(trace)

self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace)
self.transition_to(process_states.Excepted, exception, trace)

def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]:
"""Pause the process.
Expand Down Expand Up @@ -1082,7 +1082,7 @@ def do_kill(_next_state: process_states.State) -> Any:
try:
# Ignore the next state
# __import__('ipdb').set_trace()
self.transition_to(process_states.ProcessState.KILLED, exception)
self.transition_to(process_states.Killed, exception)
return True
finally:
self._killing = None
Expand Down Expand Up @@ -1134,7 +1134,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac
:param exception: The exception that caused the failure
:param trace_back: Optional exception traceback
"""
self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace_back)
self.transition_to(process_states.Excepted, exception, trace_back)

def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
"""
Expand Down Expand Up @@ -1162,7 +1162,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]
self._state.interrupt(interrupt_exception)
return cast(futures.CancellableAction, self._interrupt_action)

self.transition_to(process_states.ProcessState.KILLED, msg)
self.transition_to(process_states.Killed, msg)
return True

@property
Expand Down

0 comments on commit e551eff

Please sign in to comment.