diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 5b0841b6..53ca396a 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -147,7 +147,6 @@ def interrupt(self, reason: Exception) -> None: ... @runtime_checkable class Proceedable(Protocol): - def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. @@ -356,17 +355,13 @@ def get_debug(self) -> bool: def set_debug(self, enabled: bool) -> None: self._debug: bool = enabled - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State: - # XXX: this method create state from label, which is duplicate as _create_state_instance and less generic - # because the label is defined after the state and required to be know before calling this function. - # This method should be replaced by `_create_state_instance`. - # aiida-core using this method for its Waiting state override. - try: - state_cls = self.get_states_map()[state_label] - return state_cls(self, *args, **kwargs) - except KeyError: + def create_state(self, state_label: Hashable, **kwargs: Any) -> State: + if state_label not in self.get_states_map(): raise ValueError(f'{state_label} is not a valid state') + state_cls = self.get_states_map()[state_label] + return state_cls(**kwargs) + def _exit_current_state(self, next_state: State) -> None: """Exit the given state""" @@ -389,9 +384,3 @@ def _enter_next_state(self, next_state: State) -> None: next_state.enter() self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) - - def _create_state_instance(self, state_cls: type[State], **kwargs: Any) -> State: - if state_cls.LABEL not in self.get_states_map(): - raise ValueError(f'{state_cls.LABEL} is not a valid state') - - return state_cls(self, **kwargs) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 4348803e..c6c7e8e8 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -241,7 +241,10 @@ async def execute(self) -> state_machine.State: # Let this bubble up to the caller raise except Exception: - excepted = self.process.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) + # state_cls: Excepted = self.process.get_states_map()[ProcessState.EXCEPTED] + _, exception, traceback = sys.exc_info() + # excepted = state_cls(exception=exception, traceback=traceback) + excepted = Excepted(exception=exception, traceback=traceback) return cast(state_machine.State, excepted) else: if not isinstance(result, Command): @@ -367,10 +370,10 @@ def exit(self) -> None: ... @final class Excepted(persistence.Savable): """ - Excepted state, can optionally provide exception and trace_back + Excepted state, can optionally provide exception and traceback :param exception: The exception instance - :param trace_back: An optional exception traceback + :param traceback: An optional exception traceback """ LABEL = ProcessState.EXCEPTED @@ -383,18 +386,15 @@ class Excepted(persistence.Savable): def __init__( self, - process: 'Process', exception: Optional[BaseException], - trace_back: Optional[TracebackType] = None, + traceback: Optional[TracebackType] = None, ): """ - :param process: The associated process :param exception: The exception instance - :param trace_back: An optional exception traceback + :param traceback: An optional exception traceback """ - self.process = process self.exception = exception - self.traceback = trace_back + self.traceback = traceback def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] @@ -408,7 +408,6 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.process = load_context.process self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -450,14 +449,12 @@ class Finished(persistence.Savable): is_terminal = True - def __init__(self, process: 'Process', result: Any, successful: bool) -> None: - self.process = process + def __init__(self, result: Any, successful: bool) -> None: self.result = result self.successful = successful def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.process = load_context.process def enter(self) -> None: ... @@ -481,17 +478,14 @@ class Killed(persistence.Savable): is_terminal = True - def __init__(self, process: 'Process', msg: Optional[MessageType]): + def __init__(self, msg: Optional[MessageType]): """ - :param process: The associated process :param msg: Optional kill message """ - self.process = process self.msg = msg def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.process = load_context.process def enter(self) -> None: ... diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 08899f9e..d01092f4 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1072,7 +1072,7 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to(process_states.Excepted(self, exception=exception, trace_back=trace)) + self.transition_to(process_states.Excepted(self, exception=exception, traceback=trace)) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. @@ -1190,13 +1190,13 @@ def resume(self, *args: Any) -> None: return self._state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) - def fail(self, exception: Optional[BaseException], trace_back: Optional[TracebackType]) -> None: + def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None: """ Fail the process in response to an exception :param exception: The exception that caused the failure - :param trace_back: Optional exception traceback + :param traceback: Optional exception traceback """ - self.transition_to(process_states.Excepted(self, exception=exception, trace_back=trace_back)) + self.transition_to(process_states.Excepted(self, exception=exception, traceback=traceback)) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ @@ -1315,7 +1315,10 @@ async def step(self) -> None: raise except Exception: # Overwrite the next state to go to excepted directly - next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:]) + _, exception, traceback = sys.exc_info() + next_state = self.create_state( + process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback + ) self._set_interrupt_action(None) if self._interrupt_action: