-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix recursive workgraph #336
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -1032,6 +1032,13 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None | |||||
|
||||||
self.report(f"Run task: {name}, type: {task['metadata']['node_type']}") | ||||||
executor, _ = get_executor(task["executor"]) | ||||||
# Add the executor to the globals so that it can be used in the task | ||||||
# in the case of recursive workgraph | ||||||
# We also need to rebuild the Task calss and attach it to the executor | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
if task["metadata"]["node_type"].upper() == "GRAPH_BUILDER": | ||||||
task_class = Task.from_dict(self.ctx._tasks[name]) | ||||||
executor.node = executor.task = task_class.__class__ | ||||||
executor.__globals__[executor.__name__] = executor | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My concern is that executors will have the same name as it is not always unique and then override each other. For calcfunctions for example it is the name of the function which is reasonable, as it is only overridden if a function with the same name is defined. But for the case of local functions, this overrides any global definition. from aiida_workgraph import task, WorkGraph
from aiida.engine import calcfunction
from aiida import load_profile
load_profile()
@task.graph_builder()
def my_add():
@calcfunction
def add(x, y):
return x+y
wg = WorkGraph()
task = wg.add_task(add, x=1, y=1)
return wg
wg = my_add()
print(wg.tasks["add1"].get_executor()['name']) # out 'add' but better 'my_add.add' I guess solving this issue requires much more work and since we don't have any examples defining calcfunctions locally (however we have to load codes locally in graph_builder), I think it is not very crucial, but an issue would be nice to keep this in mind. |
||||||
# print("executor: ", executor) | ||||||
args, kwargs, var_args, var_kwargs, args_dict = self.get_inputs(name) | ||||||
for i, key in enumerate(self.ctx._tasks[name]["args"]): | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What kind of task is this? When I create add a task in WorkGraph there is no
task["executor"]
but I can dotask.get_executor()