diff --git a/pydra/engine/helpers.py b/pydra/engine/helpers.py index 5b411a9a95..4d8e84132b 100644 --- a/pydra/engine/helpers.py +++ b/pydra/engine/helpers.py @@ -263,7 +263,9 @@ def make_klass(spec): **kwargs, ) checker_label = f"'{name}' field of {spec.name}" - type_checker = TypeParser[newfield.type](newfield.type, label=checker_label) + type_checker = TypeParser[newfield.type]( + newfield.type, label=checker_label, superclass_auto_cast=True + ) if newfield.type in (MultiInputObj, MultiInputFile): converter = attr.converters.pipe(ensure_list, type_checker) elif newfield.type in (MultiOutputObj, MultiOutputFile): diff --git a/pydra/engine/specs.py b/pydra/engine/specs.py index 289eb8d3cd..d9bd2269d9 100644 --- a/pydra/engine/specs.py +++ b/pydra/engine/specs.py @@ -445,7 +445,7 @@ def collect_additional_outputs(self, inputs, output_dir, outputs): ), ): raise TypeError( - f"Support for {fld.type} type, required for {fld.name} in {self}, " + f"Support for {fld.type} type, required for '{fld.name}' in {self}, " "has not been implemented in collect_additional_output" ) # assuming that field should have either default or metadata, but not both diff --git a/pydra/engine/tests/test_node_task.py b/pydra/engine/tests/test_node_task.py index 4e182781b0..37ed90d037 100644 --- a/pydra/engine/tests/test_node_task.py +++ b/pydra/engine/tests/test_node_task.py @@ -133,21 +133,7 @@ def test_task_init_3a( def test_task_init_4(): - """task with interface and inputs. splitter set using split method""" - nn = fun_addtwo(name="NA") - nn.split(splitter="a", a=[3, 5]) - assert np.allclose(nn.inputs.a, [3, 5]) - - assert nn.state.splitter == "NA.a" - assert nn.state.splitter_rpn == ["NA.a"] - - nn.state.prepare_states(nn.inputs) - assert nn.state.states_ind == [{"NA.a": 0}, {"NA.a": 1}] - assert nn.state.states_val == [{"NA.a": 3}, {"NA.a": 5}] - - -def test_task_init_4a(): - """task with a splitter and inputs set in the split method""" + """task with interface splitter and inputs set in the split method""" nn = fun_addtwo(name="NA") nn.split(splitter="a", a=[3, 5]) assert np.allclose(nn.inputs.a, [3, 5]) diff --git a/pydra/utils/tests/test_typing.py b/pydra/utils/tests/test_typing.py index f88aeafe15..665d79327d 100644 --- a/pydra/utils/tests/test_typing.py +++ b/pydra/utils/tests/test_typing.py @@ -1,5 +1,6 @@ import os import itertools +import sys import typing as ty from pathlib import Path import tempfile @@ -8,13 +9,16 @@ from ...engine.specs import File, LazyOutField from ..typing import TypeParser from pydra import Workflow -from fileformats.application import Json +from fileformats.application import Json, Yaml, Xml from .utils import ( generic_func_task, GenericShellTask, specific_func_task, SpecificShellTask, + other_specific_func_task, + OtherSpecificShellTask, MyFormatX, + MyOtherFormatX, MyHeader, ) @@ -152,8 +156,12 @@ def test_type_check_nested6(): def test_type_check_nested7(): + TypeParser(ty.Tuple[float, float, float])(lz(ty.List[int])) + + +def test_type_check_nested7a(): with pytest.raises(TypeError, match="Wrong number of type arguments"): - TypeParser(ty.Tuple[float, float, float])(lz(ty.List[int])) + TypeParser(ty.Tuple[float, float, float])(lz(ty.Tuple[int])) def test_type_check_nested8(): @@ -164,6 +172,18 @@ def test_type_check_nested8(): )(lz(ty.List[float])) +def test_type_check_permit_superclass(): + # Typical case as Json is subclass of File + TypeParser(ty.List[File])(lz(ty.List[Json])) + # Permissive super class, as File is superclass of Json + TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[File])) + with pytest.raises(TypeError, match="Cannot coerce"): + TypeParser(ty.List[Json], superclass_auto_cast=False)(lz(ty.List[File])) + # Fails because Yaml is neither sub or super class of Json + with pytest.raises(TypeError, match="Cannot coerce"): + TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[Yaml])) + + def test_type_check_fail1(): with pytest.raises(TypeError, match="Wrong number of type arguments in tuple"): TypeParser(ty.Tuple[int, int, int])(lz(ty.Tuple[float, float, float, float])) @@ -490,14 +510,29 @@ def test_matches_type_tuple(): assert not TypeParser.matches_type(ty.Tuple[int], ty.Tuple[int, int]) -def test_matches_type_tuple_ellipsis(): +def test_matches_type_tuple_ellipsis1(): assert TypeParser.matches_type(ty.Tuple[int], ty.Tuple[int, ...]) + + +def test_matches_type_tuple_ellipsis2(): assert TypeParser.matches_type(ty.Tuple[int, int], ty.Tuple[int, ...]) + + +def test_matches_type_tuple_ellipsis3(): assert not TypeParser.matches_type(ty.Tuple[int, float], ty.Tuple[int, ...]) - assert not TypeParser.matches_type(ty.Tuple[int, ...], ty.Tuple[int]) + + +def test_matches_type_tuple_ellipsis4(): + assert TypeParser.matches_type(ty.Tuple[int, ...], ty.Tuple[int]) + + +def test_matches_type_tuple_ellipsis5(): assert TypeParser.matches_type( ty.Tuple[int], ty.List[int], coercible=[(tuple, list)] ) + + +def test_matches_type_tuple_ellipsis6(): assert TypeParser.matches_type( ty.Tuple[int, ...], ty.List[int], coercible=[(tuple, list)] ) @@ -538,7 +573,17 @@ def specific_task(request): assert False -def test_typing_cast(tmp_path, generic_task, specific_task): +@pytest.fixture(params=["func", "shell"]) +def other_specific_task(request): + if request.param == "func": + return other_specific_func_task + elif request.param == "shell": + return OtherSpecificShellTask + else: + assert False + + +def test_typing_implicit_cast_from_super(tmp_path, generic_task, specific_task): """Check the casting of lazy fields and whether specific file-sets can be recovered from generic `File` classes""" @@ -562,33 +607,86 @@ def test_typing_cast(tmp_path, generic_task, specific_task): ) ) + wf.add( + specific_task( + in_file=wf.generic.lzout.out, + name="specific2", + ) + ) + + wf.set_output( + [ + ("out_file", wf.specific2.lzout.out), + ] + ) + + in_file = MyFormatX.sample() + + result = wf(in_file=in_file, plugin="serial") + + out_file: MyFormatX = result.output.out_file + assert type(out_file) is MyFormatX + assert out_file.parent != in_file.parent + assert type(out_file.header) is MyHeader + assert out_file.header.parent != in_file.header.parent + + +def test_typing_cast(tmp_path, specific_task, other_specific_task): + """Check the casting of lazy fields and whether specific file-sets can be recovered + from generic `File` classes""" + + wf = Workflow( + name="test", + input_spec={"in_file": MyFormatX}, + output_spec={"out_file": MyFormatX}, + ) + + wf.add( + specific_task( + in_file=wf.lzin.in_file, + name="entry", + ) + ) + + with pytest.raises(TypeError, match="Cannot coerce"): + # No cast of generic task output to MyFormatX + wf.add( # Generic task + other_specific_task( + in_file=wf.entry.lzout.out, + name="inner", + ) + ) + + wf.add( # Generic task + other_specific_task( + in_file=wf.entry.lzout.out.cast(MyOtherFormatX), + name="inner", + ) + ) + with pytest.raises(TypeError, match="Cannot coerce"): # No cast of generic task output to MyFormatX wf.add( specific_task( - in_file=wf.generic.lzout.out, - name="specific2", + in_file=wf.inner.lzout.out, + name="exit", ) ) wf.add( specific_task( - in_file=wf.generic.lzout.out.cast(MyFormatX), - name="specific2", + in_file=wf.inner.lzout.out.cast(MyFormatX), + name="exit", ) ) wf.set_output( [ - ("out_file", wf.specific2.lzout.out), + ("out_file", wf.exit.lzout.out), ] ) - my_fspath = tmp_path / "in_file.my" - hdr_fspath = tmp_path / "in_file.hdr" - my_fspath.write_text("my-format") - hdr_fspath.write_text("my-header") - in_file = MyFormatX([my_fspath, hdr_fspath]) + in_file = MyFormatX.sample() result = wf(in_file=in_file, plugin="serial") @@ -611,6 +709,63 @@ def test_type_is_subclass3(): assert TypeParser.is_subclass(ty.Type[Json], ty.Type[File]) +def test_union_is_subclass1(): + assert TypeParser.is_subclass(ty.Union[Json, Yaml], ty.Union[Json, Yaml, Xml]) + + +def test_union_is_subclass2(): + assert not TypeParser.is_subclass(ty.Union[Json, Yaml, Xml], ty.Union[Json, Yaml]) + + +def test_union_is_subclass3(): + assert TypeParser.is_subclass(Json, ty.Union[Json, Yaml]) + + +def test_union_is_subclass4(): + assert not TypeParser.is_subclass(ty.Union[Json, Yaml], Json) + + +def test_generic_is_subclass1(): + assert TypeParser.is_subclass(ty.List[int], list) + + +def test_generic_is_subclass2(): + assert not TypeParser.is_subclass(list, ty.List[int]) + + +def test_generic_is_subclass3(): + assert not TypeParser.is_subclass(ty.List[float], ty.List[int]) + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Cannot subscript tuple in < Py3.9" +) +def test_generic_is_subclass4(): + class MyTuple(tuple): + pass + + class A: + pass + + class B(A): + pass + + assert TypeParser.is_subclass(MyTuple[A], ty.Tuple[A]) + assert TypeParser.is_subclass(ty.Tuple[B], ty.Tuple[A]) + assert TypeParser.is_subclass(MyTuple[B], ty.Tuple[A]) + assert not TypeParser.is_subclass(ty.Tuple[A], ty.Tuple[B]) + assert not TypeParser.is_subclass(ty.Tuple[A], MyTuple[A]) + assert not TypeParser.is_subclass(MyTuple[A], ty.Tuple[B]) + assert TypeParser.is_subclass(MyTuple[A, int], ty.Tuple[A, int]) + assert TypeParser.is_subclass(ty.Tuple[B, int], ty.Tuple[A, int]) + assert TypeParser.is_subclass(MyTuple[B, int], ty.Tuple[A, int]) + assert TypeParser.is_subclass(MyTuple[int, B], ty.Tuple[int, A]) + assert not TypeParser.is_subclass(MyTuple[B, int], ty.Tuple[int, A]) + assert not TypeParser.is_subclass(MyTuple[int, B], ty.Tuple[A, int]) + assert not TypeParser.is_subclass(MyTuple[B, int], ty.Tuple[A]) + assert not TypeParser.is_subclass(MyTuple[B], ty.Tuple[A, int]) + + def test_type_is_instance1(): assert TypeParser.is_instance(File, ty.Type[File]) diff --git a/pydra/utils/tests/utils.py b/pydra/utils/tests/utils.py index eb452edf91..3582fa9eda 100644 --- a/pydra/utils/tests/utils.py +++ b/pydra/utils/tests/utils.py @@ -1,12 +1,13 @@ from fileformats.generic import File -from fileformats.core.mixin import WithSeparateHeader +from fileformats.core.mixin import WithSeparateHeader, WithMagicNumber from pydra import mark from pydra.engine.task import ShellCommandTask from pydra.engine import specs -class MyFormat(File): +class MyFormat(WithMagicNumber, File): ext = ".my" + magic_number = b"MYFORMAT" class MyHeader(File): @@ -17,6 +18,12 @@ class MyFormatX(WithSeparateHeader, MyFormat): header_type = MyHeader +class MyOtherFormatX(WithMagicNumber, WithSeparateHeader, File): + magic_number = b"MYFORMAT" + ext = ".my" + header_type = MyHeader + + @mark.task def generic_func_task(in_file: File) -> File: return in_file @@ -118,3 +125,57 @@ class SpecificShellTask(ShellCommandTask): input_spec = specific_shell_input_spec output_spec = specific_shelloutput_spec executable = "echo" + + +@mark.task +def other_specific_func_task(in_file: MyOtherFormatX) -> MyOtherFormatX: + return in_file + + +other_specific_shell_input_fields = [ + ( + "in_file", + MyOtherFormatX, + { + "help_string": "the input file", + "argstr": "", + "copyfile": "copy", + "sep": " ", + }, + ), + ( + "out", + str, + { + "help_string": "output file name", + "argstr": "", + "position": -1, + "output_file_template": "{in_file}", # Pass through un-altered + }, + ), +] + +other_specific_shell_input_spec = specs.SpecInfo( + name="Input", fields=other_specific_shell_input_fields, bases=(specs.ShellSpec,) +) + +other_specific_shell_output_fields = [ + ( + "out", + MyOtherFormatX, + { + "help_string": "output file", + }, + ), +] +other_specific_shelloutput_spec = specs.SpecInfo( + name="Output", + fields=other_specific_shell_output_fields, + bases=(specs.ShellOutSpec,), +) + + +class OtherSpecificShellTask(ShellCommandTask): + input_spec = other_specific_shell_input_spec + output_spec = other_specific_shelloutput_spec + executable = "echo" diff --git a/pydra/utils/typing.py b/pydra/utils/typing.py index ceddc7e219..ee8e733e44 100644 --- a/pydra/utils/typing.py +++ b/pydra/utils/typing.py @@ -4,6 +4,7 @@ import os import sys import typing as ty +import logging import attr from ..engine.specs import ( LazyField, @@ -19,6 +20,7 @@ # Python < 3.8 from typing_extensions import get_origin, get_args # type: ignore +logger = logging.getLogger("pydra") NO_GENERIC_ISSUBCLASS = sys.version_info.major == 3 and sys.version_info.minor < 10 @@ -56,6 +58,9 @@ class TypeParser(ty.Generic[T]): the tree of more complex nested container types. Overrides 'coercible' to enable you to carve out exceptions, such as TypeParser(list, coercible=[(ty.Iterable, list)], not_coercible=[(str, list)]) + superclass_auto_cast : bool + Allow lazy fields to pass the type check if their types are superclasses of the + specified pattern (instead of matching or being subclasses of the pattern) label : str the label to be used to identify the type parser in error messages. Especially useful when TypeParser is used as a converter in attrs.fields @@ -64,6 +69,7 @@ class TypeParser(ty.Generic[T]): tp: ty.Type[T] coercible: ty.List[ty.Tuple[TypeOrAny, TypeOrAny]] not_coercible: ty.List[ty.Tuple[TypeOrAny, TypeOrAny]] + superclass_auto_cast: bool label: str COERCIBLE_DEFAULT: ty.Tuple[ty.Tuple[type, type], ...] = ( @@ -107,6 +113,7 @@ def __init__( not_coercible: ty.Optional[ ty.Iterable[ty.Tuple[TypeOrAny, TypeOrAny]] ] = NOT_COERCIBLE_DEFAULT, + superclass_auto_cast: bool = False, label: str = "", ): def expand_pattern(t): @@ -135,6 +142,7 @@ def expand_pattern(t): ) self.not_coercible = list(not_coercible) if not_coercible is not None else [] self.pattern = expand_pattern(tp) + self.superclass_auto_cast = superclass_auto_cast def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: """Attempts to coerce the object to the specified type, unless the value is @@ -161,7 +169,27 @@ def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: if obj is attr.NOTHING: coerced = attr.NOTHING # type: ignore[assignment] elif isinstance(obj, LazyField): - self.check_type(obj.type) + try: + self.check_type(obj.type) + except TypeError as e: + if self.superclass_auto_cast: + try: + # Check whether the type of the lazy field isn't a superclass of + # the type to check against, and if so, allow it due to permissive + # typing rules. + TypeParser(obj.type).check_type(self.tp) + except TypeError: + raise e + else: + logger.info( + "Connecting lazy field %s to %s%s via permissive typing that " + "allows super-to-sub type connections", + obj, + self.tp, + self.label_str, + ) + else: + raise e coerced = obj # type: ignore elif isinstance(obj, StateArray): coerced = StateArray(self(o) for o in obj) # type: ignore[assignment] @@ -421,6 +449,10 @@ def check_tuple(tp_args, pattern_args): for arg in tp_args: expand_and_check(arg, pattern_args[0]) return + elif tp_args[-1] is Ellipsis: + for pattern_arg in pattern_args: + expand_and_check(tp_args[0], pattern_arg) + return if len(tp_args) != len(pattern_args): raise TypeError( f"Wrong number of type arguments in tuple {tp_args} compared to pattern " @@ -464,8 +496,17 @@ def check_coercible( explicit inclusions and exclusions set in the `coercible` and `not_coercible` member attrs """ + # Short-circuit the basic cases where the source and target are the same if source is target: return + if self.superclass_auto_cast and self.is_subclass(target, type(source)): + logger.info( + "Attempting to coerce %s into %s due to super-to-sub class coercion " + "being permitted", + source, + target, + ) + return source_origin = get_origin(source) if source_origin is not None: source = source_origin @@ -562,7 +603,7 @@ def matches_type( def is_instance( cls, obj: object, - candidates: ty.Union[ty.Type[ty.Any], ty.Iterable[ty.Type[ty.Any]]], + candidates: ty.Union[ty.Type[ty.Any], ty.Sequence[ty.Type[ty.Any]]], ) -> bool: """Checks whether the object is an instance of cls or that cls is typing.Any, extending the built-in isinstance to check nested type args @@ -574,7 +615,7 @@ def is_instance( candidates : type or ty.Iterable[type] the candidate types to check the object against """ - if not isinstance(candidates, (tuple, list)): + if not isinstance(candidates, ty.Sequence): candidates = [candidates] for candidate in candidates: if candidate is ty.Any: @@ -600,7 +641,7 @@ def is_instance( def is_subclass( cls, klass: ty.Type[ty.Any], - candidates: ty.Union[ty.Type[ty.Any], ty.Iterable[ty.Type[ty.Any]]], + candidates: ty.Union[ty.Type[ty.Any], ty.Sequence[ty.Type[ty.Any]]], any_ok: bool = False, ) -> bool: """Checks whether the class a is either the same as b, a subclass of b or b is @@ -617,16 +658,23 @@ def is_subclass( """ if not isinstance(candidates, ty.Sequence): candidates = [candidates] + if ty.Any in candidates: + return True + if klass is ty.Any: + return any_ok + + origin = get_origin(klass) + args = get_args(klass) for candidate in candidates: + candidate_origin = get_origin(candidate) + candidate_args = get_args(candidate) # Handle ty.Type[*] types in klass and candidates - if ty.get_origin(klass) is type and ( - candidate is type or ty.get_origin(candidate) is type - ): + if origin is type and (candidate is type or candidate_origin is type): if candidate is type: return True - return cls.is_subclass(ty.get_args(klass)[0], ty.get_args(candidate)[0]) - elif ty.get_origin(klass) is type or ty.get_origin(candidate) is type: + return cls.is_subclass(args[0], candidate_args[0]) + elif origin is type or candidate_origin is type: return False if NO_GENERIC_ISSUBCLASS: if klass is type and candidate is not type: @@ -636,27 +684,29 @@ def is_subclass( ): return True else: - if klass is ty.Any: - if ty.Any in candidates: # type: ignore - return True - else: - return any_ok - origin = get_origin(klass) if origin is ty.Union: - args = get_args(klass) - if get_origin(candidate) is ty.Union: - candidate_args = get_args(candidate) - else: - candidate_args = [candidate] - return all( - any(cls.is_subclass(a, c) for a in args) for c in candidate_args + union_args = ( + candidate_args if candidate_origin is ty.Union else (candidate,) ) - if origin is not None: - klass = origin - if klass is candidate or candidate is ty.Any: - return True - if issubclass(klass, candidate): - return True + matches = all( + any(cls.is_subclass(a, c) for c in union_args) for a in args + ) + if matches: + return True + else: + if candidate_args and candidate_origin is not ty.Union: + if ( + origin + and issubclass(origin, candidate_origin) # type: ignore[arg-type] + and len(args) == len(candidate_args) + and all( + issubclass(a, c) for a, c in zip(args, candidate_args) + ) + ): + return True + else: + if issubclass(origin if origin else klass, candidate): + return True return False @classmethod