From 5e898af7709dfdbb521b927e9873199e850fc6bc Mon Sep 17 00:00:00 2001 From: Tom Close Date: Fri, 13 Dec 2024 16:28:27 +1100 Subject: [PATCH] shell tasks now execute --- pydra/design/base.py | 28 +- pydra/design/shell.py | 1 + pydra/design/tests/test_shell.py | 19 +- pydra/design/tests/test_workflow.py | 8 +- pydra/engine/boutiques.py | 10 +- pydra/engine/core.py | 234 +++-------- pydra/engine/environments.py | 13 +- pydra/engine/specs.py | 384 ++++++++++++++++-- pydra/engine/task.py | 301 +------------- pydra/engine/tests/test_boutiques.py | 12 +- pydra/engine/tests/test_dockertask.py | 4 +- pydra/engine/tests/test_helpers_file.py | 6 +- pydra/engine/tests/test_nipype1_convert.py | 6 +- pydra/engine/tests/test_shelltask.py | 88 ++-- .../engine/tests/test_shelltask_inputspec.py | 38 +- pydra/engine/tests/test_task.py | 4 +- pydra/engine/workflow/node.py | 97 ++++- pydra/utils/typing.py | 11 + 18 files changed, 660 insertions(+), 604 deletions(-) diff --git a/pydra/design/base.py b/pydra/design/base.py index feff4a4467..61394e7a68 100644 --- a/pydra/design/base.py +++ b/pydra/design/base.py @@ -9,7 +9,7 @@ import attrs.validators from attrs.converters import default_if_none from fileformats.generic import File -from pydra.utils.typing import TypeParser, is_optional, is_fileset_or_union +from pydra.utils.typing import TypeParser, is_optional, is_fileset_or_union, is_type from pydra.engine.helpers import ( from_list_if_single, ensure_list, @@ -52,11 +52,6 @@ def __bool__(self): EMPTY = _Empty.EMPTY # To provide a blank placeholder for the default field -def is_type(_, __, val: ty.Any) -> bool: - """check that the value is a type or generic""" - return inspect.isclass(val) or ty.get_origin(val) - - def convert_default_value(value: ty.Any, self_: "Field") -> ty.Any: """Ensure the default value has been coerced into the correct type""" if value is EMPTY: @@ -400,6 +395,10 @@ def make_task_spec( if name is None and klass is not None: name = klass.__name__ + if reserved_names := [n for n in inputs if n in spec_type.RESERVED_FIELD_NAMES]: + raise ValueError( + f"{reserved_names} are reserved and cannot be used for {spec_type} field names" + ) outputs_klass = make_outputs_spec(out_type, outputs, outputs_bases, name) if klass is None or not issubclass(klass, spec_type): if name is None: @@ -503,7 +502,7 @@ def make_outputs_spec( outputs_bases = bases + (spec_type,) if reserved_names := [n for n in outputs if n in spec_type.RESERVED_FIELD_NAMES]: raise ValueError( - f"{reserved_names} are reserved and cannot be used for output field names" + f"{reserved_names} are reserved and cannot be used for {spec_type} field names" ) # Add in any fields in base classes that haven't already been converted into attrs # fields (e.g. stdout, stderr and return_code) @@ -585,12 +584,25 @@ def ensure_field_objects( arg.name = input_name if not arg.help_string: arg.help_string = input_helps.get(input_name, "") - else: + elif is_type(arg): inputs[input_name] = arg_type( type=arg, name=input_name, help_string=input_helps.get(input_name, ""), ) + elif isinstance(arg, dict): + arg_kwds = copy(arg) + if "help_string" not in arg_kwds: + arg_kwds["help_string"] = input_helps.get(input_name, "") + inputs[input_name] = arg_type( + name=input_name, + **arg_kwds, + ) + else: + raise ValueError( + f"Input {input_name} must be an instance of {Arg}, a type, or a dictionary " + f" of keyword arguments to pass to {Arg}, not {arg}" + ) for output_name, out in list(outputs.items()): if isinstance(out, Out): diff --git a/pydra/design/shell.py b/pydra/design/shell.py index 346ebd2eed..2410a410ae 100644 --- a/pydra/design/shell.py +++ b/pydra/design/shell.py @@ -343,6 +343,7 @@ def make( argstr="", position=0, default=executable, + validator=attrs.validators.min_len(1), help_string=EXECUTABLE_HELP_STRING, ) diff --git a/pydra/design/tests/test_shell.py b/pydra/design/tests/test_shell.py index 1cbae39be8..e25d2c7a5d 100644 --- a/pydra/design/tests/test_shell.py +++ b/pydra/design/tests/test_shell.py @@ -32,6 +32,7 @@ def test_interface_template(): assert sorted_fields(SampleInterface) == [ shell.arg( name="executable", + validator=attrs.validators.min_len(1), default="cp", type=str | ty.Sequence[str], position=0, @@ -80,6 +81,7 @@ def test_interface_template_w_types_and_path_template_ext(): assert sorted_fields(SampleInterface) == [ shell.arg( name="executable", + validator=attrs.validators.min_len(1), default="trim-png", type=str | ty.Sequence[str], position=0, @@ -119,6 +121,7 @@ def test_interface_template_w_modify(): assert sorted_fields(SampleInterface) == [ shell.arg( name="executable", + validator=attrs.validators.min_len(1), default="trim-png", type=str | ty.Sequence[str], position=0, @@ -176,6 +179,7 @@ def test_interface_template_more_complex(): assert sorted_fields(SampleInterface) == [ shell.arg( name="executable", + validator=attrs.validators.min_len(1), default="cp", type=str | ty.Sequence[str], position=0, @@ -273,6 +277,7 @@ def test_interface_template_with_overrides_and_optionals(): == [ shell.arg( name="executable", + validator=attrs.validators.min_len(1), default="cp", type=str | ty.Sequence[str], position=0, @@ -347,6 +352,7 @@ def test_interface_template_with_defaults(): assert sorted_fields(SampleInterface) == [ shell.arg( name="executable", + validator=attrs.validators.min_len(1), default="cp", type=str | ty.Sequence[str], position=0, @@ -414,6 +420,7 @@ def test_interface_template_with_type_overrides(): assert sorted_fields(SampleInterface) == [ shell.arg( name="executable", + validator=attrs.validators.min_len(1), default="cp", type=str | ty.Sequence[str], position=0, @@ -545,6 +552,7 @@ class Outputs(ShellOutputs): type=bool, help_string="Show complete date in long format", argstr="-T", + default=False, requires=["long_format"], xor=["date_format_str"], ), @@ -606,7 +614,7 @@ def test_shell_pickle_roundtrip(Ls, tmp_path): assert RereadLs is Ls -@pytest.mark.xfail(reason="Still need to update tasks to use new shell interface") +# @pytest.mark.xfail(reason="Still need to update tasks to use new shell interface") def test_shell_run(Ls, tmp_path): Path.touch(tmp_path / "a") Path.touch(tmp_path / "b") @@ -615,16 +623,16 @@ def test_shell_run(Ls, tmp_path): ls = Ls(directory=tmp_path, long_format=True) # Test cmdline - assert ls.inputs.directory == tmp_path - assert not ls.inputs.hidden - assert ls.inputs.long_format + assert ls.directory == Directory(tmp_path) + assert not ls.hidden + assert ls.long_format assert ls.cmdline == f"ls -l {tmp_path}" # Drop Long format flag to make output simpler ls = Ls(directory=tmp_path) result = ls() - assert result.output.entries == ["a", "b", "c"] + assert sorted(result.output.entries) == ["a", "b", "c"] @pytest.fixture(params=["static", "dynamic"]) @@ -721,6 +729,7 @@ class Outputs: assert sorted_fields(A) == [ shell.arg( name="executable", + validator=attrs.validators.min_len(1), default="cp", type=str | ty.Sequence[str], argstr="", diff --git a/pydra/design/tests/test_workflow.py b/pydra/design/tests/test_workflow.py index 34cd564770..d6b11ab565 100644 --- a/pydra/design/tests/test_workflow.py +++ b/pydra/design/tests/test_workflow.py @@ -93,14 +93,8 @@ def MyTestShellWorkflow( ) output_video = workflow.add( shell.define( - "HandBrakeCLI -i -o " + "HandBrakeCLI -i -o " "--width --height ", - # By default any input/output specified with a flag (e.g. -i ) - # is considered optional, i.e. of type `FsObject | None`, and therefore - # won't be used by default. By overriding this with non-optional types, - # the fields are specified as being required. - inputs={"in_video": video.Mp4}, - outputs={"out_video": video.Mp4}, )(in_video=add_watermark.out_video, width=1280, height=720), name="resize", ).out_video diff --git a/pydra/engine/boutiques.py b/pydra/engine/boutiques.py index 8d7782b3e5..d12d30f1d4 100644 --- a/pydra/engine/boutiques.py +++ b/pydra/engine/boutiques.py @@ -182,9 +182,9 @@ def _command_args_single(self, state_ind=None, index=None): """Get command line arguments for a single state""" input_filepath = self._bosh_invocation_file(state_ind=state_ind, index=index) cmd_list = ( - self.inputs.executable + self.spec.executable + [str(self.bosh_file), input_filepath] - + self.inputs.args + + self.spec.args + self.bindings ) return cmd_list @@ -192,11 +192,11 @@ def _command_args_single(self, state_ind=None, index=None): def _bosh_invocation_file(self, state_ind=None, index=None): """creating bosh invocation file - json file with inputs values""" input_json = {} - for f in attrs_fields(self.inputs, exclude_names=("executable", "args")): + for f in attrs_fields(self.spec, exclude_names=("executable", "args")): if self.state and f"{self.name}.{f.name}" in state_ind: - value = getattr(self.inputs, f.name)[state_ind[f"{self.name}.{f.name}"]] + value = getattr(self.spec, f.name)[state_ind[f"{self.name}.{f.name}"]] else: - value = getattr(self.inputs, f.name) + value = getattr(self.spec, f.name) # adding to the json file if specified by the user if value is not attr.NOTHING and value != "NOTHING": if is_local_file(f): diff --git a/pydra/engine/core.py b/pydra/engine/core.py index 18631d38ea..f74264dfb8 100644 --- a/pydra/engine/core.py +++ b/pydra/engine/core.py @@ -7,7 +7,7 @@ import sys from pathlib import Path import typing as ty -from copy import deepcopy, copy +from copy import deepcopy from uuid import uuid4 from filelock import SoftFileLock import shutil @@ -22,6 +22,7 @@ RuntimeSpec, Result, TaskHook, + TaskSpec, ) from .helpers import ( create_checksum, @@ -76,6 +77,9 @@ class Task: _cache_dir = None # Working directory in which to operate _references = None # List of references for a task + name: str + spec: TaskSpec + def __init__( self, spec, @@ -131,23 +135,12 @@ def __init__( if Task._etelemetry_version_data is None: Task._etelemetry_version_data = check_latest_version() - self.interface = spec - # raise error if name is same as of attributes + self.spec = spec self.name = name - if not self.input_spec: - raise Exception("No input_spec in class: %s" % self.__class__.__name__) - - self.inputs = self.interface( - **{ - # in attrs names that starts with "_" could be set when name provided w/o "_" - (f.name[1:] if f.name.startswith("_") else f.name): f.default - for f in attr.fields(type(self.interface)) - } - ) self.input_names = [ field.name - for field in attr.fields(type(self.interface)) + for field in attr.fields(type(self.spec)) if field.name not in ["_func", "_graph_checksums"] ] @@ -164,17 +157,11 @@ def __init__( raise ValueError(f"Unknown input set {inputs!r}") inputs = self._input_sets[inputs] - self.inputs = attr.evolve(self.inputs, **inputs) + self.spec = attr.evolve(self.spec, **inputs) # checking if metadata is set properly - self.inputs.check_metadata() - # dictionary to save the connections with lazy fields - self.inp_lf = {} - self.state = None - # container dimensions provided by the user - self.cont_dim = cont_dim - # container dimension for inner input if needed (e.g. for inner splitter) - self._inner_cont_dim = {} + self.spec._check_resolved() + self.spec._check_rules() self._output = {} self._result = {} # flag that says if node finished all jobs @@ -206,18 +193,11 @@ def __str__(self): def __getstate__(self): state = self.__dict__.copy() - state["interface"] = cp.dumps(state["interface"]) - inputs = {} - for k, v in attr.asdict(state["inputs"], recurse=False).items(): - if k.startswith("_"): - k = k[1:] - inputs[k] = v - state["inputs"] = inputs + state["spec"] = cp.dumps(state["spec"]) return state def __setstate__(self, state): - state["interface"] = cp.loads(state["interface"]) - state["inputs"] = self.interface(**state["inputs"]) + state["spec"] = cp.loads(state["spec"]) self.__dict__.update(state) def help(self, returnhelp=False): @@ -243,63 +223,10 @@ def checksum(self): and to create nodes checksums needed for graph checksums (before the tasks have inputs etc.) """ - input_hash = self.inputs.hash - if self.state is None: - self._checksum = create_checksum(self.__class__.__name__, input_hash) - else: - splitter_hash = hash_function(self.state.splitter) - self._checksum = create_checksum( - self.__class__.__name__, hash_function([input_hash, splitter_hash]) - ) + input_hash = self.spec._hash + self._checksum = create_checksum(self.__class__.__name__, input_hash) return self._checksum - def checksum_states(self, state_index=None): - """ - Calculate a checksum for the specific state or all of the states of the task. - Replaces state-arrays in the inputs fields with a specific values for states. - Used to recreate names of the task directories, - - Parameters - ---------- - state_index : - TODO - - """ - if is_workflow(self) and self.inputs._graph_checksums is attr.NOTHING: - self.inputs._graph_checksums = { - nd.name: nd.checksum for nd in self.graph_sorted - } - - if state_index is not None: - inputs_copy = copy(self.inputs) - for key, ind in self.state.inputs_ind[state_index].items(): - val = self._extract_input_el( - inputs=self.inputs, inp_nm=key.split(".")[1], ind=ind - ) - setattr(inputs_copy, key.split(".")[1], val) - # setting files_hash again in case it was cleaned by setting specific element - # that might be important for outer splitter of input variable with big files - # the file can be changed with every single index even if there are only two files - input_hash = inputs_copy.hash - if is_workflow(self): - con_hash = hash_function(self._connections) - # TODO: hash list is not used - hash_list = [input_hash, con_hash] # noqa: F841 - checksum_ind = create_checksum( - self.__class__.__name__, self._checksum_wf(input_hash) - ) - else: - checksum_ind = create_checksum(self.__class__.__name__, input_hash) - return checksum_ind - else: - checksum_list = [] - if not hasattr(self.state, "inputs_ind"): - self.state.prepare_states(self.inputs, cont_dim=self.cont_dim) - self.state.prepare_inputs() - for ind in range(len(self.state.inputs_ind)): - checksum_list.append(self.checksum_states(state_index=ind)) - return checksum_list - @property def uid(self): """the unique id number for the task @@ -333,7 +260,7 @@ def output_names(self): """Get the names of the outputs from the task's output_spec (not everything has to be generated, see generated_output_names). """ - return [f.name for f in attr.fields(self.interface.Outputs)] + return [f.name for f in attr.fields(self.spec.Outputs)] @property def generated_output_names(self): @@ -342,13 +269,13 @@ def generated_output_names(self): it uses output_names. The results depends on the input provided to the task """ - output_klass = self.interface.Outputs + output_klass = self.spec.Outputs if hasattr(output_klass, "_generated_output_names"): output = output_klass( **{f.name: attr.NOTHING for f in attr.fields(output_klass)} ) # using updated input (after filing the templates) - _inputs = deepcopy(self.inputs) + _inputs = deepcopy(self.spec) modified_inputs = template_update(_inputs, self.output_dir) if modified_inputs: _inputs = attr.evolve(_inputs, **modified_inputs) @@ -397,8 +324,6 @@ def cache_locations(self, locations): @property def output_dir(self): """Get the filesystem path where outputs will be written.""" - if self.state: - return [self._cache_dir / checksum for checksum in self.checksum_states()] return self._cache_dir / self.checksum @property @@ -434,7 +359,7 @@ def __call__( pass # if there is plugin provided or the task is a Workflow or has a state, # the submitter will be created using provided plugin, self.plugin or "cf" - elif plugin or self.state or is_workflow(self): + elif plugin: plugin = plugin or self.plugin or "cf" if plugin_kwargs is None: plugin_kwargs = {} @@ -442,7 +367,7 @@ def __call__( if submitter: with submitter as sub: - self.inputs = attr.evolve(self.inputs, **kwargs) + self.spec = attr.evolve(self.spec, **kwargs) res = sub(self, environment=environment) else: # tasks without state could be run without a submitter res = self._run(rerun=rerun, environment=environment, **kwargs) @@ -462,10 +387,10 @@ def _modify_inputs(self): from pydra.utils.typing import TypeParser orig_inputs = { - k: v for k, v in attrs_values(self.inputs).items() if not k.startswith("_") + k: v for k, v in attrs_values(self.spec).items() if not k.startswith("_") } map_copyfiles = {} - input_fields = attr.fields(type(self.inputs)) + input_fields = attr.fields(type(self.spec)) for name, value in orig_inputs.items(): fld = getattr(input_fields, name) copy_mode, copy_collation = parse_copyfile( @@ -484,7 +409,7 @@ def _modify_inputs(self): if value is not copied_value: map_copyfiles[name] = copied_value modified_inputs = template_update( - self.inputs, self.output_dir, map_copyfiles=map_copyfiles + self.spec, self.output_dir, map_copyfiles=map_copyfiles ) assert all(m in orig_inputs for m in modified_inputs), ( "Modified inputs contain fields not present in original inputs. " @@ -497,7 +422,7 @@ def _modify_inputs(self): # Ensure we pass a copy not the original just in case inner # attributes are modified during execution value = deepcopy(orig_value) - setattr(self.inputs, name, value) + setattr(self.spec, name, value) return orig_inputs def _populate_filesystem(self, checksum, output_dir): @@ -517,10 +442,7 @@ def _populate_filesystem(self, checksum, output_dir): shutil.rmtree(output_dir) output_dir.mkdir(parents=False, exist_ok=self.can_resume) - def _run(self, rerun=False, environment=None, **kwargs): - self.inputs = attr.evolve(self.inputs, **kwargs) - self.inputs.check_fields_input_spec() - + def _run(self, rerun=False, environment=None): checksum = self.checksum output_dir = self.output_dir lockfile = self.cache_dir / (checksum + ".lock") @@ -535,7 +457,6 @@ def _run(self, rerun=False, environment=None, **kwargs): cwd = os.getcwd() self._populate_filesystem(checksum, output_dir) os.chdir(output_dir) - orig_inputs = self._modify_inputs() result = Result(output=None, runtime=None, errored=False) self.hooks.pre_run_task(self) self.audit.start_audit(odir=output_dir) @@ -544,7 +465,7 @@ def _run(self, rerun=False, environment=None, **kwargs): try: self.audit.monitor() self._run_task(environment=environment) - result.output = self._collect_outputs(output_dir=output_dir) + result.output = self.spec.Outputs.from_task(self) except Exception: etype, eval, etr = sys.exc_info() traceback = format_exception(etype, eval, etr) @@ -558,8 +479,6 @@ def _run(self, rerun=False, environment=None, **kwargs): # removing the additional file with the checksum (self.cache_dir / f"{self.uid}_info.json").unlink() # Restore original values to inputs - for field_name, field_value in orig_inputs.items(): - setattr(self.inputs, field_name, field_value) os.chdir(cwd) self.hooks.post_run(self, result) # Check for any changes to the input hashes that have occurred during the execution @@ -567,16 +486,6 @@ def _run(self, rerun=False, environment=None, **kwargs): self._check_for_hash_changes() return result - def _collect_outputs(self, output_dir): - output_klass = self.interface.Outputs - output = output_klass( - **{f.name: attr.NOTHING for f in attr.fields(output_klass)} - ) - other_output = output.collect_additional_outputs( - self.inputs, output_dir, self.output_ - ) - return attr.evolve(output, **self.output_, **other_output) - def _extract_input_el(self, inputs, inp_nm, ind): """ Extracting element of the inputs taking into account @@ -603,7 +512,7 @@ def get_input_el(self, ind): for inp in set(self.input_names): if f"{self.name}.{inp}" in input_ind: inputs_dict[inp] = self._extract_input_el( - inputs=self.inputs, + inputs=self.spec, inp_nm=inp, ind=input_ind[f"{self.name}.{inp}"], ) @@ -626,7 +535,7 @@ def pickle_task(self): def done(self): """Check whether the tasks has been finalized and all outputs are stored.""" # if any of the field is lazy, there is no need to check results - if has_lazy(self.inputs): + if has_lazy(self.spec): return False _result = self.result() if self.state: @@ -699,73 +608,40 @@ def result(self, state_index=None, return_inputs=False): # return a future if not if self.errored: return Result(output=None, runtime=None, errored=True) - if self.state: - if state_index is None: - # if state_index=None, collecting all results - if self.state.combiner: - return self._combined_output(return_inputs=return_inputs) - else: - results = [] - for ind in range(len(self.state.inputs_ind)): - checksum = self.checksum_states(state_index=ind) - result = load_result(checksum, self.cache_locations) - if result is None: - return None - results.append(result) - if return_inputs is True or return_inputs == "val": - return list(zip(self.state.states_val, results)) - elif return_inputs == "ind": - return list(zip(self.state.states_ind, results)) - else: - return results - else: # state_index is not None - if self.state.combiner: - return self._combined_output(return_inputs=return_inputs)[ - state_index - ] - result = load_result( - self.checksum_states(state_index), self.cache_locations - ) - if return_inputs is True or return_inputs == "val": - return (self.state.states_val[state_index], result) - elif return_inputs == "ind": - return (self.state.states_ind[state_index], result) - else: - return result + + if state_index is not None: + raise ValueError("Task does not have a state") + checksum = self.checksum + result = load_result(checksum, self.cache_locations) + if result and result.errored: + self._errored = True + if return_inputs is True or return_inputs == "val": + inputs_val = { + f"{self.name}.{inp}": getattr(self.spec, inp) + for inp in self.input_names + } + return (inputs_val, result) + elif return_inputs == "ind": + inputs_ind = {f"{self.name}.{inp}": None for inp in self.input_names} + return (inputs_ind, result) else: - if state_index is not None: - raise ValueError("Task does not have a state") - checksum = self.checksum - result = load_result(checksum, self.cache_locations) - if result and result.errored: - self._errored = True - if return_inputs is True or return_inputs == "val": - inputs_val = { - f"{self.name}.{inp}": getattr(self.inputs, inp) - for inp in self.input_names - } - return (inputs_val, result) - elif return_inputs == "ind": - inputs_ind = {f"{self.name}.{inp}": None for inp in self.input_names} - return (inputs_ind, result) - else: - return result + return result def _reset(self): """Reset the connections between inputs and LazyFields.""" - for field in attrs_fields(self.inputs): + for field in attrs_fields(self.spec): if field.name in self.inp_lf: - setattr(self.inputs, field.name, self.inp_lf[field.name]) + setattr(self.spec, field.name, self.inp_lf[field.name]) if is_workflow(self): for task in self.graph.nodes: task._reset() def _check_for_hash_changes(self): - hash_changes = self.inputs.hash_changes() + hash_changes = self.spec._hash_changes() details = "" for changed in hash_changes: - field = getattr(attr.fields(type(self.inputs)), changed) - val = getattr(self.inputs, changed) + field = getattr(attr.fields(type(self.spec)), changed) + val = getattr(self.spec, changed) field_type = type(val) if issubclass(field.type, FileSet): details += ( @@ -797,8 +673,8 @@ def _check_for_hash_changes(self): "Input values and hashes for '%s' %s node:\n%s\n%s", self.name, type(self).__name__, - self.inputs, - self.inputs._hashes, + self.spec, + self.spec._hashes, ) SUPPORTED_COPY_MODES = FileSet.CopyMode.any @@ -906,12 +782,12 @@ def checksum(self): (before the tasks have inputs etc.) """ # if checksum is called before run the _graph_checksums is not ready - if is_workflow(self) and self.inputs._graph_checksums is attr.NOTHING: - self.inputs._graph_checksums = { + if is_workflow(self) and self.spec._graph_checksums is attr.NOTHING: + self.spec._graph_checksums = { nd.name: nd.checksum for nd in self.graph_sorted } - input_hash = self.inputs.hash + input_hash = self.spec.hash if not self.state: self._checksum = create_checksum( self.__class__.__name__, self._checksum_wf(input_hash) @@ -1190,7 +1066,7 @@ async def _run_task(self, submitter, rerun=False, environment=None): # logger.info("Added %s to %s", self.output_spec, self) def _collect_outputs(self): - output_klass = self.interface.Outputs + output_klass = self.spec.Outputs output = output_klass( **{f.name: attr.NOTHING for f in attr.fields(output_klass)} ) diff --git a/pydra/engine/environments.py b/pydra/engine/environments.py index 0c57008058..80193c87db 100644 --- a/pydra/engine/environments.py +++ b/pydra/engine/environments.py @@ -1,7 +1,10 @@ +import typing as ty from .helpers import execute - from pathlib import Path +if ty.TYPE_CHECKING: + from pydra.engine.task import ShellTask + class Environment: """ @@ -14,7 +17,7 @@ class Environment: def setup(self): pass - def execute(self, task): + def execute(self, task: "ShellTask"): """ Execute the task in the environment. @@ -39,7 +42,7 @@ class Native(Environment): Native environment, i.e. the tasks are executed in the current python environment. """ - def execute(self, task): + def execute(self, task: "ShellTask"): keys = ["return_code", "stdout", "stderr"] values = execute(task.command_args(), strip=task.strip) output = dict(zip(keys, values)) @@ -87,7 +90,7 @@ def bind(self, loc, mode="ro"): class Docker(Container): """Docker environment.""" - def execute(self, task): + def execute(self, task: "ShellTask"): docker_img = f"{self.image}:{self.tag}" # mounting all input locations mounts = task.get_bindings(root=self.root) @@ -123,7 +126,7 @@ def execute(self, task): class Singularity(Container): """Singularity environment.""" - def execute(self, task): + def execute(self, task: "ShellTask"): singularity_img = f"{self.image}:{self.tag}" # mounting all input locations mounts = task.get_bindings(root=self.root) diff --git a/pydra/engine/specs.py b/pydra/engine/specs.py index a643f0fef2..b026c342cc 100644 --- a/pydra/engine/specs.py +++ b/pydra/engine/specs.py @@ -3,8 +3,11 @@ import os from pathlib import Path import re +from copy import copy import inspect import itertools +import platform +import shlex import typing as ty from glob import glob from copy import deepcopy @@ -13,22 +16,56 @@ from fileformats.generic import File from pydra.engine.audit import AuditFlag from pydra.utils.typing import TypeParser, MultiOutputObj -from .helpers import attrs_fields, attrs_values, is_lazy, list_fields -from .helpers_file import template_update_single +from .helpers import ( + attrs_fields, + attrs_values, + is_lazy, + list_fields, + position_sort, + ensure_list, + parse_format_string, +) +from .helpers_file import template_update_single, template_update from pydra.utils.hash import hash_function, Cache -from pydra.design.base import Field, Arg, Out, RequirementSet +from pydra.design.base import Field, Arg, Out, RequirementSet, EMPTY from pydra.design import shell +if ty.TYPE_CHECKING: + from pydra.engine.core import Task + from pydra.engine.task import ShellTask + def is_set(value: ty.Any) -> bool: """Check if a value has been set.""" - return value is not attrs.NOTHING + return value not in (attrs.NOTHING, EMPTY) class Outputs: """Base class for all output specifications""" - RESERVED_FIELD_NAMES = ("split", "combine") + RESERVED_FIELD_NAMES = ("inputs", "split", "combine") + + @classmethod + def from_task(cls, task: "Task") -> Self: + """Collect the outputs of a task from a combination of the provided inputs, + the objects in the output directory, and the stdout and stderr of the process. + + Parameters + ---------- + task : Task + The task whose outputs are being collected. + + Returns + ------- + outputs : Outputs + The outputs of the task + """ + return cls(**{f.name: attrs.NOTHING for f in attrs_fields(cls)}) + + @property + def inputs(self): + """The inputs object associated with a lazy-outputs object""" + return self._get_node().inputs def split( self, @@ -62,7 +99,9 @@ def split( self : TaskBase a reference to the task """ - self._node.split(splitter, overwrite=overwrite, cont_dim=cont_dim, **inputs) + self._get_node().split( + splitter, overwrite=overwrite, cont_dim=cont_dim, **inputs + ) return self def combine( @@ -89,9 +128,17 @@ def combine( self : Self a reference to the outputs object """ - self._node.combine(combiner, overwrite=overwrite) + self._get_node().combine(combiner, overwrite=overwrite) return self + def _get_node(self): + try: + return self._node + except AttributeError: + raise AttributeError( + f"{self} outputs object is not a lazy output of a workflow node" + ) + OutputsType = ty.TypeVar("OutputType", bound=Outputs) @@ -101,6 +148,8 @@ class TaskSpec(ty.Generic[OutputsType]): Task: "ty.Type[core.Task]" + RESERVED_FIELD_NAMES = () + def __call__( self, name: str | None = None, @@ -116,7 +165,7 @@ def __call__( ): self._check_rules() task = self.Task( - self, + spec=self, name=name, audit_flags=audit_flags, cache_dir=cache_dir, @@ -175,6 +224,9 @@ def _check_rules(self): for field in list_fields(self): value = getattr(self, field.name) + if is_lazy(value): + continue + # Collect alternative fields associated with this field. if field.xor: alternative_fields = { @@ -182,9 +234,7 @@ def _check_rules(self): for name in field.xor if name != field.name } - set_alternatives = { - n: v for n, v in alternative_fields.items() if is_set(v) - } + set_alternatives = {n: v for n, v in alternative_fields.items() if v} # Raise error if no field in mandatory alternative group is set. if not is_set(value): @@ -206,10 +256,19 @@ def _check_rules(self): ) # Raise error if any required field is unset. - if field.requires and not any(rs.satisfied(self) for rs in field.requires): + if ( + value + and field.requires + and not any(rs.satisfied(self) for rs in field.requires) + ): + if len(field.requires) > 1: + qualification = ( + " at least one of the following requirements to be satisfied: " + ) + else: + qualification = "" raise ValueError( - f"{field.name} requires at least one of the requirement sets to be " - f"satisfied: {[str(r) for r in field.requires]}" + f"{field.name!r} requires{qualification} {[str(r) for r in field.requires]}" ) @classmethod @@ -235,6 +294,14 @@ def _check_arg_refs(cls, inputs: list[Arg], outputs: list[Out]) -> None: f"of {inpt} " + str(list(unrecognised)) ) + def _check_resolved(self): + """Checks that all the fields in the spec have been resolved""" + if has_lazy_values := [n for n, v in attrs_values(self).items() if is_lazy(v)]: + raise ValueError( + f"Cannot execute {self} because the following fields " + f"still have lazy values {has_lazy_values}" + ) + @attrs.define(kw_only=True) class Runtime: @@ -257,7 +324,7 @@ class Result: errored: bool = False def __getstate__(self): - state = self.__dict__.copy() + state = attrs_values(self) if state["output"] is not None: fields = tuple((el.name, el.type) for el in attrs_fields(state["output"])) state["output_spec"] = (state["output"].__class__.__name__, fields) @@ -348,13 +415,9 @@ class ShellOutputs(Outputs): stderr: str = shell.out(help_string=STDERR_HELP) @classmethod - def collect_outputs( + def from_task( cls, - inputs: "ShellSpec", - output_dir: Path, - stdout: str, - stderr: str, - return_code: int, + task: "ShellTask", ) -> Self: """Collect the outputs of a shell process from a combination of the provided inputs, the objects in the output directory, and the stdout and stderr of the process. @@ -378,9 +441,15 @@ def collect_outputs( The outputs of the shell process """ - outputs = cls(return_code=return_code, stdout=stdout, stderr=stderr) + outputs = cls( + return_code=task.output_["return_code"], + stdout=task.output_["stdout"], + stderr=task.output_["stderr"], + ) fld: shell.out for fld in list_fields(cls): + if fld.name in ["return_code", "stdout", "stderr"]: + continue if not TypeParser.is_subclass( fld.type, ( @@ -399,17 +468,17 @@ def collect_outputs( ) # Get the corresponding value from the inputs if it exists, which will be # passed through to the outputs, to permit manual overrides - if isinstance(fld, shell.outarg) and is_set(getattr(inputs, fld.name)): - resolved_value = getattr(inputs, fld.name) + if isinstance(fld, shell.outarg) and is_set(getattr(task.inputs, fld.name)): + resolved_value = getattr(task.spec, fld.name) elif is_set(fld.default): - resolved_value = cls._resolve_default_value(fld, output_dir) + resolved_value = cls._resolve_default_value(fld, task.output_dir) else: if fld.type in [int, float, bool, str, list] and not fld.callable: raise AttributeError( f"{fld.type} has to have a callable in metadata" ) resolved_value = cls._generate_implicit_value( - fld, inputs, output_dir, outputs, stdout, stderr + fld, task.spec, task.output_dir, outputs.stdout, outputs.stderr ) # Set the resolved value setattr(outputs, fld.name, resolved_value) @@ -543,10 +612,204 @@ def _required_fields_satisfied(cls, fld: shell.out, inputs: "ShellSpec") -> bool class ShellSpec(TaskSpec[ShellOutputsType]): - pass + + RESERVED_FIELD_NAMES = ("cmdline",) + + @property + def cmdline(self) -> str: + """The equivalent command line that would be submitted if the task were run on + the current working directory.""" + # checking the inputs fields before returning the command line + self._check_resolved() + # Skip the executable, which can be a multi-part command, e.g. 'docker run'. + cmd_args = self._command_args() + cmdline = cmd_args[0] + for arg in cmd_args[1:]: + # If there are spaces in the arg, and it is not enclosed by matching + # quotes, add quotes to escape the space. Not sure if this should + # be expanded to include other special characters apart from spaces + if " " in arg: + cmdline += " '" + arg + "'" + else: + cmdline += " " + arg + return cmdline + + def _command_args( + self, + output_dir: Path | None = None, + input_updates: dict[str, ty.Any] | None = None, + root: Path | None = None, + ) -> list[str]: + """Get command line arguments""" + if output_dir is None: + output_dir = Path.cwd() + self._check_resolved() + inputs = attrs_values(self) + modified_inputs = template_update(self, output_dir=output_dir) + if input_updates: + inputs.update(input_updates) + inputs.update(modified_inputs) + pos_args = [] # list for (position, command arg) + self._positions_provided = [] + for field in list_fields(self): + name = field.name + value = inputs[name] + if value is None: + continue + if name == "executable": + pos_args.append(self._command_shelltask_executable(field, value)) + elif name == "args": + pos_val = self._command_shelltask_args(field, value) + if pos_val: + pos_args.append(pos_val) + else: + if name in modified_inputs: + pos_val = self._command_pos_args( + field, value, output_dir, root=root + ) + else: + pos_val = self._command_pos_args(field, value, output_dir, inputs) + if pos_val: + pos_args.append(pos_val) + + # Sort command and arguments by position + cmd_args = position_sort(pos_args) + # pos_args values are each a list of arguments, so concatenate lists after sorting + return sum(cmd_args, []) + + def _command_shelltask_executable( + self, field: shell.arg, value: ty.Any + ) -> tuple[int, ty.Any]: + """Returning position and value for executable ShellTask input""" + pos = 0 # executable should be the first el. of the command + assert value + return pos, ensure_list(value, tuple2list=True) + + def _command_shelltask_args( + self, field: shell.arg, value: ty.Any + ) -> tuple[int, ty.Any]: + """Returning position and value for args ShellTask input""" + pos = -1 # assuming that args is the last el. of the command + if value is None: + return None + else: + return pos, ensure_list(value, tuple2list=True) + + def _command_pos_args( + self, + field: shell.arg, + value: ty.Any, + inputs: dict[str, ty.Any], + output_dir: Path, + root: Path | None = None, + ) -> tuple[int, ty.Any]: + """ + Checking all additional input fields, setting pos to None, if position not set. + Creating a list with additional parts of the command that comes from + the specific field. + """ + if field.argstr is None and field.formatter is None: + # assuming that input that has no argstr is not used in the command, + # or a formatter is not provided too. + return None + if field.position is not None: + if not isinstance(field.position, int): + raise Exception( + f"position should be an integer, but {field.position} given" + ) + # checking if the position is not already used + if field.position in self._positions_provided: + raise Exception( + f"{field.name} can't have provided position, {field.position} is already used" + ) + + self._positions_provided.append(field.position) + + # Shift non-negatives up to allow executable to be 0 + # Shift negatives down to allow args to be -1 + field.position += 1 if field.position >= 0 else -1 + + if value: + if root: # values from templates + value = value.replace(str(output_dir), f"{root}{output_dir}") + + if field.readonly and value is not None: + raise Exception(f"{field.name} is read only, the value can't be provided") + elif value is None and not field.readonly and field.formatter is None: + return None + + cmd_add = [] + # formatter that creates a custom command argument + # it can take the value of the field, all inputs, or the value of other fields. + if field.formatter: + call_args = inspect.getfullargspec(field.formatter) + call_args_val = {} + for argnm in call_args.args: + if argnm == "field": + call_args_val[argnm] = value + elif argnm == "inputs": + call_args_val[argnm] = inputs + else: + if argnm in inputs: + call_args_val[argnm] = inputs[argnm] + else: + raise AttributeError( + f"arguments of the formatter function from {field.name} " + f"has to be in inputs or be field or output_dir, " + f"but {argnm} is used" + ) + cmd_el_str = field.formatter(**call_args_val) + cmd_el_str = cmd_el_str.strip().replace(" ", " ") + if cmd_el_str != "": + cmd_add += split_cmd(cmd_el_str) + elif field.type is bool and "{" not in field.argstr: + # if value is simply True the original argstr is used, + # if False, nothing is added to the command. + if value is True: + cmd_add.append(field.argstr) + else: + if ( + field.argstr.endswith("...") + and isinstance(value, ty.Iterable) + and not isinstance(value, (str, bytes)) + ): + field.argstr = field.argstr.replace("...", "") + # if argstr has a more complex form, with "{input_field}" + if "{" in field.argstr and "}" in field.argstr: + argstr_formatted_l = [] + for val in value: + argstr_f = argstr_formatting( + field.argstr, self, value_updates={field.name: val} + ) + argstr_formatted_l.append(f" {argstr_f}") + cmd_el_str = field.sep.join(argstr_formatted_l) + else: # argstr has a simple form, e.g. "-f", or "--f" + cmd_el_str = field.sep.join( + [f" {field.argstr} {val}" for val in value] + ) + else: + # in case there are ... when input is not a list + field.argstr = field.argstr.replace("...", "") + if isinstance(value, ty.Iterable) and not isinstance( + value, (str, bytes) + ): + cmd_el_str = field.sep.join([str(val) for val in value]) + value = cmd_el_str + # if argstr has a more complex form, with "{input_field}" + if "{" in field.argstr and "}" in field.argstr: + cmd_el_str = field.argstr.replace(f"{{{field.name}}}", str(value)) + cmd_el_str = argstr_formatting(cmd_el_str, self.spec) + else: # argstr has a simple form, e.g. "-f", or "--f" + if value: + cmd_el_str = f"{field.argstr} {value}" + else: + cmd_el_str = "" + if cmd_el_str: + cmd_add += split_cmd(cmd_el_str) + return field.position, cmd_add -def donothing(*args, **kwargs): +def donothing(*args: ty.Any, **kwargs: ty.Any) -> None: return None @@ -569,4 +832,69 @@ def reset(self): setattr(self, val, donothing) +def split_cmd(cmd: str): + """Splits a shell command line into separate arguments respecting quotes + + Parameters + ---------- + cmd : str + Command line string or part thereof + + Returns + ------- + str + the command line string split into process args + """ + # Check whether running on posix or Windows system + on_posix = platform.system() != "Windows" + args = shlex.split(cmd, posix=on_posix) + cmd_args = [] + for arg in args: + match = re.match("(['\"])(.*)\\1$", arg) + if match: + cmd_args.append(match.group(2)) + else: + cmd_args.append(arg) + return cmd_args + + +def argstr_formatting( + argstr: str, inputs: dict[str, ty.Any], value_updates: dict[str, ty.Any] = None +): + """formatting argstr that have form {field_name}, + using values from inputs and updating with value_update if provided + """ + # if there is a value that has to be updated (e.g. single value from a list) + # getting all fields that should be formatted, i.e. {field_name}, ... + if value_updates: + inputs = copy(inputs) + inputs.update(value_updates) + inp_fields = parse_format_string(argstr) + val_dict = {} + for fld_name in inp_fields: + fld_value = inputs[fld_name] + fld_attr = getattr(attrs.fields(type(inputs)), fld_name) + if fld_value is None or ( + fld_value is False + and fld_attr.type is not bool + and TypeParser.matches_type(fld_attr.type, ty.Union[Path, bool]) + ): + # if value is NOTHING, nothing should be added to the command + val_dict[fld_name] = "" + else: + val_dict[fld_name] = fld_value + + # formatting string based on the val_dict + argstr_formatted = argstr.format(**val_dict) + # removing extra commas and spaces after removing the field that have NOTHING + argstr_formatted = ( + argstr_formatted.replace("[ ", "[") + .replace(" ]", "]") + .replace("[,", "[") + .replace(",]", "]") + .strip() + ) + return argstr_formatted + + from pydra.engine import core # noqa: E402 diff --git a/pydra/engine/task.py b/pydra/engine/task.py index 78ab415d38..ab226e8ba4 100644 --- a/pydra/engine/task.py +++ b/pydra/engine/task.py @@ -41,31 +41,21 @@ from __future__ import annotations -import platform -import re import attr -import attrs -import inspect -import typing as ty -import shlex from pathlib import Path import cloudpickle as cp from fileformats.core import FileSet from .core import Task from pydra.utils.messenger import AuditFlag from .specs import ( + PythonSpec, ShellSpec, attrs_fields, ) from .helpers import ( attrs_values, - is_lazy, - parse_format_string, - position_sort, - ensure_list, parse_copyfile, ) -from .helpers_file import template_update from pydra.utils.typing import TypeParser from .environments import Native @@ -73,12 +63,14 @@ class PythonTask(Task): """Wrap a Python callable as a task element.""" + spec: PythonSpec + def _run_task(self, environment=None): - inputs = attrs_values(self.inputs) + inputs = attrs_values(self.spec) del inputs["_func"] self.output_ = None - output = cp.loads(self.inputs._func)(**inputs) - output_names = [f.name for f in attr.fields(self.interface.Outputs)] + output = cp.loads(self.spec._func)(**inputs) + output_names = [f.name for f in attr.fields(self.spec.Outputs)] if output is None: self.output_ = {nm: None for nm in output_names} elif len(output_names) == 1: @@ -98,6 +90,8 @@ def _run_task(self, environment=None): class ShellTask(Task): """Wrap a shell command as a task element.""" + spec: ShellSpec + def __init__( self, spec: ShellSpec, @@ -137,7 +131,11 @@ def __init__( strip : :obj:`bool` TODO """ + self.return_code = None + self.stdout = None + self.stderr = None super().__init__( + spec=spec, name=name, inputs=kwargs, cont_dim=cont_dim, @@ -174,212 +172,8 @@ def get_bindings(self, root: str | None = None) -> dict[str, tuple[str, str]]: self._prepare_bindings(root=root) return self.bindings - def command_args(self, root=None): - """Get command line arguments""" - if is_lazy(self.inputs): - raise Exception("can't return cmdline, self.inputs has LazyFields") - if self.state: - raise NotImplementedError - - modified_inputs = template_update(self.inputs, output_dir=self.output_dir) - for field_name, field_value in modified_inputs.items(): - setattr(self.inputs, field_name, field_value) - - pos_args = [] # list for (position, command arg) - self._positions_provided = [] - for field in attrs_fields(self.inputs): - name, meta = field.name, field.metadata - if ( - getattr(self.inputs, name) is attr.NOTHING - and not meta.get("readonly") - and not meta.get("formatter") - ): - continue - if name == "executable": - pos_args.append(self._command_shelltask_executable(field)) - elif name == "args": - pos_val = self._command_shelltask_args(field) - if pos_val: - pos_args.append(pos_val) - else: - if name in modified_inputs: - pos_val = self._command_pos_args(field, root=root) - else: - pos_val = self._command_pos_args(field) - if pos_val: - pos_args.append(pos_val) - - # Sort command and arguments by position - cmd_args = position_sort(pos_args) - # pos_args values are each a list of arguments, so concatenate lists after sorting - return sum(cmd_args, []) - - def _field_value(self, field, check_file=False): - """ - Checking value of the specific field, if value is not set, None is returned. - check_file has no effect, but subclasses can use it to validate or modify - filenames. - """ - value = getattr(self.inputs, field.name) - if value == attr.NOTHING: - value = None - return value - - def _command_shelltask_executable(self, field): - """Returning position and value for executable ShellTask input""" - pos = 0 # executable should be the first el. of the command - value = self._field_value(field) - if value is None: - raise ValueError("executable has to be set") - return pos, ensure_list(value, tuple2list=True) - - def _command_shelltask_args(self, field): - """Returning position and value for args ShellTask input""" - pos = -1 # assuming that args is the last el. of the command - value = self._field_value(field, check_file=True) - if value is None: - return None - else: - return pos, ensure_list(value, tuple2list=True) - - def _command_pos_args(self, field, root=None): - """ - Checking all additional input fields, setting pos to None, if position not set. - Creating a list with additional parts of the command that comes from - the specific field. - """ - argstr = field.metadata.get("argstr", None) - formatter = field.metadata.get("formatter", None) - if argstr is None and formatter is None: - # assuming that input that has no argstr is not used in the command, - # or a formatter is not provided too. - return None - pos = field.metadata.get("position", None) - if pos is not None: - if not isinstance(pos, int): - raise Exception(f"position should be an integer, but {pos} given") - # checking if the position is not already used - if pos in self._positions_provided: - raise Exception( - f"{field.name} can't have provided position, {pos} is already used" - ) - - self._positions_provided.append(pos) - - # Shift non-negatives up to allow executable to be 0 - # Shift negatives down to allow args to be -1 - pos += 1 if pos >= 0 else -1 - - value = self._field_value(field, check_file=True) - - if value: - if field.name in self.inputs_mod_root: - value = self.inputs_mod_root[field.name] - elif root: # values from templates - value = value.replace(str(self.output_dir), f"{root}{self.output_dir}") - - if field.metadata.get("readonly", False) and value is not None: - raise Exception(f"{field.name} is read only, the value can't be provided") - elif ( - value is None - and not field.metadata.get("readonly", False) - and formatter is None - ): - return None - - inputs_dict = attrs_values(self.inputs) - - cmd_add = [] - # formatter that creates a custom command argument - # it can take the value of the field, all inputs, or the value of other fields. - if "formatter" in field.metadata: - call_args = inspect.getfullargspec(field.metadata["formatter"]) - call_args_val = {} - for argnm in call_args.args: - if argnm == "field": - call_args_val[argnm] = value - elif argnm == "inputs": - call_args_val[argnm] = inputs_dict - else: - if argnm in inputs_dict: - call_args_val[argnm] = inputs_dict[argnm] - else: - raise AttributeError( - f"arguments of the formatter function from {field.name} " - f"has to be in inputs or be field or output_dir, " - f"but {argnm} is used" - ) - cmd_el_str = field.metadata["formatter"](**call_args_val) - cmd_el_str = cmd_el_str.strip().replace(" ", " ") - if cmd_el_str != "": - cmd_add += split_cmd(cmd_el_str) - elif field.type is bool and "{" not in argstr: - # if value is simply True the original argstr is used, - # if False, nothing is added to the command. - if value is True: - cmd_add.append(argstr) - else: - sep = field.metadata.get("sep", " ") - if ( - argstr.endswith("...") - and isinstance(value, ty.Iterable) - and not isinstance(value, (str, bytes)) - ): - argstr = argstr.replace("...", "") - # if argstr has a more complex form, with "{input_field}" - if "{" in argstr and "}" in argstr: - argstr_formatted_l = [] - for val in value: - argstr_f = argstr_formatting( - argstr, self.inputs, value_updates={field.name: val} - ) - argstr_formatted_l.append(f" {argstr_f}") - cmd_el_str = sep.join(argstr_formatted_l) - else: # argstr has a simple form, e.g. "-f", or "--f" - cmd_el_str = sep.join([f" {argstr} {val}" for val in value]) - else: - # in case there are ... when input is not a list - argstr = argstr.replace("...", "") - if isinstance(value, ty.Iterable) and not isinstance( - value, (str, bytes) - ): - cmd_el_str = sep.join([str(val) for val in value]) - value = cmd_el_str - # if argstr has a more complex form, with "{input_field}" - if "{" in argstr and "}" in argstr: - cmd_el_str = argstr.replace(f"{{{field.name}}}", str(value)) - cmd_el_str = argstr_formatting(cmd_el_str, self.inputs) - else: # argstr has a simple form, e.g. "-f", or "--f" - if value: - cmd_el_str = f"{argstr} {value}" - else: - cmd_el_str = "" - if cmd_el_str: - cmd_add += split_cmd(cmd_el_str) - return pos, cmd_add - - @property - def cmdline(self): - """Get the actual command line that will be submitted - Returns a list if the task has a state. - """ - if is_lazy(self.inputs): - raise Exception("can't return cmdline, self.inputs has LazyFields") - # checking the inputs fields before returning the command line - self.inputs.check_fields_input_spec() - if self.state: - raise NotImplementedError - # Skip the executable, which can be a multi-part command, e.g. 'docker run'. - cmdline = self.command_args()[0] - for arg in self.command_args()[1:]: - # If there are spaces in the arg, and it is not enclosed by matching - # quotes, add quotes to escape the space. Not sure if this should - # be expanded to include other special characters apart from spaces - if " " in arg: - cmdline += " '" + arg + "'" - else: - cmdline += " " + arg - return cmdline + def command_args(self, root: Path | None = None) -> list[str]: + return self.spec._command_args(input_updates=self.inputs_mod_root, root=root) def _run_task(self, environment=None): if environment is None: @@ -392,9 +186,9 @@ def _prepare_bindings(self, root: str): This updates the ``bindings`` attribute of the current task to make files available in an ``Environment``-defined ``root``. """ - for fld in attrs_fields(self.inputs): + for fld in attrs_fields(self.spec): if TypeParser.contains_type(FileSet, fld.type): - fileset = getattr(self.inputs, fld.name) + fileset = getattr(self.spec, fld.name) copy = parse_copyfile(fld)[0] == FileSet.CopyMode.copy host_path, env_path = fileset.parent, Path(f"{root}{fileset.parent}") @@ -409,66 +203,3 @@ def _prepare_bindings(self, root: str): ) DEFAULT_COPY_COLLATION = FileSet.CopyCollation.adjacent - - -def split_cmd(cmd: str): - """Splits a shell command line into separate arguments respecting quotes - - Parameters - ---------- - cmd : str - Command line string or part thereof - - Returns - ------- - str - the command line string split into process args - """ - # Check whether running on posix or Windows system - on_posix = platform.system() != "Windows" - args = shlex.split(cmd, posix=on_posix) - cmd_args = [] - for arg in args: - match = re.match("(['\"])(.*)\\1$", arg) - if match: - cmd_args.append(match.group(2)) - else: - cmd_args.append(arg) - return cmd_args - - -def argstr_formatting(argstr, inputs, value_updates=None): - """formatting argstr that have form {field_name}, - using values from inputs and updating with value_update if provided - """ - inputs_dict = attrs_values(inputs) - # if there is a value that has to be updated (e.g. single value from a list) - if value_updates: - inputs_dict.update(value_updates) - # getting all fields that should be formatted, i.e. {field_name}, ... - inp_fields = parse_format_string(argstr) - val_dict = {} - for fld_name in inp_fields: - fld_value = inputs_dict[fld_name] - fld_attr = getattr(attrs.fields(type(inputs)), fld_name) - if fld_value is attr.NOTHING or ( - fld_value is False - and fld_attr.type is not bool - and TypeParser.matches_type(fld_attr.type, ty.Union[Path, bool]) - ): - # if value is NOTHING, nothing should be added to the command - val_dict[fld_name] = "" - else: - val_dict[fld_name] = fld_value - - # formatting string based on the val_dict - argstr_formatted = argstr.format(**val_dict) - # removing extra commas and spaces after removing the field that have NOTHING - argstr_formatted = ( - argstr_formatted.replace("[ ", "[") - .replace(" ]", "]") - .replace("[,", "[") - .replace(",]", "]") - .strip() - ) - return argstr_formatted diff --git a/pydra/engine/tests/test_boutiques.py b/pydra/engine/tests/test_boutiques.py index 28da1f176a..c951091887 100644 --- a/pydra/engine/tests/test_boutiques.py +++ b/pydra/engine/tests/test_boutiques.py @@ -29,8 +29,8 @@ def test_boutiques_1(maskfile, plugin, results_function, tmpdir, data_tests_dir): """simple task to run fsl.bet using BoshTask""" btask = BoshTask(name="NA", zenodo_id="1482743") - btask.inputs.infile = data_tests_dir / "test.nii.gz" - btask.inputs.maskfile = maskfile + btask.spec.infile = data_tests_dir / "test.nii.gz" + btask.spec.maskfile = maskfile btask.cache_dir = tmpdir res = results_function(btask, plugin) @@ -60,8 +60,8 @@ def test_boutiques_spec_1(data_tests_dir): assert len(btask.input_spec.fields) == 2 assert btask.input_spec.fields[0][0] == "infile" assert btask.input_spec.fields[1][0] == "maskfile" - assert hasattr(btask.inputs, "infile") - assert hasattr(btask.inputs, "maskfile") + assert hasattr(btask.spec, "infile") + assert hasattr(btask.spec, "maskfile") assert len(btask.output_spec.fields) == 2 assert btask.output_spec.fields[0][0] == "outfile" @@ -84,9 +84,9 @@ def test_boutiques_spec_2(data_tests_dir): assert len(btask.input_spec.fields) == 1 assert btask.input_spec.fields[0][0] == "infile" - assert hasattr(btask.inputs, "infile") + assert hasattr(btask.spec, "infile") # input doesn't see maskfile - assert not hasattr(btask.inputs, "maskfile") + assert not hasattr(btask.spec, "maskfile") assert len(btask.output_spec.fields) == 0 diff --git a/pydra/engine/tests/test_dockertask.py b/pydra/engine/tests/test_dockertask.py index 5f69584d60..cc196cd87c 100644 --- a/pydra/engine/tests/test_dockertask.py +++ b/pydra/engine/tests/test_dockertask.py @@ -77,7 +77,7 @@ def test_docker_2a(results_function, plugin): args=cmd_args, environment=Docker(image="busybox"), ) - assert docky.inputs.executable == "echo" + assert docky.spec.executable == "echo" assert docky.cmdline == f"{cmd_exec} {' '.join(cmd_args)}" res = results_function(docky, plugin) @@ -332,7 +332,7 @@ def test_docker_inputspec_2a_except(plugin, tmp_path): input_spec=my_input_spec, strip=True, ) - assert docky.inputs.file2.fspath == filename_2 + assert docky.spec.file2.fspath == filename_2 res = docky() assert res.output.stdout == "hello from pydra\nhave a nice one" diff --git a/pydra/engine/tests/test_helpers_file.py b/pydra/engine/tests/test_helpers_file.py index 915d183973..7db3f8d34f 100644 --- a/pydra/engine/tests/test_helpers_file.py +++ b/pydra/engine/tests/test_helpers_file.py @@ -394,11 +394,11 @@ class MyCommand(ShellTask): task = MyCommand(in_file=filename) assert task.cmdline == f"my {filename}" - task.inputs.optional = True + task.spec.optional = True assert task.cmdline == f"my {filename} --opt {task.output_dir / 'file.out'}" - task.inputs.optional = False + task.spec.optional = False assert task.cmdline == f"my {filename}" - task.inputs.optional = "custom-file-out.txt" + task.spec.optional = "custom-file-out.txt" assert task.cmdline == f"my {filename} --opt custom-file-out.txt" diff --git a/pydra/engine/tests/test_nipype1_convert.py b/pydra/engine/tests/test_nipype1_convert.py index 4dc6f80369..2f5abbfb76 100644 --- a/pydra/engine/tests/test_nipype1_convert.py +++ b/pydra/engine/tests/test_nipype1_convert.py @@ -92,7 +92,7 @@ def test_interface_executable_1(): """testing if the class executable is properly set and used in the command line""" task = Interf_2() assert task.executable == "testing command" - assert task.inputs.executable == "testing command" + assert task.spec.executable == "testing command" assert task.cmdline == "testing command" @@ -103,14 +103,14 @@ def test_interface_executable_2(): task = Interf_2(executable="i want a different command") assert task.executable == "testing command" # task.executable stays the same, but input.executable is changed, so the cmd is changed - assert task.inputs.executable == "i want a different command" + assert task.spec.executable == "i want a different command" assert task.cmdline == "i want a different command" def test_interface_cmdline_with_spaces(): task = Interf_3(in_file="/path/to/file/with spaces") assert task.executable == "testing command" - assert task.inputs.executable == "testing command" + assert task.spec.executable == "testing command" assert task.cmdline == "testing command '/path/to/file/with spaces'" diff --git a/pydra/engine/tests/test_shelltask.py b/pydra/engine/tests/test_shelltask.py index 631f72ff73..b8591092f4 100644 --- a/pydra/engine/tests/test_shelltask.py +++ b/pydra/engine/tests/test_shelltask.py @@ -78,7 +78,7 @@ def test_shell_cmd_2a(plugin, results_function, tmp_path): # separate command into exec + args shelly = ShellTask(name="shelly", executable=cmd_exec, args=cmd_args) shelly.cache_dir = tmp_path - assert shelly.inputs.executable == "echo" + assert shelly.spec.executable == "echo" assert shelly.cmdline == "echo " + " ".join(cmd_args) res = results_function(shelly, plugin) @@ -95,7 +95,7 @@ def test_shell_cmd_2b(plugin, results_function, tmp_path): # separate command into exec + args shelly = ShellTask(name="shelly", executable=cmd_exec, args=cmd_args) shelly.cache_dir = tmp_path - assert shelly.inputs.executable == "echo" + assert shelly.spec.executable == "echo" assert shelly.cmdline == "echo pydra" res = results_function(shelly, plugin) @@ -307,8 +307,8 @@ def test_shell_cmd_inputspec_1(plugin, results_function, tmp_path): input_spec=my_input_spec, cache_dir=tmp_path, ) - assert shelly.inputs.executable == cmd_exec - assert shelly.inputs.args == cmd_args + assert shelly.spec.executable == cmd_exec + assert shelly.spec.args == cmd_args assert shelly.cmdline == "echo -n 'hello from pydra'" res = results_function(shelly, plugin) @@ -356,8 +356,8 @@ def test_shell_cmd_inputspec_2(plugin, results_function, tmp_path): input_spec=my_input_spec, cache_dir=tmp_path, ) - assert shelly.inputs.executable == cmd_exec - assert shelly.inputs.args == cmd_args + assert shelly.spec.executable == cmd_exec + assert shelly.spec.args == cmd_args assert shelly.cmdline == "echo -n HELLO 'from pydra'" res = results_function(shelly, plugin) assert res.output.stdout == "HELLO from pydra" @@ -395,7 +395,7 @@ def test_shell_cmd_inputspec_3(plugin, results_function, tmp_path): input_spec=my_input_spec, cache_dir=tmp_path, ) - assert shelly.inputs.executable == cmd_exec + assert shelly.spec.executable == cmd_exec assert shelly.cmdline == "echo HELLO" res = results_function(shelly, plugin) assert res.output.stdout == "HELLO\n" @@ -428,7 +428,7 @@ def test_shell_cmd_inputspec_3a(plugin, results_function, tmp_path): input_spec=my_input_spec, cache_dir=tmp_path, ) - assert shelly.inputs.executable == cmd_exec + assert shelly.spec.executable == cmd_exec assert shelly.cmdline == "echo HELLO" res = results_function(shelly, plugin) assert res.output.stdout == "HELLO\n" @@ -462,9 +462,9 @@ def test_shell_cmd_inputspec_3b(plugin, results_function, tmp_path): shelly = ShellTask( name="shelly", executable=cmd_exec, input_spec=my_input_spec, cache_dir=tmp_path ) - shelly.inputs.text = hello + shelly.spec.text = hello - assert shelly.inputs.executable == cmd_exec + assert shelly.spec.executable == cmd_exec assert shelly.cmdline == "echo HELLO" res = results_function(shelly, plugin) assert res.output.stdout == "HELLO\n" @@ -530,7 +530,7 @@ def test_shell_cmd_inputspec_3c(plugin, results_function, tmp_path): name="shelly", executable=cmd_exec, input_spec=my_input_spec, cache_dir=tmp_path ) - assert shelly.inputs.executable == cmd_exec + assert shelly.spec.executable == cmd_exec assert shelly.cmdline == "echo" res = results_function(shelly, plugin) assert res.output.stdout == "\n" @@ -560,7 +560,7 @@ def test_shell_cmd_inputspec_4(plugin, results_function, tmp_path): name="shelly", executable=cmd_exec, input_spec=my_input_spec, cache_dir=tmp_path ) - assert shelly.inputs.executable == cmd_exec + assert shelly.spec.executable == cmd_exec assert shelly.cmdline == "echo Hello" res = results_function(shelly, plugin) @@ -586,7 +586,7 @@ def test_shell_cmd_inputspec_4a(plugin, results_function, tmp_path): name="shelly", executable=cmd_exec, input_spec=my_input_spec, cache_dir=tmp_path ) - assert shelly.inputs.executable == cmd_exec + assert shelly.spec.executable == cmd_exec assert shelly.cmdline == "echo Hello" res = results_function(shelly, plugin) @@ -617,7 +617,7 @@ def test_shell_cmd_inputspec_4b(plugin, results_function, tmp_path): name="shelly", executable=cmd_exec, input_spec=my_input_spec, cache_dir=tmp_path ) - assert shelly.inputs.executable == cmd_exec + assert shelly.spec.executable == cmd_exec assert shelly.cmdline == "echo Hi" res = results_function(shelly, plugin) @@ -728,7 +728,7 @@ def test_shell_cmd_inputspec_5_nosubm(plugin, results_function, tmp_path): input_spec=my_input_spec, cache_dir=tmp_path, ) - assert shelly.inputs.executable == cmd_exec + assert shelly.spec.executable == cmd_exec assert shelly.cmdline == "ls -t" results_function(shelly, plugin) @@ -825,7 +825,7 @@ def test_shell_cmd_inputspec_6(plugin, results_function, tmp_path): input_spec=my_input_spec, cache_dir=tmp_path, ) - assert shelly.inputs.executable == cmd_exec + assert shelly.spec.executable == cmd_exec assert shelly.cmdline == "ls -l -t" results_function(shelly, plugin) @@ -913,8 +913,8 @@ def test_shell_cmd_inputspec_6b(plugin, results_function, tmp_path): input_spec=my_input_spec, cache_dir=tmp_path, ) - shelly.inputs.opt_l = cmd_l - assert shelly.inputs.executable == cmd_exec + shelly.spec.opt_l = cmd_l + assert shelly.spec.executable == cmd_exec assert shelly.cmdline == "ls -l -t" results_function(shelly, plugin) @@ -1505,7 +1505,7 @@ def test_shell_cmd_inputspec_10(plugin, results_function, tmp_path): cache_dir=tmp_path, ) - assert shelly.inputs.executable == cmd_exec + assert shelly.spec.executable == cmd_exec res = results_function(shelly, plugin) assert res.output.stdout == "hello from boston" @@ -1590,7 +1590,7 @@ def test_shell_cmd_inputspec_11(tmp_path): wf = Workflow(name="wf", input_spec=["inputFiles"], inputFiles=["test1", "test2"]) - task.inputs.inputFiles = wf.lzin.inputFiles + task.spec.inputFiles = wf.lzin.inputFiles wf.add(task) wf.set_output([("out", wf.echoMultiple.lzout.outputFiles)]) @@ -1704,8 +1704,8 @@ def test_shell_cmd_inputspec_with_iterable(): task = ShellTask(name="test", input_spec=input_spec, executable="test") for iterable_type in (list, tuple): - task.inputs.iterable_1 = iterable_type(range(3)) - task.inputs.iterable_2 = iterable_type(["bar", "foo"]) + task.spec.iterable_1 = iterable_type(range(3)) + task.spec.iterable_2 = iterable_type(["bar", "foo"]) assert task.cmdline == "test --in1 0 1 2 --in2 bar --in2 foo" @@ -3445,8 +3445,8 @@ def test_shell_cmd_inputspec_outputspec_1(): input_spec=my_input_spec, output_spec=my_output_spec, ) - shelly.inputs.file1 = "new_file_1.txt" - shelly.inputs.file2 = "new_file_2.txt" + shelly.spec.file1 = "new_file_1.txt" + shelly.spec.file2 = "new_file_2.txt" res = shelly() assert res.output.stdout == "" @@ -3499,7 +3499,7 @@ def test_shell_cmd_inputspec_outputspec_1a(): input_spec=my_input_spec, output_spec=my_output_spec, ) - shelly.inputs.file1 = "new_file_1.txt" + shelly.spec.file1 = "new_file_1.txt" res = shelly() assert res.output.stdout == "" @@ -3560,8 +3560,8 @@ def test_shell_cmd_inputspec_outputspec_2(): input_spec=my_input_spec, output_spec=my_output_spec, ) - shelly.inputs.file1 = "new_file_1.txt" - shelly.inputs.file2 = "new_file_2.txt" + shelly.spec.file1 = "new_file_1.txt" + shelly.spec.file2 = "new_file_2.txt" # all fields from output_spec should be in output_names and generated_output_names assert ( shelly.output_names @@ -3627,7 +3627,7 @@ def test_shell_cmd_inputspec_outputspec_2a(): input_spec=my_input_spec, output_spec=my_output_spec, ) - shelly.inputs.file1 = "new_file_1.txt" + shelly.spec.file1 = "new_file_1.txt" # generated_output_names should know that newfile2 will not be generated assert shelly.output_names == [ "return_code", @@ -3699,9 +3699,9 @@ def test_shell_cmd_inputspec_outputspec_3(): input_spec=my_input_spec, output_spec=my_output_spec, ) - shelly.inputs.file1 = "new_file_1.txt" - shelly.inputs.file2 = "new_file_2.txt" - shelly.inputs.additional_inp = 2 + shelly.spec.file1 = "new_file_1.txt" + shelly.spec.file2 = "new_file_2.txt" + shelly.spec.additional_inp = 2 res = shelly() assert res.output.stdout == "" @@ -3760,8 +3760,8 @@ def test_shell_cmd_inputspec_outputspec_3a(): input_spec=my_input_spec, output_spec=my_output_spec, ) - shelly.inputs.file1 = "new_file_1.txt" - shelly.inputs.file2 = "new_file_2.txt" + shelly.spec.file1 = "new_file_1.txt" + shelly.spec.file2 = "new_file_2.txt" # generated_output_names should know that newfile2 will not be generated assert shelly.output_names == [ "return_code", @@ -3824,8 +3824,8 @@ def test_shell_cmd_inputspec_outputspec_4(): input_spec=my_input_spec, output_spec=my_output_spec, ) - shelly.inputs.file1 = "new_file_1.txt" - shelly.inputs.additional_inp = 2 + shelly.spec.file1 = "new_file_1.txt" + shelly.spec.additional_inp = 2 # generated_output_names should be the same as output_names assert ( shelly.output_names @@ -3879,9 +3879,9 @@ def test_shell_cmd_inputspec_outputspec_4a(): input_spec=my_input_spec, output_spec=my_output_spec, ) - shelly.inputs.file1 = "new_file_1.txt" + shelly.spec.file1 = "new_file_1.txt" # the value is not in the list from requires - shelly.inputs.additional_inp = 1 + shelly.spec.additional_inp = 1 res = shelly() assert res.output.stdout == "" @@ -3934,8 +3934,8 @@ def test_shell_cmd_inputspec_outputspec_5(): input_spec=my_input_spec, output_spec=my_output_spec, ) - shelly.inputs.file1 = "new_file_1.txt" - shelly.inputs.additional_inp_A = 2 + shelly.spec.file1 = "new_file_1.txt" + shelly.spec.additional_inp_A = 2 res = shelly() assert res.output.stdout == "" @@ -3988,8 +3988,8 @@ def test_shell_cmd_inputspec_outputspec_5a(): input_spec=my_input_spec, output_spec=my_output_spec, ) - shelly.inputs.file1 = "new_file_1.txt" - shelly.inputs.additional_inp_B = 2 + shelly.spec.file1 = "new_file_1.txt" + shelly.spec.additional_inp_B = 2 res = shelly() assert res.output.stdout == "" @@ -4042,7 +4042,7 @@ def test_shell_cmd_inputspec_outputspec_5b(): input_spec=my_input_spec, output_spec=my_output_spec, ) - shelly.inputs.file1 = "new_file_1.txt" + shelly.spec.file1 = "new_file_1.txt" res = shelly() assert res.output.stdout == "" @@ -4091,7 +4091,7 @@ def test_shell_cmd_inputspec_outputspec_6_except(): input_spec=my_input_spec, output_spec=my_output_spec, ) - shelly.inputs.file1 = "new_file_1.txt" + shelly.spec.file1 = "new_file_1.txt" with pytest.raises(Exception, match="requires field can be"): shelly() @@ -4338,7 +4338,7 @@ def change_name(file): name="bet_task", executable="bet", in_file=in_file, input_spec=bet_input_spec ) out_file = shelly.output_dir / "test_brain.nii.gz" - assert shelly.inputs.executable == "bet" + assert shelly.spec.executable == "bet" assert shelly.cmdline == f"bet {in_file} {out_file}" # res = shelly(plugin="cf") diff --git a/pydra/engine/tests/test_shelltask_inputspec.py b/pydra/engine/tests/test_shelltask_inputspec.py index 53071d65c5..7b95ea558f 100644 --- a/pydra/engine/tests/test_shelltask_inputspec.py +++ b/pydra/engine/tests/test_shelltask_inputspec.py @@ -719,7 +719,7 @@ def test_shell_cmd_inputs_not_given_1(): ) shelly = ShellTask(name="shelly", executable="executable", input_spec=my_input_spec) - shelly.inputs.arg2 = "argument2" + shelly.spec.arg2 = "argument2" assert shelly.cmdline == "executable --arg2 argument2" @@ -1747,7 +1747,7 @@ def test_shell_cmd_inputs_template_requires_1(): assert "--tpl" not in shelly.cmdline # When requirements are met. - shelly.inputs.with_tpl = True + shelly.spec.with_tpl = True assert "tpl.in.file" in shelly.cmdline @@ -2212,27 +2212,27 @@ class SimpleTaskXor(ShellTask): def test_task_inputs_mandatory_with_xOR_one_mandatory_is_OK(): """input spec with mandatory inputs""" task = SimpleTaskXor() - task.inputs.input_1 = "Input1" - task.inputs.input_2 = attr.NOTHING - task.inputs.check_fields_input_spec() + task.spec.input_1 = "Input1" + task.spec.input_2 = attr.NOTHING + task.spec.check_fields_input_spec() def test_task_inputs_mandatory_with_xOR_one_mandatory_out_3_is_OK(): """input spec with mandatory inputs""" task = SimpleTaskXor() - task.inputs.input_1 = attr.NOTHING - task.inputs.input_2 = attr.NOTHING - task.inputs.input_3 = True - task.inputs.check_fields_input_spec() + task.spec.input_1 = attr.NOTHING + task.spec.input_2 = attr.NOTHING + task.spec.input_3 = True + task.spec.check_fields_input_spec() def test_task_inputs_mandatory_with_xOR_zero_mandatory_raises_error(): """input spec with mandatory inputs""" task = SimpleTaskXor() - task.inputs.input_1 = attr.NOTHING - task.inputs.input_2 = attr.NOTHING + task.spec.input_1 = attr.NOTHING + task.spec.input_2 = attr.NOTHING with pytest.raises(Exception) as excinfo: - task.inputs.check_fields_input_spec() + task.spec.check_fields_input_spec() assert "input_1 is mandatory" in str(excinfo.value) assert "no alternative provided by ['input_2', 'input_3']" in str(excinfo.value) assert excinfo.type is AttributeError @@ -2241,11 +2241,11 @@ def test_task_inputs_mandatory_with_xOR_zero_mandatory_raises_error(): def test_task_inputs_mandatory_with_xOR_two_mandatories_raises_error(): """input spec with mandatory inputs""" task = SimpleTaskXor() - task.inputs.input_1 = "Input1" - task.inputs.input_2 = True + task.spec.input_1 = "Input1" + task.spec.input_2 = True with pytest.raises(Exception) as excinfo: - task.inputs.check_fields_input_spec() + task.spec.check_fields_input_spec() assert "input_1 is mutually exclusive with ['input_2']" in str(excinfo.value) assert excinfo.type is AttributeError @@ -2253,12 +2253,12 @@ def test_task_inputs_mandatory_with_xOR_two_mandatories_raises_error(): def test_task_inputs_mandatory_with_xOR_3_mandatories_raises_error(): """input spec with mandatory inputs""" task = SimpleTaskXor() - task.inputs.input_1 = "Input1" - task.inputs.input_2 = True - task.inputs.input_3 = False + task.spec.input_1 = "Input1" + task.spec.input_2 = True + task.spec.input_3 = False with pytest.raises(Exception) as excinfo: - task.inputs.check_fields_input_spec() + task.spec.check_fields_input_spec() assert "input_1 is mutually exclusive with ['input_2', 'input_3']" in str( excinfo.value ) diff --git a/pydra/engine/tests/test_task.py b/pydra/engine/tests/test_task.py index 8699eb1711..4a481e9d92 100644 --- a/pydra/engine/tests/test_task.py +++ b/pydra/engine/tests/test_task.py @@ -57,7 +57,7 @@ def test_numpy(): fft = mark.annotate({"a": np.ndarray, "return": np.ndarray})(np.fft.fft) fft = mark.task(fft)() arr = np.array([[1, 10], [2, 20]]) - fft.inputs.a = arr + fft.spec.a = arr res = fft() assert np.allclose(np.fft.fft(arr), res.output.out) @@ -1319,7 +1319,7 @@ def test_shell_cmd(tmpdir): # separate command into exec + args shelly = ShellTask(executable=cmd[0], args=cmd[1:]) - assert shelly.inputs.executable == "echo" + assert shelly.spec.executable == "echo" assert shelly.cmdline == " ".join(cmd) res = shelly._run() assert res.output.return_code == 0 diff --git a/pydra/engine/workflow/node.py b/pydra/engine/workflow/node.py index 36efa95af1..189fc0cebc 100644 --- a/pydra/engine/workflow/node.py +++ b/pydra/engine/workflow/node.py @@ -1,11 +1,14 @@ import typing as ty -from copy import deepcopy +from copy import deepcopy, copy from enum import Enum +from pathlib import Path import attrs from pydra.utils.typing import TypeParser, StateArray from . import lazy -from ..specs import TaskSpec, Outputs -from ..helpers import ensure_list, attrs_values, is_lazy +from ..specs import TaskSpec, Outputs, WorkflowSpec +from ..task import Task +from ..helpers import ensure_list, attrs_values, is_lazy, load_result, create_checksum +from pydra.utils.hash import hash_function from .. import helpers_state as hlpst from ..state import State @@ -273,6 +276,94 @@ def combiner(self): return () return self._state.combiner + def _get_tasks( + self, + cache_locations: Path | list[Path], + state_index: int | None = None, + return_inputs: bool = False, + ) -> list["Task"]: + raise NotImplementedError + if self.state: + if state_index is None: + # if state_index=None, collecting all results + if self.state.combiner: + return self._combined_output(return_inputs=return_inputs) + else: + results = [] + for ind in range(len(self.state.inputs_ind)): + checksum = self.checksum_states(state_index=ind) + result = load_result(checksum, cache_locations) + if result is None: + return None + results.append(result) + if return_inputs is True or return_inputs == "val": + return list(zip(self.state.states_val, results)) + elif return_inputs == "ind": + return list(zip(self.state.states_ind, results)) + else: + return results + else: # state_index is not None + if self.state.combiner: + return self._combined_output(return_inputs=return_inputs)[ + state_index + ] + result = load_result(self.checksum_states(state_index), cache_locations) + if return_inputs is True or return_inputs == "val": + return (self.state.states_val[state_index], result) + elif return_inputs == "ind": + return (self.state.states_ind[state_index], result) + else: + return result + else: + return load_result(self._spec._checksum, cache_locations) + + def _checksum_states(self, state_index=None): + """ + Calculate a checksum for the specific state or all of the states of the task. + Replaces state-arrays in the inputs fields with a specific values for states. + Used to recreate names of the task directories, + + Parameters + ---------- + state_index : + TODO + + """ + # if is_workflow(self) and self.spec._graph_checksums is attr.NOTHING: + # self.spec._graph_checksums = { + # nd.name: nd.checksum for nd in self.graph_sorted + # } + + if state_index is not None: + inputs_copy = copy(self.spec) + for key, ind in self.state.inputs_ind[state_index].items(): + val = self._extract_input_el( + inputs=self.spec, inp_nm=key.split(".")[1], ind=ind + ) + setattr(inputs_copy, key.split(".")[1], val) + # setting files_hash again in case it was cleaned by setting specific element + # that might be important for outer splitter of input variable with big files + # the file can be changed with every single index even if there are only two files + input_hash = inputs_copy.hash + if isinstance(self._spec, WorkflowSpec): + con_hash = hash_function(self._connections) + # TODO: hash list is not used + hash_list = [input_hash, con_hash] # noqa: F841 + checksum_ind = create_checksum( + self.__class__.__name__, self._checksum_wf(input_hash) + ) + else: + checksum_ind = create_checksum(self.__class__.__name__, input_hash) + return checksum_ind + else: + checksum_list = [] + if not hasattr(self.state, "inputs_ind"): + self.state.prepare_states(self.spec, cont_dim=self.cont_dim) + self.state.prepare_inputs() + for ind in range(len(self.state.inputs_ind)): + checksum_list.append(self._checksum_states(state_index=ind)) + return checksum_list + def _check_if_outputs_have_been_used(self, msg): used = [] if self._lzout: diff --git a/pydra/utils/typing.py b/pydra/utils/typing.py index ee21d26db3..2ce2efd1ff 100644 --- a/pydra/utils/typing.py +++ b/pydra/utils/typing.py @@ -1042,3 +1042,14 @@ def is_fileset_or_union(type_: type) -> bool: if is_union(type_): return any(is_fileset_or_union(t) for t in ty.get_args(type_)) return issubclass(type_, core.FileSet) + + +def is_type(*args: ty.Any) -> bool: + """check that the value is a type or generic""" + if len(args) == 3: # attrs validator + val = args[2] + elif len(args) != 1: + raise TypeError(f"is_type() takes 1 or 3 arguments, not {args}") + else: + val = args[0] + return inspect.isclass(val) or ty.get_origin(val)