Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Syntactic sugar to add a task to the zone #391

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions docs/source/howto/html/while_task.html
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
const { RenderUtils } = ReteRenderUtils;
const styled = window.styled;

const workgraphData = {"name": "while_task", "uuid": "11cc817c-5e20-11ef-9a7c-906584de3e5b", "state": "CREATED", "nodes": {"add1": {"label": "add1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "11d730f4-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "11d72a96-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "11d73144-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "11d72a96-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}], "outputs": [{"name": "_wait"}], "position": [30, 30], "children": []}, "compare1": {"label": "compare1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "11dfbf8a-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "11dfb9cc-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "11dfbfe4-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "11dfb9cc-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}], "outputs": [{"name": "result"}], "position": [60, 60], "children": []}, "while3": {"label": "while3", "node_type": "WHILE", "inputs": [{"name": "conditions"}], "outputs": [], "position": [90, 90], "children": ["add2", "multiply1"]}, "add2": {"label": "add2", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "11f043be-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "11f03e0a-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "11f04418-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "11f03e0a-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "_wait"}], "outputs": [{"name": "result"}], "position": [120, 120], "children": []}, "multiply1": {"label": "multiply1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "11f84e9c-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "11f8485c-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "add2", "from_socket": "result", "from_socket_uuid": "11f044e0-5e20-11ef-9a7c-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "11f84f00-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "11f8485c-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "x"}], "outputs": [{"name": "result"}], "position": [150, 150], "children": []}, "add3": {"label": "add3", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "1204244c-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "1204186c-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "multiply1", "from_socket": "result", "from_socket_uuid": "11f84fc8-5e20-11ef-9a7c-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "12042546-5e20-11ef-9a7c-906584de3e5b", "node_uuid": "1204186c-5e20-11ef-9a7c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "x"}], "outputs": [], "position": [180, 180], "children": []}}, "links": [{"from_socket": "result", "from_node": "compare1", "from_socket_uuid": "11dfc0b6-5e20-11ef-9a7c-906584de3e5b", "to_socket": "conditions", "to_node": "while3", "state": false}, {"from_socket": "result", "from_node": "add2", "from_socket_uuid": "11f044e0-5e20-11ef-9a7c-906584de3e5b", "to_socket": "x", "to_node": "multiply1", "state": false}, {"from_socket": "result", "from_node": "multiply1", "from_socket_uuid": "11f84fc8-5e20-11ef-9a7c-906584de3e5b", "to_socket": "x", "to_node": "add3", "state": false}, {"from_node": "add1", "from_socket": "_wait", "to_node": "add2", "to_socket": "_wait"}]}
const nodegraphData = {"name": "while_task", "uuid": "2251e0fe-b934-11ef-a5ab-906584de3e5b", "state": "CREATED", "nodes": {"add1": {"label": "add1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any"}, {"name": "y", "identifier": "workgraph.any"}], "properties": {"x": {"identifier": "workgraph.any", "value": null}, "y": {"identifier": "workgraph.any", "value": null}, "_wait": {"identifier": "workgraph.any", "value": null}}, "outputs": [{"name": "_wait"}], "position": [30, 30], "children": []}, "compare1": {"label": "compare1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any"}, {"name": "y", "identifier": "workgraph.any"}], "properties": {"x": {"identifier": "workgraph.any", "value": null}, "y": {"identifier": "workgraph.any", "value": null}, "_wait": {"identifier": "workgraph.any", "value": null}}, "outputs": [{"name": "result"}], "position": [60, 60], "children": []}, "while3": {"label": "while3", "node_type": "WHILE", "inputs": [{"name": "conditions"}, {"name": "_wait"}], "properties": {"_wait": {"identifier": "workgraph.any", "value": null}, "max_iterations": {"identifier": "node_graph.int", "value": null}, "conditions": {"identifier": "workgraph.any", "value": null}}, "outputs": [], "position": [90, 90], "children": ["add2", "multiply1"]}, "add2": {"label": "add2", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any"}, {"name": "y", "identifier": "workgraph.any"}], "properties": {"x": {"identifier": "workgraph.any", "value": null}, "y": {"identifier": "workgraph.any", "value": null}, "_wait": {"identifier": "workgraph.any", "value": null}}, "outputs": [{"name": "result"}], "position": [120, 120], "children": []}, "multiply1": {"label": "multiply1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any"}, {"name": "y", "identifier": "workgraph.any"}, {"name": "x"}], "properties": {"x": {"identifier": "workgraph.any", "value": null}, "y": {"identifier": "workgraph.any", "value": null}, "_wait": {"identifier": "workgraph.any", "value": null}}, "outputs": [{"name": "result"}], "position": [150, 150], "children": []}, "add3": {"label": "add3", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any"}, {"name": "y", "identifier": "workgraph.any"}, {"name": "x"}], "properties": {"x": {"identifier": "workgraph.any", "value": null}, "y": {"identifier": "workgraph.any", "value": null}, "_wait": {"identifier": "workgraph.any", "value": null}}, "outputs": [], "position": [180, 180], "children": []}}, "links": [{"from_socket": "result", "from_node": "compare1", "to_socket": "conditions", "to_node": "while3", "state": false}, {"from_socket": "result", "from_node": "add2", "to_socket": "x", "to_node": "multiply1", "state": false}, {"from_socket": "result", "from_node": "multiply1", "to_socket": "x", "to_node": "add3", "state": false}, {"from_node": "add1", "from_socket": "_wait", "to_node": "while3", "to_socket": "_wait"}]}

