Skip to content

Commit

Permalink
Improve test coverage (#372)
Browse files Browse the repository at this point in the history
* Add test for the normal task, vector socket, get_or_create_code, test_get_parent_workgraphs, test_generate_node_graph, widget
* Remove unused code
* fix max number of running processes
* fix organize_nested_inputs
* remove append in the awaitable manager
* increase check interval to fix unstable `play` and `pause` test
  • Loading branch information
superstar54 authored Dec 2, 2024
1 parent 512d1a8 commit 7f87a26
Show file tree
Hide file tree
Showing 28 changed files with 367 additions and 320 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ jobs:
playwright install
pip list
- name: Install system dependencies
run: sudo apt update && sudo apt install --no-install-recommends graphviz

- name: Create AiiDA profile
run: verdi setup -n --config .github/config/profile.yaml

Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,5 @@ dmypy.json
tests/work
/tests/**/*.png
/tests/**/*txt
.vscode/
/tests/**/*html
.vscode
24 changes: 3 additions & 21 deletions aiida_workgraph/engine/awaitable_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,9 @@ def insert_awaitable(self, awaitable: Awaitable) -> None:
ctx, key = self.ctx_manager.resolve_nested_context(awaitable.key)

# Already assign the awaitable itself to the location in the context container where it is supposed to end up
# once it is resolved. This is especially important for the `APPEND` action, since it needs to maintain the
# order, but the awaitables will not necessarily be resolved in the order in which they are added. By using the
# awaitable as a placeholder, in the `_resolve_awaitable`, it can be found and replaced by the resolved value.
# once it is resolved.
if awaitable.action == AwaitableAction.ASSIGN:
ctx[key] = awaitable
elif awaitable.action == AwaitableAction.APPEND:
ctx.setdefault(key, []).append(awaitable)
else:
raise AssertionError(f"Unsupported awaitable action: {awaitable.action}")

Expand All @@ -67,26 +63,12 @@ def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None:

if awaitable.action == AwaitableAction.ASSIGN:
ctx[key] = value
elif awaitable.action == AwaitableAction.APPEND:
# Find the same awaitable inserted in the context
container = ctx[key]
for index, placeholder in enumerate(container):
if (
isinstance(placeholder, Awaitable)
and placeholder.pk == awaitable.pk
):
container[index] = value
break
else:
raise AssertionError(
f"Awaitable `{awaitable.pk} was not in `ctx.{awaitable.key}`"
)
else:
raise AssertionError(f"Unsupported awaitable action: {awaitable.action}")

awaitable.resolved = True
# remove awaitabble from the list
self._awaitables = [a for a in self._awaitables if a.pk != awaitable.pk]
# remove awaitabble from the list, and use the same list reference
self._awaitables[:] = [a for a in self._awaitables if a.pk != awaitable.pk]

if not self.process.has_terminated():
# the process may be terminated, for example, if the process was killed or excepted
Expand Down
9 changes: 5 additions & 4 deletions aiida_workgraph/engine/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def is_workgraph_finished(self) -> bool:

def continue_workgraph(self) -> None:
self.process.report("Continue workgraph.")
# self.update_workgraph_from_base()
task_to_run = []
for name, task in self.ctx._tasks.items():
# update task state
Expand Down Expand Up @@ -734,9 +733,11 @@ def update_normal_task_state(self, name, results, success=True):
if success:
task = self.ctx._tasks[name]
if isinstance(results, tuple):
if len(task["outputs"]) != len(results):
return self.exit_codes.OUTPUS_NOT_MATCH_RESULTS
output_names = get_sorted_names(task["outputs"])
# there are two built-in outputs: _wait and _outputs
if len(task["outputs"]) - 2 != len(results):
self.on_task_failed(name)
return self.process.exit_codes.OUTPUS_NOT_MATCH_RESULTS
output_names = get_sorted_names(task["outputs"])[0:-2]
for i, output_name in enumerate(output_names):
task["results"][output_name] = results[i]
elif isinstance(results, dict):
Expand Down
4 changes: 2 additions & 2 deletions aiida_workgraph/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple:
"""Prepare the inputs for WorkGraph task"""
from aiida_workgraph.utils import merge_properties, serialize_properties
from aiida_workgraph.utils import organize_nested_inputs, serialize_properties
from aiida.orm.utils.serialize import deserialize_unsafe

wgdata = deserialize_unsafe(task["executor"]["wgdata"])
Expand All @@ -19,7 +19,7 @@ def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple:
"value"
] = value
# merge the properties
merge_properties(wgdata)
organize_nested_inputs(wgdata)
serialize_properties(wgdata)
metadata = {"call_link_label": task["name"]}
inputs = {"wg": wgdata, "metadata": metadata}
Expand Down
7 changes: 0 additions & 7 deletions aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,13 +322,6 @@ def read_wgdata_from_base(self) -> t.Dict[str, t.Any]:
wgdata["context"] = deserialize_unsafe(wgdata["context"])
return wgdata

