diff --git a/docs/source/howto/html/while_task.html b/docs/source/howto/html/while_task.html
index b00c2eff..1239e9a5 100644
--- a/docs/source/howto/html/while_task.html
+++ b/docs/source/howto/html/while_task.html
@@ -61,7 +61,7 @@
const { RenderUtils } = ReteRenderUtils;
const styled = window.styled;
- const workgraphData = {"name": "while_task", "uuid": "11cc817c-5e20-11ef-9a7c-906584de3e5b", "state": "CREATED", "nodes": {"add1": {"label": "add1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "11d730f4-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "11d72a96-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "11d73144-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "11d72a96-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}], "outputs": [{"name": "_wait"}], "position": [30, 30], "children": []}, "compare1": {"label": "compare1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "11dfbf8a-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "11dfb9cc-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "11dfbfe4-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "11dfb9cc-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}], "outputs": [{"name": "result"}], "position": [60, 60], "children": []}, "while3": {"label": "while3", "node_type": "WHILE", "inputs": [{"name": "conditions"}], "outputs": [], "position": [90, 90], "children": ["add2", "multiply1"]}, "add2": {"label": "add2", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "11f043be-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "11f03e0a-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "11f04418-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "11f03e0a-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "_wait"}], "outputs": [{"name": "result"}], "position": [120, 120], "children": []}, "multiply1": {"label": "multiply1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "11f84e9c-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "11f8485c-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "add2", "from_socket": "result", "from_socket_uuid": "11f044e0-5e20-11ef-9a7c-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "11f84f00-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "11f8485c-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "x"}], "outputs": [{"name": "result"}], "position": [150, 150], "children": []}, "add3": {"label": "add3", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "1204244c-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "1204186c-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "multiply1", "from_socket": "result", "from_socket_uuid": "11f84fc8-5e20-11ef-9a7c-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "12042546-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "1204186c-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "x"}], "outputs": [], "position": [180, 180], "children": []}}, "links": [{"from_socket": "result", "from_node": "compare1", "from_socket_uuid": "11dfc0b6-5e20-11ef-9a7c-906584de3e5b", "to_socket": "conditions", "to_node": "while3", "state": false}, {"from_socket": "result", "from_node": "add2", "from_socket_uuid": "11f044e0-5e20-11ef-9a7c-906584de3e5b", "to_socket": "x", "to_node": "multiply1", "state": false}, {"from_socket": "result", "from_node": "multiply1", "from_socket_uuid": "11f84fc8-5e20-11ef-9a7c-906584de3e5b", "to_socket": "x", "to_node": "add3", "state": false}, {"from_node": "add1", "from_socket": "_wait", "to_node": "add2", "to_socket": "_wait"}]}
+ const nodegraphData = {"name": "while_task", "uuid": "2251e0fe-b934-11ef-a5ab-906584de3e5b", "state": "CREATED", "nodes": {"add1": {"label": "add1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any"}, {"name": "y", "identifier": "workgraph.any"}], "properties": {"x": {"identifier": "workgraph.any", "value": null}, "y": {"identifier": "workgraph.any", "value": null}, "_wait": {"identifier": "workgraph.any", "value": null}}, "outputs": [{"name": "_wait"}], "position": [30, 30], "children": []}, "compare1": {"label": "compare1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any"}, {"name": "y", "identifier": "workgraph.any"}], "properties": {"x": {"identifier": "workgraph.any", "value": null}, "y": {"identifier": "workgraph.any", "value": null}, "_wait": {"identifier": "workgraph.any", "value": null}}, "outputs": [{"name": "result"}], "position": [60, 60], "children": []}, "while3": {"label": "while3", "node_type": "WHILE", "inputs": [{"name": "conditions"}, {"name": "_wait"}], "properties": {"_wait": {"identifier": "workgraph.any", "value": null}, "max_iterations": {"identifier": "node_graph.int", "value": null}, "conditions": {"identifier": "workgraph.any", "value": null}}, "outputs": [], "position": [90, 90], "children": ["add2", "multiply1"]}, "add2": {"label": "add2", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any"}, {"name": "y", "identifier": "workgraph.any"}], "properties": {"x": {"identifier": "workgraph.any", "value": null}, "y": {"identifier": "workgraph.any", "value": null}, "_wait": {"identifier": "workgraph.any", "value": null}}, "outputs": [{"name": "result"}], "position": [120, 120], "children": []}, "multiply1": {"label": "multiply1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any"}, {"name": "y", "identifier": "workgraph.any"}, {"name": "x"}], "properties": {"x": {"identifier": "workgraph.any", "value": null}, "y": {"identifier": "workgraph.any", "value": null}, "_wait": {"identifier": "workgraph.any", "value": null}}, "outputs": [{"name": "result"}], "position": [150, 150], "children": []}, "add3": {"label": "add3", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any"}, {"name": "y", "identifier": "workgraph.any"}, {"name": "x"}], "properties": {"x": {"identifier": "workgraph.any", "value": null}, "y": {"identifier": "workgraph.any", "value": null}, "_wait": {"identifier": "workgraph.any", "value": null}}, "outputs": [], "position": [180, 180], "children": []}}, "links": [{"from_socket": "result", "from_node": "compare1", "to_socket": "conditions", "to_node": "while3", "state": false}, {"from_socket": "result", "from_node": "add2", "to_socket": "x", "to_node": "multiply1", "state": false}, {"from_socket": "result", "from_node": "multiply1", "to_socket": "x", "to_node": "add3", "state": false}, {"from_node": "add1", "from_socket": "_wait", "to_node": "while3", "to_socket": "_wait"}]}
// Define Schemes to use in vanilla JS
const Schemes = {
@@ -133,21 +133,21 @@
}
}
- async function loadJSON(editor, area, layout, workgraphData) {
- for (const nodeId in workgraphData.nodes) {
- const nodeData = workgraphData.nodes[nodeId];
+ async function loadJSON(editor, area, layout, nodegraphData) {
+ for (const nodeId in nodegraphData.nodes) {
+ const nodeData = nodegraphData.nodes[nodeId];
await addNode(editor, area, nodeData);
}
- // Adding connections based on workgraphData
- workgraphData.links.forEach(async (link) => { // Specify the type of link here
+ // Adding connections based on nodegraphData
+ nodegraphData.links.forEach(async (link) => { // Specify the type of link here
await addLink(editor, area, layout, link);
});
// Add while zones
console.log("Adding while zone: ");
- for (const nodeId in workgraphData.nodes) {
- const nodeData = workgraphData.nodes[nodeId];
+ for (const nodeId in nodegraphData.nodes) {
+ const nodeData = nodegraphData.nodes[nodeId];
const node_type = nodeData['node_type'];
if (node_type === "WHILE" || node_type === "IF" || node_type === "ZONE") {
// find the node
@@ -162,6 +162,18 @@
}
}
+ /**
+ * Defines custom padding for a scope layout.
+ * The padding values are used by the ScopesPlugin to avoid node overlapping with the socket of the parent node.
+ */
+ const customScopePadding = () => ({
+ top: 80,
+ left: 30,
+ right: 30,
+ bottom: 50
+ });
+
+
async function createEditor(container) {
const socket = new ClassicPreset.Socket("socket");
@@ -169,7 +181,7 @@
const area = new AreaPlugin(container);
const connection = new ConnectionPlugin();
const render = new ReactPlugin({ createRoot });
- const scopes = new ScopesPlugin();
+ const scopes = new ScopesPlugin({padding: customScopePadding});
const arrange = new AutoArrangePlugin();
const minimap = new MinimapPlugin({
@@ -209,7 +221,7 @@
AreaExtensions.zoomAt(area, editor.getNodes());
}
- // Adding nodes based on workgraphData
+ // Adding nodes based on nodegraphData
const nodeMap = {}; // To keep track of created nodes for linking
editor.nodeMap = nodeMap;
@@ -238,7 +250,7 @@
if (containerRef.current && !editor) {
createEditor(containerRef.current).then((editor) => {
setEditor(editor);
- loadJSON(editor.editor, editor.area, editor.layout, workgraphData).then(() => {
+ loadJSON(editor.editor, editor.area, editor.layout, nodegraphData).then(() => {
// aplly layout twice to ensure all nodes are arranged
editor?.layout(false).then(() => editor?.layout(true));
});
diff --git a/docs/source/howto/if.ipynb b/docs/source/howto/if.ipynb
index eabca7ac..1e1471f0 100644
--- a/docs/source/howto/if.ipynb
+++ b/docs/source/howto/if.ipynb
@@ -149,15 +149,13 @@
"wg = WorkGraph(\"if_task\")\n",
"add1 = wg.add_task(add, name=\"add1\", x=1, y=1)\n",
"condition1 = wg.add_task(compare, name=\"condition1\", x=1, y=0)\n",
- "add2 = wg.add_task(add, name=\"add2\", x=add1.outputs[\"result\"], y=2)\n",
- "if1 = wg.add_task(\"If\", name=\"if_true\",\n",
+ "if_true_zone = wg.add_task(\"If\", name=\"if_true\",\n",
" conditions=condition1.outputs[\"result\"])\n",
- "if1.children.add(\"add2\")\n",
- "multiply1 = wg.add_task(multiply, name=\"multiply1\", x=add1.outputs[\"result\"], y=2)\n",
- "if2 = wg.add_task(\"If\", name=\"if_false\",\n",
+ "add2 = if_true_zone.add_task(add, name=\"add2\", x=add1.outputs[\"result\"], y=2)\n",
+ "if_false_zone = wg.add_task(\"If\", name=\"if_false\",\n",
" conditions=condition1.outputs[\"result\"],\n",
" invert_condition=True)\n",
- "if2.children.add(\"multiply1\")\n",
+ "multiply1 = if_false_zone.add_task(multiply, name=\"multiply1\", x=add1.outputs[\"result\"], y=2)\n",
"#---------------------------------------------------------------------\n",
"select1 = wg.add_task(\"workgraph.select\", name=\"select1\", true=add2.outputs[\"result\"],\n",
" false=multiply1.outputs[\"result\"],\n",
@@ -377,17 +375,17 @@
"source": [
"# Create a WorkGraph which is dynamically generated based on the input\n",
"# then we output the result of from the context (context)\n",
- "@task.graph_builder(outputs = [{\"name\": \"result\", \"from\": \"context.result\"}])\n",
+ "@task.graph_builder(outputs = [{\"name\": \"result\", \"from\": \"context.data\"}])\n",
"def add_multiply_if(x, y):\n",
" wg = WorkGraph()\n",
" if x.value > 0:\n",
" add1 = wg.add_task(add, name=\"add1\", x=x, y=y)\n",
- " # export the result of add1 to the context\n",
- " add1.set_context({\"result\": \"result\"})\n",
+ " # export the result of add1 to the context.data\n",
+ " add1.set_context({\"data\": \"result\"})\n",
" else:\n",
" multiply1 = wg.add_task(multiply, name=\"multiply1\", x=x, y=y)\n",
- " # export the result of multiply1 to the context\n",
- " multiply1.set_context({\"result\": \"result\"})\n",
+ " # export the result of multiply1 to the context.dadta\n",
+ " multiply1.set_context({\"data\": \"result\"})\n",
" return wg"
]
},
@@ -433,8 +431,8 @@
"\n",
"wg = WorkGraph(\"if_graph_builer\")\n",
"add1 = wg.add_task(add, name=\"add1\", x=1, y=1)\n",
- "add_multiply_if1 = wg.add_task(add_multiply_if, name=\"add_multiply_if1\", x=add1.outputs[\"result\"], y=2)\n",
- "add1 = wg.add_task(add, name=\"add2\", x=add_multiply_if1.outputs[\"result\"], y=1)\n",
+ "add_multiply_if1 = wg.add_task(add_multiply_if, name=\"add_multiply_if1\", x=add1.outputs.result, y=2)\n",
+ "add1 = wg.add_task(add, name=\"add2\", x=add_multiply_if1.outputs.result, y=1)\n",
"# export the workgraph to html file so that it can be visualized in a browser\n",
"wg.to_html()\n",
"# comment out the following line to visualize the workgraph in jupyter-notebook\n",
@@ -471,7 +469,7 @@
"source": [
"wg.submit(wait=True)\n",
"print(\"State of WorkGraph : {}\".format(wg.state))\n",
- "print('Result of add1 : {}'.format(wg.tasks[\"add2\"].outputs[\"result\"].value))"
+ "print('Result of add1 : {}'.format(wg.tasks.add2.outputs.result.value))"
]
},
{
diff --git a/docs/source/howto/while.ipynb b/docs/source/howto/while.ipynb
index af3eff2e..d9da34c2 100644
--- a/docs/source/howto/while.ipynb
+++ b/docs/source/howto/while.ipynb
@@ -35,20 +35,28 @@
"```\n",
"\n",
"### Adding tasks to the While loop\n",
- "We can add tasks to the `While` task using the `children` attribute.\n",
+ "We can add tasks to the `While` zone using the `add_task` method.\n",
"\n",
"```python\n",
- "# add task1 and task2 to the while loop\n",
- "while_task.children.add([\"task1\", \"task2\"])\n",
+ "# Add a new task to the while zone\n",
+ "while_zone1.add_task(add, name=\"task1\", a=1, b=2)\n",
"```\n"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"id": "8f5e7642",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Result: uuid: ddc3f239-3776-4746-8c1b-9b7ae9c5f2a4 (pk: 2551) value: 63\n"
+ ]
+ }
+ ],
"source": [
"from aiida_workgraph import WorkGraph, task\n",
"from aiida import load_profile\n",
@@ -92,7 +100,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 2,
"id": "8ee799d2-0b5b-4609-957f-6b3f2cd451f0",
"metadata": {},
"outputs": [
@@ -111,10 +119,10 @@
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 10,
+ "execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
@@ -128,18 +136,17 @@
"add1 = wg.add_task(add, name=\"add1\", x=1, y=1)\n",
"add1.set_context({\"n\": \"result\"})\n",
"#---------------------------------------------------------------------\n",
- "# Create the while tasks\n",
+ "# Create the while zone\n",
"compare1 = wg.add_task(compare, name=\"compare1\", x=\"{{n}}\", y=50)\n",
- "while1 = wg.add_task(\"While\", max_iterations=100, conditions=compare1.outputs[\"result\"])\n",
- "# Create the tasks in the while loop.\n",
- "add2 = wg.add_task(add, name=\"add2\", x=\"{{n}}\", y=1)\n",
- "add2.waiting_on.add(\"add1\")\n",
- "multiply1 = wg.add_task(multiply, name=\"multiply1\",\n",
+ "while_zone1 = wg.add_task(\"While\", max_iterations=100, conditions=compare1.outputs[\"result\"])\n",
+ "while_zone1.waiting_on.add(\"add1\")\n",
+ "# Create the tasks in the while zone.\n",
+ "add2 = while_zone1.add_task(add, name=\"add2\", x=\"{{n}}\", y=1)\n",
+ "multiply1 = while_zone1.add_task(multiply, name=\"multiply1\",\n",
" x=add2.outputs[\"result\"],\n",
" y=2)\n",
"# update the context variable\n",
"multiply1.set_context({\"n\": \"result\"})\n",
- "while1.children.add([\"add2\", \"multiply1\"])\n",
"#---------------------------------------------------------------------\n",
"add3 = wg.add_task(add, name=\"add3\", x=1, y=1)\n",
"wg.add_link(multiply1.outputs[\"result\"], add3.inputs[\"x\"])\n",
@@ -165,7 +172,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 3,
"id": "9ebf35aa",
"metadata": {},
"outputs": [
@@ -173,9 +180,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "WorkGraph process created, PK: 117695\n",
+ "WorkGraph process created, PK: 2557\n",
+ "Process 2557 finished with state: FINISHED\n",
"State of WorkGraph: FINISHED\n",
- "Result of add1 : uuid: 061abd52-84fe-4d4e-b381-c303e4c25e19 (pk: 117741) value: 63\n"
+ "Result of add1 : uuid: 3b982d3c-3aaf-4c87-826c-fe2172939061 (pk: 2603) value: 63\n"
]
}
],
@@ -1632,7 +1640,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3.10.4 ('scinode')",
+ "display_name": "aiida",
"language": "python",
"name": "python3"
},
@@ -1647,11 +1655,6 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
- },
- "vscode": {
- "interpreter": {
- "hash": "2f450c1ff08798c4974437dd057310afef0de414c25d1fd960ad375311c3f6ff"
- }
}
},
"nbformat": 4,
diff --git a/docs/source/howto/zone.ipynb b/docs/source/howto/zone.ipynb
index ce1a970e..9bf80c70 100644
--- a/docs/source/howto/zone.ipynb
+++ b/docs/source/howto/zone.ipynb
@@ -58,12 +58,11 @@
"wg = WorkGraph(\"test_zone\")\n",
"wg.context = {}\n",
"add1 = wg.add_task(add, name=\"add1\", x=1, y=1)\n",
- "wg.add_task(add, name=\"add2\", x=1, y=1)\n",
- "add3 = wg.add_task(add, name=\"add3\", x=1, y=add1.outputs[\"result\"])\n",
- "wg.add_task(add, name=\"add4\", x=1, y=add3.outputs[\"result\"])\n",
- "wg.add_task(add, name=\"add5\", x=1, y=add3.outputs[\"result\"])\n",
- "zone1 = wg.add_task(\"workgraph.zone\", name=\"Zone1\")\n",
- "zone1.children.add([\"add2\", \"add3\", \"add4\"])\n",
+ "zone1 = wg.add_task(\"workgraph.zone\", name=\"zone1\")\n",
+ "zone1.add_task(add, name=\"add2\", x=1, y=1)\n",
+ "add3 = zone1.add_task(add, name=\"add3\", x=1, y=add1.outputs.result)\n",
+ "zone1.add_task(add, name=\"add4\", x=1, y=add3.outputs.result)\n",
+ "wg.add_task(add, name=\"add5\", x=1, y=add3.outputs.result)\n",
"# export the workgraph to html file so that it can be visualized in a browser\n",
"wg.to_html()\n",
"# comment out the following line to visualize the workgraph in jupyter-notebook\n",
diff --git a/src/aiida_workgraph/tasks/builtins.py b/src/aiida_workgraph/tasks/builtins.py
index 9460f911..8c7c93e1 100644
--- a/src/aiida_workgraph/tasks/builtins.py
+++ b/src/aiida_workgraph/tasks/builtins.py
@@ -16,6 +16,12 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.children = TaskCollection(parent=self)
+ def add_task(self, *args, **kwargs):
+ """Syntactic sugar to add a task to the zone."""
+ task = self.parent.add_task(*args, **kwargs)
+ self.children.add(task)
+ return task
+
def create_sockets(self) -> None:
self.inputs._clear()
self.outputs._clear()
diff --git a/tests/test_if.py b/tests/test_if.py
index c63b03b8..bc1bb0b2 100644
--- a/tests/test_if.py
+++ b/tests/test_if.py
@@ -7,9 +7,8 @@ def test_if_task(decorated_add, decorated_multiply, decorated_compare):
wg = WorkGraph("test_if")
add1 = wg.add_task(decorated_add, name="add1", x=1, y=1)
condition1 = wg.add_task(decorated_compare, name="condition1", x=1, y=0)
- add2 = wg.add_task(decorated_add, name="add2", x=add1.outputs.result, y=2)
- if1 = wg.add_task("If", name="if_true", conditions=condition1.outputs.result)
- if1.children.add("add2")
+ if_zone = wg.add_task("If", name="if_true", conditions=condition1.outputs.result)
+ add2 = if_zone.add_task(decorated_add, name="add2", x=add1.outputs.result, y=2)
multiply1 = wg.add_task(
decorated_multiply, name="multiply1", x=add1.outputs.result, y=2
)
diff --git a/tests/test_zone.py b/tests/test_zone.py
index b46d575d..b0043607 100644
--- a/tests/test_zone.py
+++ b/tests/test_zone.py
@@ -6,17 +6,13 @@ def test_zone_task(decorated_add):
"""Test the zone task."""
wg = WorkGraph("test_zone")
- wg.context = {}
add1 = wg.add_task(decorated_add, name="add1", x=1, y=1)
- wg.add_task(decorated_add, name="add2", x=1, y=1)
- add3 = wg.add_task(decorated_add, name="add3", x=1, y=add1.outputs.result)
- wg.add_task(decorated_add, name="add4", x=1, y=add3.outputs.result)
- wg.add_task(decorated_add, name="add5", x=1, y=add3.outputs.result)
- zone1 = wg.add_task("workgraph.zone", name="Zone1")
- zone1.children.add(["add2", "add3"])
+ zone1 = wg.add_task("workgraph.zone", name="zone1")
+ zone1.add_task(decorated_add, name="add2", x=1, y=1)
+ zone1.add_task(decorated_add, name="add3", x=1, y=add1.outputs.result)
+ wg.add_task(decorated_add, name="add4", x=1, y=wg.tasks.add2.outputs.result)
+ wg.add_task(decorated_add, name="add5", x=1, y=wg.tasks.add3.outputs.result)
wg.run()
report = get_workchain_report(wg.process, "REPORT")
- assert "tasks ready to run: add1" in report
assert "tasks ready to run: add2,add3" in report
- assert "tasks ready to run: add4" in report
- assert "tasks ready to run: add5" in report
+ assert "tasks ready to run: add4,add5" in report