Skip to content

Commit

Permalink
bugfix can_fail determination logic
Browse files Browse the repository at this point in the history
  • Loading branch information
nathandf committed Nov 16, 2023
1 parent 4723a16 commit 77a5a8b
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions src/engine/src/core/workflows/executors/WorkflowExecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,34 +435,33 @@ def _set_tasks(self, tasks):

self.state.tasks = tasks

# Build 2 graphs:
# The first is a mapping between each task and the tasks that depend on them,
# and the second is a mapping between a task and tasks it depends on.
# Suboptimal? Yes, Space complexity is ~O(n^2), but makes for easy lookups
# Build the dependency graph where the key is a task id and the value is
# an array of all the dependent tasks
self.state.dependency_graph = {task.id: [] for task in self.state.tasks}
self.state.reverse_dependency_graph = {task.id: [] for task in self.state.tasks}
for child_task in self.state.tasks:
for parent_task in child_task.depends_on:
self.state.dependency_graph[parent_task.id].append(child_task.id)
self.state.reverse_dependency_graph[child_task.id].append(parent_task.id)
for task in self.state.tasks:
for parent_task in task.depends_on:
self.state.dependency_graph[parent_task.id].append(task.id)

# Determine if a task can fail and set the tasks' can_fail flags.
# A parent task is permitted to fail iff all of the following criteria are met:
# - It has children
# - All can_fail flags for a given parent task's children's task_dependency object == True
pprint(self.state.dependency_graph)
print()
pprint(self.state.reverse_dependency_graph)
for task in self.state.tasks:
can_fail_flags = []
for child_task_id in self.state.reverse_dependency_graph[task.id]:
# Get the task
child_task = self._get_task_by_id(child_task_id)
dep = next(filter(lambda dep: dep.id == task.id, child_task.depends_on))
can_fail_flags.append(dep.can_fail)

# If the length of can_fail_flags == 0, then this task has no child tasks
task.can_fail = False if len(can_fail_flags) == 0 else all(can_fail_flags)
try:
for parent_task_id in self.state.dependency_graph:
child_tasks = [
self._get_task_by_id(child_task_id)
for child_task_id
in self.state.dependency_graph[parent_task_id]
]
parent_can_fail_flags = []
for child_task in child_tasks:
dep = next(filter(lambda dep: dep.id == parent_task_id, child_task.depends_on))
parent_can_fail_flags.append(dep.can_fail)

# If the length of can_fail_flags == 0, then this task has no child tasks
child_task.can_fail = False if len(parent_can_fail_flags) == 0 else all(parent_can_fail_flags)
except Exception as e:
raise Exception(f"Error resolving can_fail flag for parent task '{parent_task_id}': {e}")

# Detect loops in the graph
try:
Expand Down

0 comments on commit 77a5a8b

Please sign in to comment.