Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
nathandf committed Sep 5, 2024
1 parent 188199d commit 5f2ee22
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions src/engine/src/core/workflows/WorkflowExecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def start(self, ctx, threads):
)

if not validated:
self._on_pipeline_terminal_state(event=PIPELINE_FAILED, message=err)
self._handle_pipeline_terminal_state(event=PIPELINE_FAILED, message=err)
return

# Get the first tasks
Expand All @@ -177,7 +177,7 @@ def start(self, ctx, threads):
self.state.ready_tasks += unstarted_threads
except Exception as e:
# Trigger the terminal state callback.
self._on_pipeline_terminal_state(event=PIPELINE_FAILED, message=str(e))
self._handle_pipeline_terminal_state(event=PIPELINE_FAILED, message=str(e))

@interruptable()
def _staging(self, ctx):
Expand Down Expand Up @@ -216,9 +216,9 @@ def _staging(self, ctx):
try:
self._set_tasks(self.state.ctx.pipeline.tasks)
except InvalidDependenciesError as e:
self._on_pipeline_terminal_state(PIPELINE_FAILED, message=str(e))
self._handle_pipeline_terminal_state(PIPELINE_FAILED, message=str(e))
except Exception as e:
self._on_pipeline_terminal_state(PIPELINE_FAILED, message=str(e))
self._handle_pipeline_terminal_state(PIPELINE_FAILED, message=str(e))

@interruptable()
def _prepare_tasks(self):
Expand Down Expand Up @@ -269,7 +269,7 @@ def _prepare_tasks(self):
task = template_mapper.map(task, task.uses)
except Exception as e:
# Trigger the terminal state callback.
self._on_pipeline_terminal_state(event=PIPELINE_FAILED, message=str(e))
self._handle_pipeline_terminal_state(event=PIPELINE_FAILED, message=str(e))

# Add a key to the output for the task
self.state.ctx.output[task.id] = None
Expand Down Expand Up @@ -332,7 +332,7 @@ def _start_task(self, task):
task_input_file_staging_service.stage(task)
except TaskInputStagingError as e:
# Get the next queued tasks if any
unstarted_threads = self._on_task_terminal_state(
unstarted_threads = self._handle_task_terminal_state(
task,
TaskResult(1, errors=[str(e)])
)
Expand Down Expand Up @@ -372,21 +372,21 @@ def _start_task(self, task):
task_result = TaskResult(1, errors=[str(e)])

# Get the next queued tasks if any
unstarted_threads = self._on_task_terminal_state(task, task_result)
unstarted_threads = self._handle_task_terminal_state(task, task_result)

# NOTE Triggers hook _on_change_ready_task
self.state.ready_tasks += unstarted_threads

@interruptable()
def _on_task_terminal_state(self, task, task_result):
def _handle_task_terminal_state(self, task, task_result):
# Determine the correct callback to use.
callback = self._on_task_completed
callback = self._handle_task_completed

if task_result.skipped:
callback = self._on_task_skipped
callback = self._handle_task_skipped

if not task_result.success and not task_result.skipped:
callback = self._on_task_failed
callback = self._handle_task_failed

# Call the callback. Marks task as completed or failed.
# Also publishes a TASK_COMPLETED or TASK_FAILED based on the result
Expand All @@ -399,21 +399,22 @@ def _on_task_terminal_state(self, task, task_result):
# during the initialization and execution of the task executor
self._deregister_executor(self.state.ctx.pipeline_run.uuid, task)

# Run the on_pipeline_terminal_state callback if all tasks are complete.
# Run the handle_pipeline_terminal_state callback if all tasks are complete.
if len(self.state.tasks) == len(self.state.finished):
print("PIPELINE SHOWING AS COMPLETED")
self._on_pipeline_terminal_state(event=PIPELINE_COMPLETED)
print("*********** PIPELINE COMPLETED")
self._handle_pipeline_terminal_state(event=PIPELINE_COMPLETED)
return []

if task_result.status > 0 and task.can_fail == False:
self._on_pipeline_terminal_state(event=PIPELINE_FAILED)
print("*********** PIPELINE FAILED")
self._handle_pipeline_terminal_state(event=PIPELINE_FAILED)
return []

# Execute all possible queued tasks
return self._fetch_ready_tasks()

@interruptable()
def _on_pipeline_terminal_state(self, event=None, message=""):
def _handle_pipeline_terminal_state(self, event=None, message=""):
# No event was provided. Determine if complete or failed from number
# of failed tasks
if event == None:
Expand Down Expand Up @@ -443,7 +444,7 @@ def _on_pipeline_terminal_state(self, event=None, message=""):
self._set_initial_state()

@interruptable()
def _on_task_completed(self, task, task_result):
def _handle_task_completed(self, task, task_result):
# Log the completion
self.state.ctx.logger.info(self.t_log(task, "COMPLETED"))

Expand All @@ -456,7 +457,7 @@ def _on_task_completed(self, task, task_result):
self.state.succeeded.append(task.id)

@interruptable()
def _on_task_skipped(self, task, _):
def _handle_task_skipped(self, task, _):
# Log the task active
self.state.ctx.logger.info(self.t_log(task, "SKIPPED"))

Expand All @@ -468,7 +469,7 @@ def _on_task_skipped(self, task, _):
self.state.skipped.append(task.id)

@interruptable()
def _on_task_failed(self, task, task_result):
def _handle_task_failed(self, task, task_result):
# Log the failure
self.state.ctx.logger.info(self.t_log(task, f"FAILED: {task_result.errors}"))

Expand Down

0 comments on commit 5f2ee22

Please sign in to comment.