Skip to content

Commit

Permalink
debugged up python, workflow design unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed Feb 10, 2025
1 parent 6ed002f commit 151c325
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 12 deletions.
14 changes: 10 additions & 4 deletions pydra/design/tests/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def func(a: int) -> float:
outputs = sorted(list_fields(SampleDef.Outputs), key=sort_key)
assert inputs == [
python.arg(name="a", type=int),
python.arg(name="function", type=ty.Callable, default=func),
python.arg(name="function", type=ty.Callable, hash_eq=True, default=func),
]
assert outputs == [python.out(name="out", type=float)]
definition = SampleDef(a=1)
Expand All @@ -45,7 +45,7 @@ def func(a: int, k: float = 2.0) -> float:
outputs = sorted(list_fields(SampleDef.Outputs), key=sort_key)
assert inputs == [
python.arg(name="a", type=int),
python.arg(name="function", type=ty.Callable, default=func),
python.arg(name="function", type=ty.Callable, hash_eq=True, default=func),
python.arg(name="k", type=float, default=2.0),
]
assert outputs == [python.out(name="out", type=float)]
Expand All @@ -69,7 +69,7 @@ def func(a: int) -> float:
outputs = sorted(list_fields(SampleDef.Outputs), key=sort_key)
assert inputs == [
python.arg(name="a", type=int, help="The argument to be doubled"),
python.arg(name="function", type=ty.Callable, default=func),
python.arg(name="function", type=ty.Callable, hash_eq=True, default=func),
]
assert outputs == [
python.out(name="b", type=Decimal, help="the doubled output"),
Expand All @@ -94,7 +94,7 @@ def func(a: int) -> int:
outputs = sorted(list_fields(SampleDef.Outputs), key=sort_key)
assert inputs == [
python.arg(name="a", type=float),
python.arg(name="function", type=ty.Callable, default=func),
python.arg(name="function", type=ty.Callable, hash_eq=True, default=func),
]
assert outputs == [python.out(name="b", type=float)]
intf = SampleDef(a=1)
Expand All @@ -118,6 +118,7 @@ def SampleDef(a: int, b: float) -> tuple[float, float]:
python.arg(
name="function",
type=ty.Callable,
hash_eq=True,
default=attrs.fields(SampleDef).function.default,
),
]
Expand Down Expand Up @@ -149,6 +150,7 @@ def SampleDef(a: int, b: float) -> tuple[float, float]:
python.arg(
name="function",
type=ty.Callable,
hash_eq=True,
default=attrs.fields(SampleDef).function.default,
),
]
Expand Down Expand Up @@ -183,6 +185,7 @@ def SampleDef(a: int, b: float) -> tuple[float, float]:
python.arg(
name="function",
type=ty.Callable,
hash_eq=True,
default=attrs.fields(SampleDef).function.default,
),
]
Expand Down Expand Up @@ -225,6 +228,7 @@ def SampleDef(a: int, b: float) -> tuple[float, float]:
python.arg(
name="function",
type=ty.Callable,
hash_eq=True,
default=attrs.fields(SampleDef).function.default,
),
]
Expand Down Expand Up @@ -272,6 +276,7 @@ def function(a, b):
python.arg(
name="function",
type=ty.Callable,
hash_eq=True,
default=attrs.fields(SampleDef).function.default,
),
]
Expand Down Expand Up @@ -342,6 +347,7 @@ def function(a, b):
python.arg(
name="function",
type=ty.Callable,
hash_eq=True,
default=attrs.fields(SampleDef).function.default,
),
]
Expand Down
22 changes: 17 additions & 5 deletions pydra/design/tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def MyTestWorkflow(a, b):
assert list_fields(MyTestWorkflow) == [
workflow.arg(name="a"),
workflow.arg(name="b"),
workflow.arg(name="constructor", type=ty.Callable, default=constructor),
workflow.arg(
name="constructor", type=ty.Callable, hash_eq=True, default=constructor
),
]
assert list_fields(MyTestWorkflow.Outputs) == [
workflow.out(name="out"),
Expand Down Expand Up @@ -108,7 +110,9 @@ def MyTestShellWorkflow(
workflow.arg(name="input_video", type=video.Mp4),
workflow.arg(name="watermark", type=image.Png),
workflow.arg(name="watermark_dims", type=tuple[int, int], default=(10, 10)),
workflow.arg(name="constructor", type=ty.Callable, default=constructor),
workflow.arg(
name="constructor", type=ty.Callable, hash_eq=True, default=constructor
),
]
assert list_fields(MyTestShellWorkflow.Outputs) == [
workflow.out(name="output_video", type=video.Mp4),
Expand Down Expand Up @@ -161,7 +165,9 @@ class Outputs(WorkflowOutputs):
assert sorted(list_fields(MyTestWorkflow), key=attrgetter("name")) == [
workflow.arg(name="a", type=int),
workflow.arg(name="b", type=float, help="A float input", converter=a_converter),
workflow.arg(name="constructor", type=ty.Callable, default=constructor),
workflow.arg(
name="constructor", type=ty.Callable, hash_eq=True, default=constructor
),
]
assert list_fields(MyTestWorkflow.Outputs) == [
workflow.out(name="out", type=float),
Expand Down Expand Up @@ -290,7 +296,10 @@ def MyTestWorkflow(a: int, b: float) -> tuple[float, float]:
workflow.arg(name="a", type=int, help="An integer input"),
workflow.arg(name="b", type=float, help="A float input"),
workflow.arg(
name="constructor", type=ty.Callable, default=MyTestWorkflow().constructor
name="constructor",
type=ty.Callable,
hash_eq=True,
default=MyTestWorkflow().constructor,
),
]
assert list_fields(MyTestWorkflow.Outputs) == [
Expand Down Expand Up @@ -330,7 +339,10 @@ def MyTestWorkflow(a: int, b: float):
workflow.arg(name="a", type=int),
workflow.arg(name="b", type=float),
workflow.arg(
name="constructor", type=ty.Callable, default=MyTestWorkflow().constructor
name="constructor",
type=ty.Callable,
hash_eq=True,
default=MyTestWorkflow().constructor,
),
]
assert list_fields(MyTestWorkflow.Outputs) == [
Expand Down
2 changes: 1 addition & 1 deletion pydra/design/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def make(wrapped: ty.Callable | type) -> TaskDef:
)

parsed_inputs["constructor"] = arg(
name="constructor", type=ty.Callable, default=constructor
name="constructor", type=ty.Callable, hash_eq=True, default=constructor
)
for inpt_name in lazy:
parsed_inputs[inpt_name].lazy = True
Expand Down
4 changes: 2 additions & 2 deletions pydra/engine/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,9 @@ def __iter__(self) -> ty.Generator[str, None, None]:

def __eq__(self, other: ty.Any) -> bool:
"""Check if two task definitions are equal"""
values = attrs.asdict(self)
values = attrs.asdict(self, recurse=False)
try:
other_values = attrs.asdict(other)
other_values = attrs.asdict(other, recurse=False)
except AttributeError:
return False
if set(values) != set(other_values):
Expand Down

0 comments on commit 151c325

Please sign in to comment.