Skip to content

Commit

Permalink
debugging unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed Jan 29, 2025
1 parent 2d37331 commit 4da2eef
Show file tree
Hide file tree
Showing 17 changed files with 251 additions and 212 deletions.
10 changes: 5 additions & 5 deletions new-docs/source/tutorial/5-shell.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@
"print(f\"Command-line to be run: {cp.cmdline}\")\n",
"\n",
"# Run the shell-comand task\n",
"result = cp()\n",
"outputs = cp()\n",
"\n",
"print(\n",
" f\"Contents of copied file ('{result.output.destination}'): \"\n",
" f\"'{Path(result.output.destination).read_text()}'\"\n",
" f\"Contents of copied file ('{outputs.destination}'): \"\n",
" f\"'{Path(outputs.destination).read_text()}'\"\n",
")"
]
},
Expand Down Expand Up @@ -335,10 +335,10 @@
"cp_with_size = CpWithSize(in_file=File.sample())\n",
"\n",
"# Run the command\n",
"result = cp_with_size()\n",
"outputs = cp_with_size()\n",
"\n",
"\n",
"print(f\"Size of the output file is: {result.output.out_file_size}\")"
"print(f\"Size of the output file is: {outputs.out_file_size}\")"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions new-docs/tst.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
load_json = LoadJson(file=json_file)

# Run the task
result = load_json(plugin="serial")
outputs = load_json(plugin="serial")

# Print the output interface of the of the task (LoadJson.Outputs)
print(result.outputs)
print(outputs)
23 changes: 12 additions & 11 deletions pydra/design/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __bool__(self):

def convert_default_value(value: ty.Any, self_: "Field") -> ty.Any:
"""Ensure the default value has been coerced into the correct type"""
if value is EMPTY:
if value is EMPTY or isinstance(value, attrs.Factory):
return value
return TypeParser[self_.type](self_.type, label=self_.name)(value)

Expand Down Expand Up @@ -197,6 +197,10 @@ def requirements_satisfied(self, inputs: "TaskDef") -> bool:
"""Check if all the requirements are satisfied by the inputs"""
return any(req.satisfied(inputs) for req in self.requires)

@property
def mandatory(self):
return self.default is EMPTY


@attrs.define(kw_only=True)
class Arg(Field):
Expand Down Expand Up @@ -240,7 +244,7 @@ class Arg(Field):
readonly: bool = False


@attrs.define(kw_only=True, slots=False)
@attrs.define(kw_only=True)
class Out(Field):
"""Base class for output fields of task definitions
Expand All @@ -265,7 +269,7 @@ class Out(Field):
The order of the output in the output list, allows for tuple unpacking of outputs
"""

order: int = attrs.field(default=None)
pass


