Skip to content

Commit

Permalink
WIP: create_state refact
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 3, 2024
1 parent 531c696 commit c747b06
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 38 deletions.
21 changes: 5 additions & 16 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def interrupt(self, reason: Exception) -> None: ...

@runtime_checkable
class Proceedable(Protocol):

def execute(self) -> State | None:
"""
Execute the state, performing the actions that this state is responsible for.
Expand Down Expand Up @@ -356,17 +355,13 @@ def get_debug(self) -> bool:
def set_debug(self, enabled: bool) -> None:
self._debug: bool = enabled

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
# XXX: this method create state from label, which is duplicate as _create_state_instance and less generic
# because the label is defined after the state and required to be know before calling this function.
# This method should be replaced by `_create_state_instance`.
# aiida-core using this method for its Waiting state override.
try:
state_cls = self.get_states_map()[state_label]
return state_cls(self, *args, **kwargs)
except KeyError:
def create_state(self, state_label: Hashable, **kwargs: Any) -> State:
if state_label not in self.get_states_map():
raise ValueError(f'{state_label} is not a valid state')

state_cls = self.get_states_map()[state_label]
return state_cls(**kwargs)

def _exit_current_state(self, next_state: State) -> None:
"""Exit the given state"""

Expand All @@ -389,9 +384,3 @@ def _enter_next_state(self, next_state: State) -> None:
next_state.enter()
self._state = next_state
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)

def _create_state_instance(self, state_cls: type[State], **kwargs: Any) -> State:
if state_cls.LABEL not in self.get_states_map():
raise ValueError(f'{state_cls.LABEL} is not a valid state')

return state_cls(self, **kwargs)
28 changes: 11 additions & 17 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,10 @@ async def execute(self) -> state_machine.State:
# Let this bubble up to the caller
raise
except Exception:
excepted = self.process.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:])
# state_cls: Excepted = self.process.get_states_map()[ProcessState.EXCEPTED]
_, exception, traceback = sys.exc_info()
# excepted = state_cls(exception=exception, traceback=traceback)
excepted = Excepted(exception=exception, traceback=traceback)
return cast(state_machine.State, excepted)
else:
if not isinstance(result, Command):
Expand Down Expand Up @@ -367,10 +370,10 @@ def exit(self) -> None: ...
@final
class Excepted(persistence.Savable):
"""
Excepted state, can optionally provide exception and trace_back
Excepted state, can optionally provide exception and traceback
:param exception: The exception instance
:param trace_back: An optional exception traceback
:param traceback: An optional exception traceback
"""

LABEL = ProcessState.EXCEPTED
Expand All @@ -383,18 +386,15 @@ class Excepted(persistence.Savable):

def __init__(
self,
process: 'Process',
exception: Optional[BaseException],
trace_back: Optional[TracebackType] = None,
traceback: Optional[TracebackType] = None,
):
"""
:param process: The associated process
:param exception: The exception instance
:param trace_back: An optional exception traceback
:param traceback: An optional exception traceback
"""
self.process = process
self.exception = exception
self.traceback = trace_back
self.traceback = traceback

def __str__(self) -> str:
exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0]
Expand All @@ -408,7 +408,6 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.process = load_context.process

self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader)
if _HAS_TBLIB:
Expand Down Expand Up @@ -450,14 +449,12 @@ class Finished(persistence.Savable):

is_terminal = True

def __init__(self, process: 'Process', result: Any, successful: bool) -> None:
self.process = process
def __init__(self, result: Any, successful: bool) -> None:
self.result = result
self.successful = successful

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.process = load_context.process

def enter(self) -> None: ...

Expand All @@ -481,17 +478,14 @@ class Killed(persistence.Savable):

is_terminal = True

def __init__(self, process: 'Process', msg: Optional[MessageType]):
def __init__(self, msg: Optional[MessageType]):
"""
:param process: The associated process
:param msg: Optional kill message
"""
self.process = process
self.msg = msg

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.process = load_context.process

def enter(self) -> None: ...

Expand Down
13 changes: 8 additions & 5 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ def transition_failed(
if final_state == process_states.ProcessState.CREATED:
raise exception.with_traceback(trace)

self.transition_to(process_states.Excepted(self, exception=exception, trace_back=trace))
self.transition_to(process_states.Excepted(self, exception=exception, traceback=trace))

def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]:
"""Pause the process.
Expand Down Expand Up @@ -1190,13 +1190,13 @@ def resume(self, *args: Any) -> None:
return self._state.resume(*args) # type: ignore

@event(to_states=process_states.Excepted)
def fail(self, exception: Optional[BaseException], trace_back: Optional[TracebackType]) -> None:
def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
"""
Fail the process in response to an exception
:param exception: The exception that caused the failure
:param trace_back: Optional exception traceback
:param traceback: Optional exception traceback
"""
self.transition_to(process_states.Excepted(self, exception=exception, trace_back=trace_back))
self.transition_to(process_states.Excepted(self, exception=exception, traceback=traceback))

def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
"""
Expand Down Expand Up @@ -1315,7 +1315,10 @@ async def step(self) -> None:
raise
except Exception:
# Overwrite the next state to go to excepted directly
next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:])
_, exception, traceback = sys.exc_info()
next_state = self.create_state(
process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback
)
self._set_interrupt_action(None)

if self._interrupt_action:
Expand Down

0 comments on commit c747b06

Please sign in to comment.