def update_workgraph_from_base(self) -> None:
"""Update the ctx from base.extras."""
wgdata = self.read_wgdata_from_base()
for name, task in wgdata["tasks"].items():
task["results"] = self.ctx._tasks[name].get("results")
self.setup_ctx_workgraph(wgdata)

def init_ctx(self, wgdata: t.Dict[str, t.Any]) -> None:
"""Init the context from the workgraph data."""
from aiida_workgraph.utils import update_nested_dict
Expand Down
33 changes: 0 additions & 33 deletions aiida_workgraph/executors/builtins.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,5 @@
from aiida.engine import WorkChain
from aiida import orm
from aiida.engine.processes.workchains.workchain import WorkChainSpec


def select(condition, true, false):
"""Select the data based on the condition."""
if condition:
return true
return false


class GatherWorkChain(WorkChain):
@classmethod
def define(cls, spec: WorkChainSpec) -> None:
"""Define the process specification."""

super().define(spec)
spec.input_namespace(
"datas",
dynamic=True,
help=('Dynamic namespace for the datas, "{key}" : {Data}".'),
)
spec.outline(
cls.gather,
)
spec.output(
"result",
valid_type=orm.List,
required=True,
help="A list of the uuid of the outputs.",
)

def gather(self) -> None:
datas = self.inputs.datas.values()
uuids = [data.uuid for data in datas]
# uuids = gather(uuids)
self.out("result", orm.List(uuids).store())
9 changes: 0 additions & 9 deletions aiida_workgraph/executors/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,6 @@ def add(
return {"sum": x + y}


@calcfunction
def greater(
x: Union[Int, Float], y: Union[Int, Float], t: Union[Int, Float] = 1.0
) -> Dict[str, bool]:
"""Compare node."""
time.sleep(t.value)
return {"result": x > y}


@calcfunction
def sum_diff(
x: Union[Int, Float], y: Union[Int, Float], t: Union[Int, Float] = 1.0
Expand Down
11 changes: 0 additions & 11 deletions aiida_workgraph/property.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,4 @@ def set_value(self, value: Any) -> None:
else:
raise Exception("{} is not an {}.".format(value, DataClass.__name__))

def get_serialize(self) -> Dict[str, str]:
serialize = {"module": "aiida.orm.utils.serialize", "name": "serialize"}
return serialize

def get_deserialize(self) -> Dict[str, str]:
deserialize = {
"module": "aiida.orm.utils.serialize",
"name": "deserialize_unsafe",
}
return deserialize

return AiiDATaskProperty
11 changes: 0 additions & 11 deletions aiida_workgraph/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,4 @@ def __init__(
super().__init__(name, parent, type, index, uuid=uuid)
self.add_property(DataClass, name, **kwargs)

def get_serialize(self) -> dict:
serialize = {"module": "aiida.orm.utils.serialize", "name": "serialize"}
return serialize

def get_deserialize(self) -> dict:
deserialize = {
"module": "aiida.orm.utils.serialize",
"name": "deserialize_unsafe",
}
return deserialize

return AiiDATaskSocket
22 changes: 0 additions & 22 deletions aiida_workgraph/tasks/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,6 @@ def create_sockets(self) -> None:
self.outputs.new("workgraph.any", "_wait")


class Gather(Task):
"""Gather"""

identifier = "workgraph.aiida_gather"
name = "Gather"
node_type = "WORKCHAIN"
catalog = "Control"

_executor = {
"module": "aiida_workgraph.executors.builtins",
"name": "GatherWorkChain",
}
kwargs = ["datas"]

def create_sockets(self) -> None:
self.inputs.clear()
self.outputs.clear()
inp = self.inputs.new("workgraph.any", "datas")
inp.link_limit = 100000
self.outputs.new("workgraph.any", "result")


class SetContext(Task):
"""SetContext"""

Expand Down
24 changes: 0 additions & 24 deletions aiida_workgraph/tasks/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,6 @@ def create_sockets(self) -> None:
self.outputs.new("workgraph.aiida_float", "sum")


class TestGreater(Task):

identifier: str = "workgraph.test_greater"
name = "TestGreater"
node_type = "CALCFUNCTION"
catalog = "Test"

_executor = {
"module": "aiida_workgraph.executors.test",
"name": "greater",
}
kwargs = ["x", "y"]

def create_properties(self) -> None:
pass

def create_sockets(self) -> None:
self.inputs.clear()
self.outputs.clear()
self.inputs.new("workgraph.aiida_float", "x")
self.inputs.new("workgraph.aiida_float", "y")
self.outputs.new("workgraph.aiida_bool", "result")


class TestSumDiff(Task):

identifier: str = "workgraph.test_sum_diff"
Expand Down
Loading

0 comments on commit 7f87a26

Please sign in to comment.