Skip to content

Commit

Permalink
add test_vector_socket
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 2, 2024
1 parent ad83b4b commit a193376
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 90 deletions.
33 changes: 0 additions & 33 deletions aiida_workgraph/executors/builtins.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,5 @@
from aiida.engine import WorkChain
from aiida import orm
from aiida.engine.processes.workchains.workchain import WorkChainSpec


def select(condition, true, false):
"""Select the data based on the condition."""
if condition:
return true
return false


class GatherWorkChain(WorkChain):
@classmethod
def define(cls, spec: WorkChainSpec) -> None:
"""Define the process specification."""

super().define(spec)
spec.input_namespace(
"datas",
dynamic=True,
help=('Dynamic namespace for the datas, "{key}" : {Data}".'),
)
spec.outline(
cls.gather,
)
spec.output(
"result",
valid_type=orm.List,
required=True,
help="A list of the uuid of the outputs.",
)

def gather(self) -> None:
datas = self.inputs.datas.values()
uuids = [data.uuid for data in datas]
# uuids = gather(uuids)
self.out("result", orm.List(uuids).store())
9 changes: 0 additions & 9 deletions aiida_workgraph/executors/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,6 @@ def add(
return {"sum": x + y}


@calcfunction
def greater(
x: Union[Int, Float], y: Union[Int, Float], t: Union[Int, Float] = 1.0
) -> Dict[str, bool]:
"""Compare node."""
time.sleep(t.value)
return {"result": x > y}


@calcfunction
def sum_diff(
x: Union[Int, Float], y: Union[Int, Float], t: Union[Int, Float] = 1.0
Expand Down
22 changes: 0 additions & 22 deletions aiida_workgraph/tasks/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,6 @@ def create_sockets(self) -> None:
self.outputs.new("workgraph.any", "_wait")


class Gather(Task):
"""Gather"""

identifier = "workgraph.aiida_gather"
name = "Gather"
node_type = "WORKCHAIN"
catalog = "Control"

_executor = {
"module": "aiida_workgraph.executors.builtins",
"name": "GatherWorkChain",
}
kwargs = ["datas"]

def create_sockets(self) -> None:
self.inputs.clear()
self.outputs.clear()
inp = self.inputs.new("workgraph.any", "datas")
inp.link_limit = 100000
self.outputs.new("workgraph.any", "result")


class SetContext(Task):
"""SetContext"""

Expand Down
24 changes: 0 additions & 24 deletions aiida_workgraph/tasks/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,6 @@ def create_sockets(self) -> None:
self.outputs.new("workgraph.aiida_float", "sum")


class TestGreater(Task):

identifier: str = "workgraph.test_greater"
name = "TestGreater"
node_type = "CALCFUNCTION"
catalog = "Test"

_executor = {
"module": "aiida_workgraph.executors.test",
"name": "greater",
}
kwargs = ["x", "y"]

def create_properties(self) -> None:
pass

def create_sockets(self) -> None:
self.inputs.clear()
self.outputs.clear()
self.inputs.new("workgraph.aiida_float", "x")
self.inputs.new("workgraph.aiida_float", "y")
self.outputs.new("workgraph.aiida_bool", "result")


class TestSumDiff(Task):

identifier: str = "workgraph.test_sum_diff"
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph"
"workgraph.aiida_node" = "aiida_workgraph.tasks.builtins:AiiDANode"
"workgraph.aiida_code" = "aiida_workgraph.tasks.builtins:AiiDACode"
"workgraph.test_add" = "aiida_workgraph.tasks.test:TestAdd"
"workgraph.test_greater" = "aiida_workgraph.tasks.test:TestGreater"
"workgraph.test_sum_diff" = "aiida_workgraph.tasks.test:TestSumDiff"
"workgraph.test_arithmetic_multiply_add" = "aiida_workgraph.tasks.test:TestArithmeticMultiplyAdd"
"workgraph.pythonjob" = "aiida_workgraph.tasks.pythonjob:PythonJob"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_pause_play_task(wg_calcjob):
wg.pause_tasks(["add1"])
wg.submit()
# wait for the workgraph to launch add1
wg.wait(tasks={"add1": ["CREATED"]}, timeout=20)
wg.wait(tasks={"add1": ["CREATED"]}, timeout=40)
assert wg.tasks["add1"].node.process_state.value.upper() == "CREATED"
assert wg.tasks["add1"].node.process_status == "Paused through WorkGraph"
# pause add2 after submit
Expand Down
16 changes: 16 additions & 0 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ def add(x: data_type):
add_task.set({"x": "{{variable}}"})


def test_vector_socket() -> None:
"""Test the vector data type."""
from aiida_workgraph import Task

t = Task()
t.inputs.new(
"workgraph.aiida_int_vector",
"vector2d",
property_data={"size": 2, "default": [1, 2]},
)
try:
t.inputs["vector2d"].value = [1, 2, 3]
except Exception as e:
assert "Invalid size: Expected 2, got 3 instead." in str(e)


def test_aiida_data_socket() -> None:
"""Test the mapping of data types to socket types."""
from aiida import orm, load_profile
Expand Down

0 comments on commit a193376

Please sign in to comment.