From 06612f83b0a1a61580f9cae49b66d3cf81043e09 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Mon, 2 Dec 2024 14:00:40 +0100 Subject: [PATCH 1/2] Allow mixed `str`/`dict` inputs/outputs to tasks (#345) --- aiida_workgraph/decorator.py | 10 ++++- aiida_workgraph/utils/__init__.py | 24 +++++++----- docs/gallery/concept/autogen/task.py | 51 ++++++++++++++++++++----- tests/test_decorator.py | 12 ++++++ tests/test_utils.py | 57 ++++++++++------------------ 5 files changed, 95 insertions(+), 59 deletions(-) diff --git a/aiida_workgraph/decorator.py b/aiida_workgraph/decorator.py index 791f7d8c..c5551b48 100644 --- a/aiida_workgraph/decorator.py +++ b/aiida_workgraph/decorator.py @@ -557,8 +557,8 @@ def decorator_task( identifier: Optional[str] = None, task_type: str = "Normal", properties: Optional[List[Tuple[str, str]]] = None, - inputs: Optional[List[Tuple[str, str]]] = None, - outputs: Optional[List[Tuple[str, str]]] = None, + inputs: Optional[List[str | dict]] = None, + outputs: Optional[List[str | dict]] = None, error_handlers: Optional[List[Dict[str, Any]]] = None, catalog: str = "Others", ) -> Callable: @@ -574,6 +574,12 @@ def decorator_task( outputs (list): task outputs """ + if inputs: + inputs = validate_task_inout(inputs, "inputs") + + if outputs: + outputs = validate_task_inout(outputs, "outputs") + def decorator(func): nonlocal identifier, task_type diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index aec053b8..f57d4c2f 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -632,18 +632,22 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di if the former convert them to a list of `dict`s with `name` as the key. :param inout_list: The input/output list to be validated. - :param list_type: "input" or "output" to indicate what is to be validated. - :raises TypeError: If a list of mixed or wrong types is provided to the task + :param list_type: "inputs" or "outputs" to indicate what is to be validated for better error message. + :raises TypeError: If wrong types are provided to the task :return: Processed `inputs`/`outputs` list. """ - if all(isinstance(item, str) for item in inout_list): - return [{"name": item} for item in inout_list] - elif all(isinstance(item, dict) for item in inout_list): - return inout_list - elif not all(isinstance(item, dict) for item in inout_list): + if not all(isinstance(item, (dict, str)) for item in inout_list): raise TypeError( - f"Provide either a list of `str` or `dict` as `{list_type}`, not mixed types." + f"Wrong type provided in the `{list_type}` list to the task, must be either `str` or `dict`." ) - else: - raise TypeError(f"Wrong type provided in the `{list_type}` list to the task.") + + processed_inout_list = [] + + for item in inout_list: + if isinstance(item, str): + processed_inout_list.append({"name": item}) + elif isinstance(item, dict): + processed_inout_list.append(item) + + return processed_inout_list diff --git a/docs/gallery/concept/autogen/task.py b/docs/gallery/concept/autogen/task.py index 0890fd4f..35eec338 100644 --- a/docs/gallery/concept/autogen/task.py +++ b/docs/gallery/concept/autogen/task.py @@ -54,15 +54,27 @@ def multiply(x, y): ###################################################################### # If you want to change the name of the output ports, or if there are more # than one output. You can define the outputs explicitly. For example: -# ``{"name": "sum", "identifier": "workgraph.Any"}``, where the ``identifier`` -# indicates the data type. The data type tell the code how to display the -# port in the GUI, validate the data, and serialize data into database. We -# use ``workgraph.Any`` for any data type. For the moment, the data validation is + + +# define the outputs explicitly +@task(outputs=["sum", "diff"]) +def add_minus(x, y): + return {"sum": x + y, "difference": x - y} + + +print("Inputs:", add_minus.task().inputs.keys()) +print("Outputs:", add_minus.task().outputs.keys()) + +###################################################################### +# One can also add an ``identifier`` to indicates the data type. The data +# type tell the code how to display the port in the GUI, validate the data, +# and serialize data into database. +# We use ``workgraph.Any`` for any data type. For the moment, the data validation is # experimentally supported, and the GUI display is not implemented. Thus, # I suggest you to always ``workgraph.Any`` for the port. # -# define add calcfunction task +# define the outputs with identifier @task( outputs=[ {"name": "sum", "identifier": "workgraph.Any"}, @@ -73,10 +85,6 @@ def add_minus(x, y): return {"sum": x + y, "difference": x - y} -print("Inputs:", add_minus.task().inputs.keys()) -print("Outputs:", add_minus.task().outputs.keys()) - - ###################################################################### # Then, one can use the task inside the WorkGraph: # @@ -125,9 +133,32 @@ def add_minus(x, y): if "." not in output.name: print(f" - {output.name}") +###################################################################### +# For specifying the outputs, the most explicit way is to provide a list of dictionaries, as shown above. In addition, +# as a shortcut, it is also possible to pass a list of strings. In that case, WorkGraph will internally convert the list +# of strings into a list of dictionaries in which case, each ``name`` key will be assigned each passed string value. +# Furthermore, also a mixed list of string and dict elements can be passed, which can be useful in cases where multiple +# outputs should be specified, but more detailed properties are only required for some of the outputs. The above also +# applies for the ``outputs`` argument of the ``@task`` decorator introduced earlier, as well as the ``inputs``, given +# that they are explicitly specified rather than derived from the signature of the ``Callable``. Finally, all lines +# below are valid specifiers for the ``outputs`` of the ``build_task`: +# + +NormTask = build_task(norm, outputs=["norm"]) +NormTask = build_task(norm, outputs=["norm", "norm2"]) +NormTask = build_task( + norm, outputs=["norm", {"name": "norm2", "identifier": "workgraph.Any"}] +) +NormTask = build_task( + norm, + outputs=[ + {"name": "norm", "identifier": "workgraph.Any"}, + {"name": "norm2", "identifier": "workgraph.Any"}, + ], +) ###################################################################### -# One can use these AiiDA component direclty in the WorkGraph. The inputs +# One can use these AiiDA component directly in the WorkGraph. The inputs # and outputs of the task is automatically generated based on the input # and output port of the AiiDA component. In case of ``calcfunction``, the # default output is ``result``. If there are more than one output task, diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 2d96c708..ff2237e6 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -3,6 +3,18 @@ from typing import Callable +def test_custom_outputs(): + """Test custom outputs.""" + + @task(outputs=["sum", {"name": "product", "identifier": "workgraph.any"}]) + def add_multiply(x, y): + return {"sum": x + y, "product": x * y} + + n = add_multiply.task() + assert "sum" in n.outputs.keys() + assert "product" in n.outputs.keys() + + @pytest.fixture(params=["decorator_factory", "decorator"]) def task_calcfunction(request): if request.param == "decorator_factory": diff --git a/tests/test_utils.py b/tests/test_utils.py index 8e42db34..8ecaaab6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,66 +3,49 @@ from aiida_workgraph.utils import validate_task_inout +def test_validate_task_inout_empty_list(): + """Test validation with a list of strings.""" + input_list = [] + result = validate_task_inout(input_list, "inputs") + assert result == [] + + def test_validate_task_inout_str_list(): """Test validation with a list of strings.""" input_list = ["task1", "task2"] - result = validate_task_inout(input_list, "input") + result = validate_task_inout(input_list, "inputs") assert result == [{"name": "task1"}, {"name": "task2"}] def test_validate_task_inout_dict_list(): """Test validation with a list of dictionaries.""" input_list = [{"name": "task1"}, {"name": "task2"}] - result = validate_task_inout(input_list, "input") + result = validate_task_inout(input_list, "inputs") assert result == input_list -@pytest.mark.parametrize( - "input_list, list_type, expected_error", - [ - # Mixed types error cases - ( - ["task1", {"name": "task2"}], - "input", - "Provide either a list of `str` or `dict` as `input`, not mixed types.", - ), - ( - [{"name": "task1"}, "task2"], - "output", - "Provide either a list of `str` or `dict` as `output`, not mixed types.", - ), - # Empty list cases - ([], "input", None), - ([], "output", None), - ], -) -def test_validate_task_inout_mixed_types(input_list, list_type, expected_error): - """Test error handling for mixed type lists.""" - if expected_error: - with pytest.raises(TypeError) as excinfo: - validate_task_inout(input_list, list_type) - assert str(excinfo.value) == expected_error - else: - # For empty lists, no error should be raised - result = validate_task_inout(input_list, list_type) - assert result == [] +def test_validate_task_inout_mixed_list(): + """Test validation with a list of dictionaries.""" + input_list = ["task1", {"name": "task2"}] + result = validate_task_inout(input_list, "inputs") + assert result == [{"name": "task1"}, {"name": "task2"}] @pytest.mark.parametrize( "input_list, list_type", [ # Invalid type cases - ([1, 2, 3], "input"), - ([None, None], "output"), - ([True, False], "input"), - (["task", 123], "output"), + ([1, 2, 3], "inputs"), + ([None, None], "outputs"), + ([True, False], "inputs"), + (["task", 123], "outputs"), ], ) def test_validate_task_inout_invalid_types(input_list, list_type): """Test error handling for completely invalid type lists.""" with pytest.raises(TypeError) as excinfo: validate_task_inout(input_list, list_type) - assert "Provide either a list of" in str(excinfo.value) + assert "Wrong type provided" in str(excinfo.value) def test_validate_task_inout_dict_with_extra_keys(): @@ -71,5 +54,5 @@ def test_validate_task_inout_dict_with_extra_keys(): {"name": "task1", "description": "first task"}, {"name": "task2", "priority": "high"}, ] - result = validate_task_inout(input_list, "input") + result = validate_task_inout(input_list, "inputs") assert result == input_list From 512d1a86c34e52bcebaa56ac9c97cdd503c5e9c1 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Mon, 2 Dec 2024 14:25:20 +0100 Subject: [PATCH 2/2] Hide metadata inputs for tasks in HTML visualization. (#346) Add a parameter `show_socket_depth` to control the level of the input sockets to be shown in the GUI. --- .gitignore | 2 +- aiida_workgraph/task.py | 13 ++++++++--- aiida_workgraph/utils/__init__.py | 22 +++++++++++++++++++ aiida_workgraph/widget/src/widget/__init__.py | 21 ++++++++++++------ 4 files changed, 47 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index a8b996c0..3aa9dd8a 100644 --- a/.gitignore +++ b/.gitignore @@ -135,4 +135,4 @@ dmypy.json tests/work /tests/**/*.png /tests/**/*txt -.vscode +.vscode/ diff --git a/aiida_workgraph/task.py b/aiida_workgraph/task.py index cc4e421b..326f2b50 100644 --- a/aiida_workgraph/task.py +++ b/aiida_workgraph/task.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from node_graph.node import Node as GraphNode from aiida_workgraph import USE_WIDGET from aiida_workgraph.properties import property_pool @@ -56,6 +58,7 @@ def __init__( self._widget = None self.state = "PLANNED" self.action = "" + self.show_socket_depth = 0 def to_dict(self, short: bool = False) -> Dict[str, Any]: from aiida.orm.utils.serialize import serialize @@ -174,18 +177,22 @@ def _repr_mimebundle_(self, *args: Any, **kwargs: Any) -> any: print(WIDGET_INSTALLATION_MESSAGE) return # if ipywdigets > 8.0.0, use _repr_mimebundle_ instead of _ipython_display_ - self._widget.from_node(self) + self._widget.from_node(self, show_socket_depth=self.show_socket_depth) if hasattr(self._widget, "_repr_mimebundle_"): return self._widget._repr_mimebundle_(*args, **kwargs) else: return self._widget._ipython_display_(*args, **kwargs) - def to_html(self, output: str = None, **kwargs): + def to_html( + self, output: str = None, show_socket_depth: Optional[int] = None, **kwargs + ): """Write a standalone html file to visualize the task.""" + if show_socket_depth is None: + show_socket_depth = self.show_socket_depth if self._widget is None: print(WIDGET_INSTALLATION_MESSAGE) return - self._widget.from_node(self) + self._widget.from_node(node=self, show_socket_depth=show_socket_depth) return self._widget.to_html(output=output, **kwargs) diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index f57d4c2f..bee8cf6a 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -651,3 +651,25 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di processed_inout_list.append(item) return processed_inout_list + + +def filter_keys_namespace_depth( + dict_: dict[Any, Any], max_depth: int = 0 +) -> dict[Any, Any]: + """ + Filter top-level keys of a dictionary based on the namespace nesting level (number of periods) in the key. + + :param dict dict_: The dictionary to filter. + :param int max_depth: Maximum depth of namespaces to retain (number of periods). + :return: The filtered dictionary with only keys satisfying the depth condition. + :rtype: dict + """ + result: dict[Any, Any] = {} + + for key, value in dict_.items(): + depth = key.count(".") + + if depth <= max_depth: + result[key] = value + + return result diff --git a/aiida_workgraph/widget/src/widget/__init__.py b/aiida_workgraph/widget/src/widget/__init__.py index 65a1405b..a5908534 100644 --- a/aiida_workgraph/widget/src/widget/__init__.py +++ b/aiida_workgraph/widget/src/widget/__init__.py @@ -5,6 +5,7 @@ import anywidget import traitlets from .utils import wait_to_link +from aiida_workgraph.utils import filter_keys_namespace_depth try: __version__ = importlib.metadata.version("widget") @@ -37,16 +38,22 @@ def from_workgraph(self, workgraph: Any) -> None: wgdata = workgraph_to_short_json(wgdata) self.value = wgdata - def from_node(self, node: Any) -> None: + def from_node(self, node: Any, show_socket_depth: int = 0) -> None: + tdata = node.to_dict() - tdata.pop("properties", None) - tdata.pop("executor", None) - tdata.pop("node_class", None) - tdata.pop("process", None) - tdata["label"] = tdata["identifier"] + + # Remove certain elements of the dict-representation of the Node that we don't want to show + for key in ("properties", "executor", "node_class", "process"): + tdata.pop(key, None) for input in tdata["inputs"].values(): input.pop("property") - tdata["inputs"] = list(tdata["inputs"].values()) + + tdata["label"] = tdata["identifier"] + + filtered_inputs = filter_keys_namespace_depth( + dict_=tdata["inputs"], max_depth=show_socket_depth + ) + tdata["inputs"] = list(filtered_inputs.values()) tdata["outputs"] = list(tdata["outputs"].values()) wgdata = {"name": node.name, "nodes": {node.name: tdata}, "links": []} self.value = wgdata