Skip to content

Commit

Permalink
Merge branch 'main' into improve_test_coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 2, 2024
2 parents 7c83c57 + 512d1a8 commit ab113c1
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 69 deletions.
10 changes: 8 additions & 2 deletions aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,8 @@ def decorator_task(
identifier: Optional[str] = None,
task_type: str = "Normal",
properties: Optional[List[Tuple[str, str]]] = None,
inputs: Optional[List[Tuple[str, str]]] = None,
outputs: Optional[List[Tuple[str, str]]] = None,
inputs: Optional[List[str | dict]] = None,
outputs: Optional[List[str | dict]] = None,
error_handlers: Optional[List[Dict[str, Any]]] = None,
catalog: str = "Others",
) -> Callable:
Expand All @@ -574,6 +574,12 @@ def decorator_task(
outputs (list): task outputs
"""

if inputs:
inputs = validate_task_inout(inputs, "inputs")

if outputs:
outputs = validate_task_inout(outputs, "outputs")

def decorator(func):
nonlocal identifier, task_type

Expand Down
13 changes: 10 additions & 3 deletions aiida_workgraph/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from node_graph.node import Node as GraphNode
from aiida_workgraph import USE_WIDGET
from aiida_workgraph.properties import property_pool
Expand Down Expand Up @@ -56,6 +58,7 @@ def __init__(
self._widget = None
self.state = "PLANNED"
self.action = ""
self.show_socket_depth = 0

def to_dict(self, short: bool = False) -> Dict[str, Any]:
from aiida.orm.utils.serialize import serialize
Expand Down Expand Up @@ -174,18 +177,22 @@ def _repr_mimebundle_(self, *args: Any, **kwargs: Any) -> any:
print(WIDGET_INSTALLATION_MESSAGE)
return
# if ipywdigets > 8.0.0, use _repr_mimebundle_ instead of _ipython_display_
self._widget.from_node(self)
self._widget.from_node(self, show_socket_depth=self.show_socket_depth)
if hasattr(self._widget, "_repr_mimebundle_"):
return self._widget._repr_mimebundle_(*args, **kwargs)
else:
return self._widget._ipython_display_(*args, **kwargs)

def to_html(self, output: str = None, **kwargs):
def to_html(
self, output: str = None, show_socket_depth: Optional[int] = None, **kwargs
):
"""Write a standalone html file to visualize the task."""
if show_socket_depth is None:
show_socket_depth = self.show_socket_depth
if self._widget is None:
print(WIDGET_INSTALLATION_MESSAGE)
return
self._widget.from_node(self)
self._widget.from_node(node=self, show_socket_depth=show_socket_depth)
return self._widget.to_html(output=output, **kwargs)


Expand Down
46 changes: 36 additions & 10 deletions aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,18 +595,44 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di
if the former convert them to a list of `dict`s with `name` as the key.
:param inout_list: The input/output list to be validated.
:param list_type: "input" or "output" to indicate what is to be validated.
:raises TypeError: If a list of mixed or wrong types is provided to the task
:param list_type: "inputs" or "outputs" to indicate what is to be validated for better error message.
:raises TypeError: If wrong types are provided to the task
:return: Processed `inputs`/`outputs` list.
"""

if all(isinstance(item, str) for item in inout_list):
return [{"name": item} for item in inout_list]
elif all(isinstance(item, dict) for item in inout_list):
return inout_list
elif not all(isinstance(item, dict) for item in inout_list):
if not all(isinstance(item, (dict, str)) for item in inout_list):
raise TypeError(
f"Provide either a list of `str` or `dict` as `{list_type}`, not mixed types."
f"Wrong type provided in the `{list_type}` list to the task, must be either `str` or `dict`."
)
else:
raise TypeError(f"Wrong type provided in the `{list_type}` list to the task.")

processed_inout_list = []

for item in inout_list:
if isinstance(item, str):
processed_inout_list.append({"name": item})
elif isinstance(item, dict):
processed_inout_list.append(item)

return processed_inout_list


def filter_keys_namespace_depth(
dict_: dict[Any, Any], max_depth: int = 0
) -> dict[Any, Any]:
"""
Filter top-level keys of a dictionary based on the namespace nesting level (number of periods) in the key.
:param dict dict_: The dictionary to filter.
:param int max_depth: Maximum depth of namespaces to retain (number of periods).
:return: The filtered dictionary with only keys satisfying the depth condition.
:rtype: dict
"""
result: dict[Any, Any] = {}

for key, value in dict_.items():
depth = key.count(".")

if depth <= max_depth:
result[key] = value

return result
21 changes: 14 additions & 7 deletions aiida_workgraph/widget/src/widget/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import anywidget
import traitlets
from .utils import wait_to_link
from aiida_workgraph.utils import filter_keys_namespace_depth

try:
__version__ = importlib.metadata.version("widget")
Expand Down Expand Up @@ -37,16 +38,22 @@ def from_workgraph(self, workgraph: Any) -> None:
wgdata = workgraph_to_short_json(wgdata)
self.value = wgdata

def from_node(self, node: Any) -> None:
def from_node(self, node: Any, show_socket_depth: int = 0) -> None:

tdata = node.to_dict()
tdata.pop("properties", None)
tdata.pop("executor", None)
tdata.pop("node_class", None)
tdata.pop("process", None)
tdata["label"] = tdata["identifier"]

# Remove certain elements of the dict-representation of the Node that we don't want to show
for key in ("properties", "executor", "node_class", "process"):
tdata.pop(key, None)
for input in tdata["inputs"].values():
input.pop("property")
tdata["inputs"] = list(tdata["inputs"].values())

tdata["label"] = tdata["identifier"]

filtered_inputs = filter_keys_namespace_depth(
dict_=tdata["inputs"], max_depth=show_socket_depth
)
tdata["inputs"] = list(filtered_inputs.values())
tdata["outputs"] = list(tdata["outputs"].values())
wgdata = {"name": node.name, "nodes": {node.name: tdata}, "links": []}
self.value = wgdata
Expand Down
51 changes: 41 additions & 10 deletions docs/gallery/concept/autogen/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,27 @@ def multiply(x, y):
######################################################################
# If you want to change the name of the output ports, or if there are more
# than one output. You can define the outputs explicitly. For example:
# ``{"name": "sum", "identifier": "workgraph.Any"}``, where the ``identifier``
# indicates the data type. The data type tell the code how to display the
# port in the GUI, validate the data, and serialize data into database. We
# use ``workgraph.Any`` for any data type. For the moment, the data validation is


# define the outputs explicitly
@task(outputs=["sum", "diff"])
def add_minus(x, y):
return {"sum": x + y, "difference": x - y}


print("Inputs:", add_minus.task().inputs.keys())
print("Outputs:", add_minus.task().outputs.keys())

######################################################################
# One can also add an ``identifier`` to indicates the data type. The data
# type tell the code how to display the port in the GUI, validate the data,
# and serialize data into database.
# We use ``workgraph.Any`` for any data type. For the moment, the data validation is
# experimentally supported, and the GUI display is not implemented. Thus,
# I suggest you to always ``workgraph.Any`` for the port.
#

# define add calcfunction task
# define the outputs with identifier
@task(
outputs=[
{"name": "sum", "identifier": "workgraph.Any"},
Expand All @@ -73,10 +85,6 @@ def add_minus(x, y):
return {"sum": x + y, "difference": x - y}


print("Inputs:", add_minus.task().inputs.keys())
print("Outputs:", add_minus.task().outputs.keys())


######################################################################
# Then, one can use the task inside the WorkGraph:
#
Expand Down Expand Up @@ -125,9 +133,32 @@ def add_minus(x, y):
if "." not in output.name:
print(f" - {output.name}")

######################################################################
# For specifying the outputs, the most explicit way is to provide a list of dictionaries, as shown above. In addition,
# as a shortcut, it is also possible to pass a list of strings. In that case, WorkGraph will internally convert the list
# of strings into a list of dictionaries in which case, each ``name`` key will be assigned each passed string value.
# Furthermore, also a mixed list of string and dict elements can be passed, which can be useful in cases where multiple
# outputs should be specified, but more detailed properties are only required for some of the outputs. The above also
# applies for the ``outputs`` argument of the ``@task`` decorator introduced earlier, as well as the ``inputs``, given
# that they are explicitly specified rather than derived from the signature of the ``Callable``. Finally, all lines
# below are valid specifiers for the ``outputs`` of the ``build_task`:
#

NormTask = build_task(norm, outputs=["norm"])
NormTask = build_task(norm, outputs=["norm", "norm2"])
NormTask = build_task(
norm, outputs=["norm", {"name": "norm2", "identifier": "workgraph.Any"}]
)
NormTask = build_task(
norm,
outputs=[
{"name": "norm", "identifier": "workgraph.Any"},
{"name": "norm2", "identifier": "workgraph.Any"},
],
)

######################################################################
# One can use these AiiDA component direclty in the WorkGraph. The inputs
# One can use these AiiDA component directly in the WorkGraph. The inputs
# and outputs of the task is automatically generated based on the input
# and output port of the AiiDA component. In case of ``calcfunction``, the
# default output is ``result``. If there are more than one output task,
Expand Down
12 changes: 12 additions & 0 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@
from typing import Callable


def test_custom_outputs():
"""Test custom outputs."""

@task(outputs=["sum", {"name": "product", "identifier": "workgraph.any"}])
def add_multiply(x, y):
return {"sum": x + y, "product": x * y}

n = add_multiply.task()
assert "sum" in n.outputs.keys()
assert "product" in n.outputs.keys()


@pytest.fixture(params=["decorator_factory", "decorator"])
def task_calcfunction(request):
if request.param == "decorator_factory":
Expand Down
57 changes: 20 additions & 37 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,66 +3,49 @@
from aiida_workgraph.utils import validate_task_inout


def test_validate_task_inout_empty_list():
"""Test validation with a list of strings."""
input_list = []
result = validate_task_inout(input_list, "inputs")
assert result == []


def test_validate_task_inout_str_list():
"""Test validation with a list of strings."""
input_list = ["task1", "task2"]
result = validate_task_inout(input_list, "input")
result = validate_task_inout(input_list, "inputs")
assert result == [{"name": "task1"}, {"name": "task2"}]


def test_validate_task_inout_dict_list():
"""Test validation with a list of dictionaries."""
input_list = [{"name": "task1"}, {"name": "task2"}]
result = validate_task_inout(input_list, "input")
result = validate_task_inout(input_list, "inputs")
assert result == input_list


@pytest.mark.parametrize(
"input_list, list_type, expected_error",
[
# Mixed types error cases
(
["task1", {"name": "task2"}],
"input",
"Provide either a list of `str` or `dict` as `input`, not mixed types.",
),
(
[{"name": "task1"}, "task2"],
"output",
"Provide either a list of `str` or `dict` as `output`, not mixed types.",
),
# Empty list cases
([], "input", None),
([], "output", None),
],
)
def test_validate_task_inout_mixed_types(input_list, list_type, expected_error):
"""Test error handling for mixed type lists."""
if expected_error:
with pytest.raises(TypeError) as excinfo:
validate_task_inout(input_list, list_type)
assert str(excinfo.value) == expected_error
else:
# For empty lists, no error should be raised
result = validate_task_inout(input_list, list_type)
assert result == []
def test_validate_task_inout_mixed_list():
"""Test validation with a list of dictionaries."""
input_list = ["task1", {"name": "task2"}]
result = validate_task_inout(input_list, "inputs")
assert result == [{"name": "task1"}, {"name": "task2"}]


@pytest.mark.parametrize(
"input_list, list_type",
[
# Invalid type cases
([1, 2, 3], "input"),
([None, None], "output"),
([True, False], "input"),
(["task", 123], "output"),
([1, 2, 3], "inputs"),
([None, None], "outputs"),
([True, False], "inputs"),
(["task", 123], "outputs"),
],
)
def test_validate_task_inout_invalid_types(input_list, list_type):
"""Test error handling for completely invalid type lists."""
with pytest.raises(TypeError) as excinfo:
validate_task_inout(input_list, list_type)
assert "Provide either a list of" in str(excinfo.value)
assert "Wrong type provided" in str(excinfo.value)


def test_validate_task_inout_dict_with_extra_keys():
Expand All @@ -71,7 +54,7 @@ def test_validate_task_inout_dict_with_extra_keys():
{"name": "task1", "description": "first task"},
{"name": "task2", "priority": "high"},
]
result = validate_task_inout(input_list, "input")
result = validate_task_inout(input_list, "inputs")
assert result == input_list


Expand Down

0 comments on commit ab113c1

Please sign in to comment.