def extract_fields_from_class(
Expand Down Expand Up @@ -394,22 +398,18 @@ def make_task_def(

spec_type._check_arg_refs(inputs, outputs)

# Set positions for outputs to allow for tuple unpacking
for i, output in enumerate(outputs.values()):
output.order = i

if name is None and klass is not None:
name = klass.__name__
if reserved_names := [n for n in inputs if n in spec_type.RESERVED_FIELD_NAMES]:
raise ValueError(
f"{reserved_names} are reserved and cannot be used for {spec_type} field names"
)
outputs_klass = make_outputs_spec(out_type, outputs, outputs_bases, name)
if issubclass(klass, TaskDef) and not issubclass(klass, spec_type):
raise ValueError(f"Cannot change type of definition {klass} to {spec_type}")
if klass is None or not issubclass(klass, spec_type):
if name is None:
raise ValueError("name must be provided if klass is not")
if klass is not None and issubclass(klass, TaskDef):
raise ValueError(f"Cannot change type of definition {klass} to {spec_type}")
bases = tuple(bases)
# Ensure that TaskDef is a base class
if not any(issubclass(b, spec_type) for b in bases):
Expand Down Expand Up @@ -518,16 +518,17 @@ def make_outputs_spec(
field.name = name
field.type = base.__annotations__.get(name, ty.Any)
outputs.update(base_outputs)
assert all(o.name == n for n, o in outputs.items())
outputs_klass = type(
spec_name + "Outputs",
tuple(outputs_bases),
{
o.name: attrs.field(
n: attrs.field(
converter=make_converter(o, f"{spec_name}.Outputs"),
metadata={PYDRA_ATTR_METADATA: o},
**_get_default(o),
)
for o in outputs.values()
for n, o in outputs.items()
},
)
outputs_klass.__annotations__.update((o.name, o.type) for o in outputs.values())
Expand Down
7 changes: 6 additions & 1 deletion pydra/design/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class out(Out):
outputs
"""

pass
order: int = attrs.field(default=None)


@dataclass_transform(
Expand Down Expand Up @@ -161,6 +161,11 @@ def make(wrapped: ty.Callable | type) -> PythonDef:
name="function", type=ty.Callable, default=function
)

# Set positions for outputs to allow for tuple unpacking
output: out
for i, output in enumerate(parsed_outputs.values()):
output.order = i

interface = make_task_def(
PythonDef,
PythonOutputs,
Expand Down
31 changes: 21 additions & 10 deletions pydra/design/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,12 @@ def _validate_sep(self, attribute, value):
value is not None
and self.type is not ty.Any
and ty.get_origin(self.type) is not MultiInputObj
and not issubclass(self.type, ty.Iterable)
):
raise ValueError(
f"sep ({value!r}) can only be provided when type is iterable"
)
tp = ty.get_origin(self.type) or self.type
if not issubclass(tp, ty.Iterable):
raise ValueError(
f"sep ({value!r}) can only be provided when type is iterable"
)


@attrs.define(kw_only=True)
Expand Down Expand Up @@ -353,6 +354,12 @@ def make(
if class_name[0].isdigit():
class_name = f"_{class_name}"

# Add in fields from base classes
parsed_inputs.update({n: getattr(ShellDef, n) for n in ShellDef.BASE_NAMES})
parsed_outputs.update(
{n: getattr(ShellOutputs, n) for n in ShellOutputs.BASE_NAMES}
)

# Update the inputs (overriding inputs from base classes) with the executable
# and the output argument fields
parsed_inputs.update(
Expand All @@ -371,10 +378,12 @@ def make(
# Set positions for the remaining inputs that don't have an explicit position
position_stack = remaining_positions(list(parsed_inputs.values()))
for inpt in parsed_inputs.values():
if inpt.name == "additional_args":
continue
if inpt.position is None:
inpt.position = position_stack.pop(0)

interface = make_task_def(
defn = make_task_def(
ShellDef,
ShellOutputs,
parsed_inputs,
Expand All @@ -384,7 +393,7 @@ def make(
bases=bases,
outputs_bases=outputs_bases,
)
return interface
return defn

# If a name is provided (and hence not being used as a decorator), check to see if
# we are extending from a class that already defines an executable
Expand Down Expand Up @@ -479,17 +488,19 @@ def parse_command_line_template(
outputs = {}
parts = template.split()
executable = []
for i, part in enumerate(parts, start=1):
start_args_index = 0
for part in parts:
if part.startswith("<") or part.startswith("-"):
break
executable.append(part)
start_args_index += 1
if not executable:
raise ValueError(f"Found no executable in command line template: {template}")
if len(executable) == 1:
executable = executable[0]
if i == len(parts):
args_str = " ".join(parts[start_args_index:])
if not args_str:
return executable, inputs, outputs
args_str = " ".join(parts[i - 1 :])
tokens = re.split(r"\s+", args_str.strip())
arg_pattern = r"<([:a-zA-Z0-9_,\|\-\.\/\+]+(?:\?|=[^>]+)?)>"
opt_pattern = r"--?[a-zA-Z0-9_]+"
Expand Down Expand Up @@ -662,7 +673,7 @@ def remaining_positions(
# Check for multiple positions
positions = defaultdict(list)
for arg in args:
if arg.name == "arguments":
if arg.name == "additional_args":
continue
if arg.position is not None:
if arg.position >= 0:
Expand Down
Loading

0 comments on commit 4da2eef

Please sign in to comment.