From 151c32586dab3f3f70003209b0d104bae85def24 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Mon, 10 Feb 2025 11:04:53 +1100 Subject: [PATCH] debugged up python, workflow design unittests --- pydra/design/tests/test_python.py | 14 ++++++++++---- pydra/design/tests/test_workflow.py | 22 +++++++++++++++++----- pydra/design/workflow.py | 2 +- pydra/engine/specs.py | 4 ++-- 4 files changed, 30 insertions(+), 12 deletions(-) diff --git a/pydra/design/tests/test_python.py b/pydra/design/tests/test_python.py index dce89dbf0..e698c7949 100644 --- a/pydra/design/tests/test_python.py +++ b/pydra/design/tests/test_python.py @@ -23,7 +23,7 @@ def func(a: int) -> float: outputs = sorted(list_fields(SampleDef.Outputs), key=sort_key) assert inputs == [ python.arg(name="a", type=int), - python.arg(name="function", type=ty.Callable, default=func), + python.arg(name="function", type=ty.Callable, hash_eq=True, default=func), ] assert outputs == [python.out(name="out", type=float)] definition = SampleDef(a=1) @@ -45,7 +45,7 @@ def func(a: int, k: float = 2.0) -> float: outputs = sorted(list_fields(SampleDef.Outputs), key=sort_key) assert inputs == [ python.arg(name="a", type=int), - python.arg(name="function", type=ty.Callable, default=func), + python.arg(name="function", type=ty.Callable, hash_eq=True, default=func), python.arg(name="k", type=float, default=2.0), ] assert outputs == [python.out(name="out", type=float)] @@ -69,7 +69,7 @@ def func(a: int) -> float: outputs = sorted(list_fields(SampleDef.Outputs), key=sort_key) assert inputs == [ python.arg(name="a", type=int, help="The argument to be doubled"), - python.arg(name="function", type=ty.Callable, default=func), + python.arg(name="function", type=ty.Callable, hash_eq=True, default=func), ] assert outputs == [ python.out(name="b", type=Decimal, help="the doubled output"), @@ -94,7 +94,7 @@ def func(a: int) -> int: outputs = sorted(list_fields(SampleDef.Outputs), key=sort_key) assert inputs == [ python.arg(name="a", type=float), - python.arg(name="function", type=ty.Callable, default=func), + python.arg(name="function", type=ty.Callable, hash_eq=True, default=func), ] assert outputs == [python.out(name="b", type=float)] intf = SampleDef(a=1) @@ -118,6 +118,7 @@ def SampleDef(a: int, b: float) -> tuple[float, float]: python.arg( name="function", type=ty.Callable, + hash_eq=True, default=attrs.fields(SampleDef).function.default, ), ] @@ -149,6 +150,7 @@ def SampleDef(a: int, b: float) -> tuple[float, float]: python.arg( name="function", type=ty.Callable, + hash_eq=True, default=attrs.fields(SampleDef).function.default, ), ] @@ -183,6 +185,7 @@ def SampleDef(a: int, b: float) -> tuple[float, float]: python.arg( name="function", type=ty.Callable, + hash_eq=True, default=attrs.fields(SampleDef).function.default, ), ] @@ -225,6 +228,7 @@ def SampleDef(a: int, b: float) -> tuple[float, float]: python.arg( name="function", type=ty.Callable, + hash_eq=True, default=attrs.fields(SampleDef).function.default, ), ] @@ -272,6 +276,7 @@ def function(a, b): python.arg( name="function", type=ty.Callable, + hash_eq=True, default=attrs.fields(SampleDef).function.default, ), ] @@ -342,6 +347,7 @@ def function(a, b): python.arg( name="function", type=ty.Callable, + hash_eq=True, default=attrs.fields(SampleDef).function.default, ), ] diff --git a/pydra/design/tests/test_workflow.py b/pydra/design/tests/test_workflow.py index 4e49f0e7d..090182bc4 100644 --- a/pydra/design/tests/test_workflow.py +++ b/pydra/design/tests/test_workflow.py @@ -57,7 +57,9 @@ def MyTestWorkflow(a, b): assert list_fields(MyTestWorkflow) == [ workflow.arg(name="a"), workflow.arg(name="b"), - workflow.arg(name="constructor", type=ty.Callable, default=constructor), + workflow.arg( + name="constructor", type=ty.Callable, hash_eq=True, default=constructor + ), ] assert list_fields(MyTestWorkflow.Outputs) == [ workflow.out(name="out"), @@ -108,7 +110,9 @@ def MyTestShellWorkflow( workflow.arg(name="input_video", type=video.Mp4), workflow.arg(name="watermark", type=image.Png), workflow.arg(name="watermark_dims", type=tuple[int, int], default=(10, 10)), - workflow.arg(name="constructor", type=ty.Callable, default=constructor), + workflow.arg( + name="constructor", type=ty.Callable, hash_eq=True, default=constructor + ), ] assert list_fields(MyTestShellWorkflow.Outputs) == [ workflow.out(name="output_video", type=video.Mp4), @@ -161,7 +165,9 @@ class Outputs(WorkflowOutputs): assert sorted(list_fields(MyTestWorkflow), key=attrgetter("name")) == [ workflow.arg(name="a", type=int), workflow.arg(name="b", type=float, help="A float input", converter=a_converter), - workflow.arg(name="constructor", type=ty.Callable, default=constructor), + workflow.arg( + name="constructor", type=ty.Callable, hash_eq=True, default=constructor + ), ] assert list_fields(MyTestWorkflow.Outputs) == [ workflow.out(name="out", type=float), @@ -290,7 +296,10 @@ def MyTestWorkflow(a: int, b: float) -> tuple[float, float]: workflow.arg(name="a", type=int, help="An integer input"), workflow.arg(name="b", type=float, help="A float input"), workflow.arg( - name="constructor", type=ty.Callable, default=MyTestWorkflow().constructor + name="constructor", + type=ty.Callable, + hash_eq=True, + default=MyTestWorkflow().constructor, ), ] assert list_fields(MyTestWorkflow.Outputs) == [ @@ -330,7 +339,10 @@ def MyTestWorkflow(a: int, b: float): workflow.arg(name="a", type=int), workflow.arg(name="b", type=float), workflow.arg( - name="constructor", type=ty.Callable, default=MyTestWorkflow().constructor + name="constructor", + type=ty.Callable, + hash_eq=True, + default=MyTestWorkflow().constructor, ), ] assert list_fields(MyTestWorkflow.Outputs) == [ diff --git a/pydra/design/workflow.py b/pydra/design/workflow.py index 7043da1ca..68dfcc37d 100644 --- a/pydra/design/workflow.py +++ b/pydra/design/workflow.py @@ -165,7 +165,7 @@ def make(wrapped: ty.Callable | type) -> TaskDef: ) parsed_inputs["constructor"] = arg( - name="constructor", type=ty.Callable, default=constructor + name="constructor", type=ty.Callable, hash_eq=True, default=constructor ) for inpt_name in lazy: parsed_inputs[inpt_name].lazy = True diff --git a/pydra/engine/specs.py b/pydra/engine/specs.py index 95f5a95f1..c1b97ea06 100644 --- a/pydra/engine/specs.py +++ b/pydra/engine/specs.py @@ -363,9 +363,9 @@ def __iter__(self) -> ty.Generator[str, None, None]: def __eq__(self, other: ty.Any) -> bool: """Check if two task definitions are equal""" - values = attrs.asdict(self) + values = attrs.asdict(self, recurse=False) try: - other_values = attrs.asdict(other) + other_values = attrs.asdict(other, recurse=False) except AttributeError: return False if set(values) != set(other_values):