// Define Schemes to use in vanilla JS
const Schemes = {
Expand Down Expand Up @@ -133,21 +133,21 @@
}
}

async function loadJSON(editor, area, layout, workgraphData) {
for (const nodeId in workgraphData.nodes) {
const nodeData = workgraphData.nodes[nodeId];
async function loadJSON(editor, area, layout, nodegraphData) {
for (const nodeId in nodegraphData.nodes) {
const nodeData = nodegraphData.nodes[nodeId];
await addNode(editor, area, nodeData);
}

// Adding connections based on workgraphData
workgraphData.links.forEach(async (link) => { // Specify the type of link here
// Adding connections based on nodegraphData
nodegraphData.links.forEach(async (link) => { // Specify the type of link here
await addLink(editor, area, layout, link);
});

// Add while zones
console.log("Adding while zone: ");
for (const nodeId in workgraphData.nodes) {
const nodeData = workgraphData.nodes[nodeId];
for (const nodeId in nodegraphData.nodes) {
const nodeData = nodegraphData.nodes[nodeId];
const node_type = nodeData['node_type'];
if (node_type === "WHILE" || node_type === "IF" || node_type === "ZONE") {
// find the node
Expand All @@ -162,14 +162,26 @@
}
}

/**
* Defines custom padding for a scope layout.
* The padding values are used by the ScopesPlugin to avoid node overlapping with the socket of the parent node.
*/
const customScopePadding = () => ({
top: 80,
left: 30,
right: 30,
bottom: 50
});


async function createEditor(container) {
const socket = new ClassicPreset.Socket("socket");

const editor = new NodeEditor(Schemes);
const area = new AreaPlugin(container);
const connection = new ConnectionPlugin();
const render = new ReactPlugin({ createRoot });
const scopes = new ScopesPlugin();
const scopes = new ScopesPlugin({padding: customScopePadding});
const arrange = new AutoArrangePlugin();

const minimap = new MinimapPlugin({
Expand Down Expand Up @@ -209,7 +221,7 @@
AreaExtensions.zoomAt(area, editor.getNodes());
}

// Adding nodes based on workgraphData
// Adding nodes based on nodegraphData
const nodeMap = {}; // To keep track of created nodes for linking
editor.nodeMap = nodeMap;

Expand Down Expand Up @@ -238,7 +250,7 @@
if (containerRef.current && !editor) {
createEditor(containerRef.current).then((editor) => {
setEditor(editor);
loadJSON(editor.editor, editor.area, editor.layout, workgraphData).then(() => {
loadJSON(editor.editor, editor.area, editor.layout, nodegraphData).then(() => {
// aplly layout twice to ensure all nodes are arranged
editor?.layout(false).then(() => editor?.layout(true));
});
Expand Down
26 changes: 12 additions & 14 deletions docs/source/howto/if.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,13 @@
"wg = WorkGraph(\"if_task\")\n",
"add1 = wg.add_task(add, name=\"add1\", x=1, y=1)\n",
"condition1 = wg.add_task(compare, name=\"condition1\", x=1, y=0)\n",
"add2 = wg.add_task(add, name=\"add2\", x=add1.outputs[\"result\"], y=2)\n",
"if1 = wg.add_task(\"If\", name=\"if_true\",\n",
"if_true_zone = wg.add_task(\"If\", name=\"if_true\",\n",
" conditions=condition1.outputs[\"result\"])\n",
"if1.children.add(\"add2\")\n",
"multiply1 = wg.add_task(multiply, name=\"multiply1\", x=add1.outputs[\"result\"], y=2)\n",
"if2 = wg.add_task(\"If\", name=\"if_false\",\n",
"add2 = if_true_zone.add_task(add, name=\"add2\", x=add1.outputs[\"result\"], y=2)\n",
"if_false_zone = wg.add_task(\"If\", name=\"if_false\",\n",
" conditions=condition1.outputs[\"result\"],\n",
" invert_condition=True)\n",
"if2.children.add(\"multiply1\")\n",
"multiply1 = if_false_zone.add_task(multiply, name=\"multiply1\", x=add1.outputs[\"result\"], y=2)\n",
"#---------------------------------------------------------------------\n",
"select1 = wg.add_task(\"workgraph.select\", name=\"select1\", true=add2.outputs[\"result\"],\n",
" false=multiply1.outputs[\"result\"],\n",
Expand Down Expand Up @@ -377,17 +375,17 @@
"source": [
"# Create a WorkGraph which is dynamically generated based on the input\n",
"# then we output the result of from the context (context)\n",
"@task.graph_builder(outputs = [{\"name\": \"result\", \"from\": \"context.result\"}])\n",
"@task.graph_builder(outputs = [{\"name\": \"result\", \"from\": \"context.data\"}])\n",
"def add_multiply_if(x, y):\n",
" wg = WorkGraph()\n",
" if x.value > 0:\n",
" add1 = wg.add_task(add, name=\"add1\", x=x, y=y)\n",
" # export the result of add1 to the context\n",
" add1.set_context({\"result\": \"result\"})\n",
" # export the result of add1 to the context.data\n",
" add1.set_context({\"data\": \"result\"})\n",
" else:\n",
" multiply1 = wg.add_task(multiply, name=\"multiply1\", x=x, y=y)\n",
" # export the result of multiply1 to the context\n",
" multiply1.set_context({\"result\": \"result\"})\n",
" # export the result of multiply1 to the context.dadta\n",
" multiply1.set_context({\"data\": \"result\"})\n",
" return wg"
]
},
Expand Down Expand Up @@ -433,8 +431,8 @@
"\n",
"wg = WorkGraph(\"if_graph_builer\")\n",
"add1 = wg.add_task(add, name=\"add1\", x=1, y=1)\n",
"add_multiply_if1 = wg.add_task(add_multiply_if, name=\"add_multiply_if1\", x=add1.outputs[\"result\"], y=2)\n",
"add1 = wg.add_task(add, name=\"add2\", x=add_multiply_if1.outputs[\"result\"], y=1)\n",
"add_multiply_if1 = wg.add_task(add_multiply_if, name=\"add_multiply_if1\", x=add1.outputs.result, y=2)\n",
"add1 = wg.add_task(add, name=\"add2\", x=add_multiply_if1.outputs.result, y=1)\n",
"# export the workgraph to html file so that it can be visualized in a browser\n",
"wg.to_html()\n",
"# comment out the following line to visualize the workgraph in jupyter-notebook\n",
Expand Down Expand Up @@ -471,7 +469,7 @@
"source": [
"wg.submit(wait=True)\n",
"print(\"State of WorkGraph : {}\".format(wg.state))\n",
"print('Result of add1 : {}'.format(wg.tasks[\"add2\"].outputs[\"result\"].value))"
"print('Result of add1 : {}'.format(wg.tasks.add2.outputs.result.value))"
]
},
{
Expand Down
51 changes: 27 additions & 24 deletions docs/source/howto/while.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,28 @@
"```\n",
"\n",
"### Adding tasks to the While loop\n",
"We can add tasks to the `While` task using the `children` attribute.\n",
"We can add tasks to the `While` zone using the `add_task` method.\n",
"\n",
"```python\n",
"# add task1 and task2 to the while loop\n",
"while_task.children.add([\"task1\", \"task2\"])\n",
"# Add a new task to the while zone\n",
"while_zone1.add_task(add, name=\"task1\", a=1, b=2)\n",
"```\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "8f5e7642",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Result: uuid: ddc3f239-3776-4746-8c1b-9b7ae9c5f2a4 (pk: 2551) value: 63\n"
]
}
],
"source": [
"from aiida_workgraph import WorkGraph, task\n",
"from aiida import load_profile\n",
Expand Down Expand Up @@ -92,7 +100,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 2,
"id": "8ee799d2-0b5b-4609-957f-6b3f2cd451f0",
"metadata": {},
"outputs": [
Expand All @@ -111,10 +119,10 @@
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x792d26afe0d0>"
"<IPython.lib.display.IFrame at 0x73847c2f07d0>"
]
},
"execution_count": 10,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -128,18 +136,17 @@
"add1 = wg.add_task(add, name=\"add1\", x=1, y=1)\n",
"add1.set_context({\"n\": \"result\"})\n",
"#---------------------------------------------------------------------\n",
"# Create the while tasks\n",
"# Create the while zone\n",
"compare1 = wg.add_task(compare, name=\"compare1\", x=\"{{n}}\", y=50)\n",
"while1 = wg.add_task(\"While\", max_iterations=100, conditions=compare1.outputs[\"result\"])\n",
"# Create the tasks in the while loop.\n",
"add2 = wg.add_task(add, name=\"add2\", x=\"{{n}}\", y=1)\n",
"add2.waiting_on.add(\"add1\")\n",
"multiply1 = wg.add_task(multiply, name=\"multiply1\",\n",
"while_zone1 = wg.add_task(\"While\", max_iterations=100, conditions=compare1.outputs[\"result\"])\n",
"while_zone1.waiting_on.add(\"add1\")\n",
"# Create the tasks in the while zone.\n",
"add2 = while_zone1.add_task(add, name=\"add2\", x=\"{{n}}\", y=1)\n",
"multiply1 = while_zone1.add_task(multiply, name=\"multiply1\",\n",
" x=add2.outputs[\"result\"],\n",
" y=2)\n",
"# update the context variable\n",
"multiply1.set_context({\"n\": \"result\"})\n",
"while1.children.add([\"add2\", \"multiply1\"])\n",
"#---------------------------------------------------------------------\n",
"add3 = wg.add_task(add, name=\"add3\", x=1, y=1)\n",
"wg.add_link(multiply1.outputs[\"result\"], add3.inputs[\"x\"])\n",
Expand All @@ -165,17 +172,18 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 3,
"id": "9ebf35aa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WorkGraph process created, PK: 117695\n",
"WorkGraph process created, PK: 2557\n",
"Process 2557 finished with state: FINISHED\n",
"State of WorkGraph: FINISHED\n",
"Result of add1 : uuid: 061abd52-84fe-4d4e-b381-c303e4c25e19 (pk: 117741) value: 63\n"
"Result of add1 : uuid: 3b982d3c-3aaf-4c87-826c-fe2172939061 (pk: 2603) value: 63\n"
]
}
],
Expand Down Expand Up @@ -1632,7 +1640,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.4 ('scinode')",
"display_name": "aiida",
"language": "python",
"name": "python3"
},
Expand All @@ -1647,11 +1655,6 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
},
"vscode": {
"interpreter": {
"hash": "2f450c1ff08798c4974437dd057310afef0de414c25d1fd960ad375311c3f6ff"
}
}
},
"nbformat": 4,
Expand Down
Loading
Loading