diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 0000000..e2556be --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,17 @@ +import unittest + +from wdlgen import ParameterMeta + + +class TestParamMeta(unittest.TestCase): + def test_quote_sanitise(self): + meta = ParameterMeta(foo='"bar"').get_string() + self.assertEqual('foo: "\\"bar\\""', meta) + + def test_nl_sanitise(self): + meta = ParameterMeta(foo="bar\nbaz").get_string() + self.assertEqual('foo: "bar\\nbaz"', meta) + + def test_backslackquote_sanitise(self): + meta = ParameterMeta(foo='bar\\"').get_string() + self.assertEqual('foo: "bar\\\\\\""', meta) diff --git a/tests/test_struct_generation.py b/tests/test_struct_generation.py new file mode 100644 index 0000000..1724961 --- /dev/null +++ b/tests/test_struct_generation.py @@ -0,0 +1,20 @@ +import unittest + +from wdlgen import WdlType, String, Int +from wdlgen.struct import Struct + + +class TestStructs(unittest.TestCase): + def test_spec_example_1(self): + s = Struct("Name") + s.add_field(String, "myString") + s.add_field(Int, "myInt") + + self.assertEqual( + """\ +struct Name { + String myString + Int myInt +}""", + s.get_string(), + ) diff --git a/tests/test_task_generation.py b/tests/test_task_generation.py index 0c166a2..24b781f 100644 --- a/tests/test_task_generation.py +++ b/tests/test_task_generation.py @@ -11,7 +11,9 @@ WorkflowScatter, Meta, ParameterMeta, + Int, ) +from wdlgen.struct import Struct class TestTaskGeneration(unittest.TestCase): @@ -209,26 +211,20 @@ def test_commandarg_nospace(self): self.assertEqual("arg=argVal", t.get_string()) def test_commandarg_flag(self): - t = Task.Command.CommandInput.from_fields( - name="my_value", - true="--arg" - ) - self.assertEqual("~{if (my_value) then \"--arg\" else \"\"}", t.get_string()) + t = Task.Command.CommandInput.from_fields(name="my_value", true="--arg") + self.assertEqual('~{if (my_value) then "--arg" else ""}', t.get_string()) def test_commandarg_flag_false(self): - t = Task.Command.CommandInput.from_fields( - name="my_value", - false="--arg" - ) - self.assertEqual("~{if (my_value) then \"\" else \"--arg\"}", t.get_string()) + t = Task.Command.CommandInput.from_fields(name="my_value", false="--arg") + self.assertEqual('~{if (my_value) then "" else "--arg"}', t.get_string()) def test_commandinp_array_inp(self): t = Task.Command.CommandInput.from_fields( - name="my_array", - separator=" ", - default=[] + name="my_array", separator=" ", default=[] + ) + self.assertEqual( + '~{sep(" ", if defined(my_array) then my_array else [])}', t.get_string() ) - self.assertEqual("~{sep(\" \", if defined(my_array) then my_array else [])}", t.get_string()) class TestWorkflowGeneration(unittest.TestCase): @@ -328,9 +324,7 @@ def test_parameter_meta_dict(self): w = Task( "param_meta_obj", parameter_meta=ParameterMeta( - obj_value={ - "help": "This is help text", "scalar": 96 - } + obj_value={"help": "This is help text", "scalar": 96} ), ) @@ -381,3 +375,31 @@ def test_meta_string(self): }""" derived_workflow_only = "".join(w.get_string().splitlines(keepends=True)[2:]) self.assertEqual(expected, derived_workflow_only) + + +class TestTaskWithExtra(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + s = Struct("Name") + s.add_field(String, "myString") + s.add_field(Int, "myInt") + + cls.struct = s + + def test_post_import_struct(self): + + t = Task("hello") + t.inputs.extend([Input(String, "inp1"), Input(Int, "inp2")]) + + t.pre_statements.append(self.struct) + t.noninput_declarations.append("Name value = object { inp1: inp1, inp2: inp2 }") + + t.command = Task.Command("cat ~{write_json(value)}") + + t.outputs.append( + Output( + WdlType.parse_type("Array[String]"), "matches", "read_string(stdout())" + ) + ) + + print(t.get_string()) diff --git a/tests/test_workflow_generation.py b/tests/test_workflow_generation.py index 6100aa1..22b644f 100644 --- a/tests/test_workflow_generation.py +++ b/tests/test_workflow_generation.py @@ -12,7 +12,7 @@ def test_parameter_meta_scalar(self): test: 42 } }""" - derived_workflow_only = "".join(w.get_string().splitlines(keepends=True)[4:]) + derived_workflow_only = "".join(w.get_string().splitlines(keepends=True)[2:]) self.assertEqual(expected, derived_workflow_only) def test_parameter_meta_string(self): @@ -26,7 +26,7 @@ def test_parameter_meta_string(self): other: "string value" } }""" - derived_workflow_only = "".join(w.get_string().splitlines(keepends=True)[4:]) + derived_workflow_only = "".join(w.get_string().splitlines(keepends=True)[2:]) self.assertEqual(expected, derived_workflow_only) def test_parameter_meta_bool(self): @@ -41,7 +41,7 @@ def test_parameter_meta_bool(self): neg: false } }""" - derived_task_only = "".join(w.get_string().splitlines(keepends=True)[4:]) + derived_task_only = "".join(w.get_string().splitlines(keepends=True)[2:]) self.assertEqual(expected, derived_task_only) def test_parameter_meta_obj(self): @@ -60,16 +60,14 @@ def test_parameter_meta_obj(self): obj_value: {help: "This is help text", scalar: 96} } }""" - derived_workflow_only = "".join(w.get_string().splitlines(keepends=True)[4:]) + derived_workflow_only = "".join(w.get_string().splitlines(keepends=True)[2:]) self.assertEqual(expected, derived_workflow_only) def test_parameter_meta_dict(self): w = Workflow( "param_meta_obj", parameter_meta=ParameterMeta( - obj_value={ - "help": "This is help text", "scalar": 96 - } + obj_value={"help": "This is help text", "scalar": 96} ), ) @@ -79,9 +77,10 @@ def test_parameter_meta_dict(self): obj_value: {help: "This is help text", scalar: 96} } }""" - derived_workflow_only = "".join(w.get_string().splitlines(keepends=True)[4:]) + derived_workflow_only = "".join(w.get_string().splitlines(keepends=True)[2:]) self.assertEqual(expected, derived_workflow_only) + class TestWorkflowMetaGeneration(unittest.TestCase): def test_meta_scalar(self): w = Workflow("meta_scalar", meta=Meta(arbitrary_scalar=42)) @@ -92,7 +91,7 @@ def test_meta_scalar(self): arbitrary_scalar: 42 } }""" - derived_workflow_only = "".join(w.get_string().splitlines(keepends=True)[4:]) + derived_workflow_only = "".join(w.get_string().splitlines(keepends=True)[2:]) self.assertEqual(expected, derived_workflow_only) def test_meta_string(self): @@ -104,7 +103,7 @@ def test_meta_string(self): author: "illusional" } }""" - derived_workflow_only = "".join(w.get_string().splitlines(keepends=True)[4:]) + derived_workflow_only = "".join(w.get_string().splitlines(keepends=True)[2:]) self.assertEqual(expected, derived_workflow_only) def test_meta_bool(self): @@ -117,5 +116,6 @@ def test_meta_bool(self): neg: false } }""" - derived_task_only = "".join(w.get_string().splitlines(keepends=True)[4:]) + result = w.get_string() + derived_task_only = "".join(result.splitlines(keepends=True)[2:]) self.assertEqual(expected, derived_task_only) diff --git a/wdlgen/common.py b/wdlgen/common.py index 22cfbeb..c31e2be 100644 --- a/wdlgen/common.py +++ b/wdlgen/common.py @@ -10,8 +10,9 @@ def __init__(self, condition, value_if_true, value_if_false): self.value_if_true = value_if_true self.value_if_false = value_if_false - def get_string(self): - return ( + def get_string(self, indent=0): + tb = indent * " " + return tb + ( f"if {self.condition} then {self.value_if_true} else {self.value_if_false}" ) @@ -31,16 +32,17 @@ def __init__( self.format = "{type} {name}{def_w_equals}" - def get_string(self): + def get_string(self, indent=0): if self.type is None: raise Exception( f"Could not convert wdlgen.Input ('{self.name}') to string because type was null" ) + tb = indent * " " wd = self.type.get_string() if isinstance(wd, list): - return self.get_string_from_type(wd[0]) - return self.get_string_from_type(wd) + return tb + self.get_string_from_type(wd[0]) + return tb + self.get_string_from_type(wd) def get_string_from_type(self, wdtype): expression = self.expression @@ -72,8 +74,9 @@ def __init__(self, data_type: WdlType, name: str, expression: str = None): self.name = name self.expression = expression - def get_string(self): - f = "{type} {name}{def_w_equals}" + def get_string(self, indent=0): + tb = indent * " " + f = tb + "{type} {name}{def_w_equals}" if isinstance(self.type, list): return [ f.format( diff --git a/wdlgen/struct.py b/wdlgen/struct.py new file mode 100644 index 0000000..455bc8f --- /dev/null +++ b/wdlgen/struct.py @@ -0,0 +1,30 @@ +from typing import List, Optional + +from wdlgen import WdlBase, WdlType + + +class StructField(WdlBase): + def __init__(self, type_: WdlType, name: str): + self.type_ = type_ + self.name = name + + def get_string(self, indent=0): + ind = indent * " " + return f"{ind}{self.type_.get_string()} {self.name}" + + +class Struct(WdlBase): + def __init__(self, name: str, fields: Optional[List[StructField]] = None): + self.name = name + self.fields = fields or [] + + def add_field(self, type_: WdlType, name: str): + self.fields.append(StructField(type_, name)) + + def get_string(self, indent=0): + tb = indent * " " + fields = "\n".join((f.get_string(indent=indent + 1)) for f in self.fields) + return f"""\ +{tb}struct {self.name} {{ +{fields} +{tb}}}""" diff --git a/wdlgen/task.py b/wdlgen/task.py index 8d074b2..a0e3040 100644 --- a/wdlgen/task.py +++ b/wdlgen/task.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Union from .common import Input, Output from .util import WdlBase, KvClass, Meta, ParameterMeta @@ -56,33 +56,34 @@ class Command(WdlBase): """ class CommandArgument(WdlBase): - def __init__( - self, - value, - position=None - ): + def __init__(self, value, position=None): self.value = value self.position = position @staticmethod - def from_fields(prefix: str = None, value: str = None, position: int = None, separate_value_from_prefix: bool = True,): + def from_fields( + prefix: str = None, + value: str = None, + position: int = None, + separate_value_from_prefix: bool = True, + ): pre = prefix if prefix else "" sp = " " if separate_value_from_prefix else "" val = value if value else "" - return Task.Command.CommandArgument((pre + sp + val).strip(), position=position) + return Task.Command.CommandArgument( + (pre + sp + val).strip(), position=position + ) - def get_string(self): - return self.value + def get_string(self, indent=0): + tb = indent * " " + return tb + str(self.value) class CommandInput(CommandArgument): def __init__( - self, - value, - position=None, + self, value, position=None, ): super().__init__( - value=value, - position=position, + value=value, position=position, ) @staticmethod @@ -92,7 +93,8 @@ def from_input(inp: Input, prefix: str = None, position: int = None): ) @staticmethod - def from_fields(name: str, + def from_fields( + name: str, optional: bool = False, prefix: str = None, position: int = None, @@ -101,7 +103,8 @@ def from_fields(name: str, separator=None, true=None, false=None, - separate_arrays=None): + separate_arrays=None, + ): name, array_sep, default, true, false = ( name, @@ -123,13 +126,21 @@ def from_fields(name: str, # Ugly optional workaround: https://github.com/openwdl/wdl/issues/25#issuecomment-315424063 # Additional workaround for 'length(select_first({name}, [])' as length requires a non-optional array internal_pref = f'if defined({name}) && length(select_first([{name}, []])) > 0 then "{bc}" else ""' - return Task.Command.CommandInput(f'~{{{internal_pref}}}~{{sep(" {bc}", {name})}}', position=position) - return Task.Command.CommandInput(f'~{{sep(" ", prefix("{bc}", {name}))}}', position=position) + return Task.Command.CommandInput( + f'~{{{internal_pref}}}~{{sep(" {bc}", {name})}}', + position=position, + ) + return Task.Command.CommandInput( + f'~{{sep(" ", prefix("{bc}", {name}))}}', position=position + ) elif array_sep and optional: # optional array with separator # ifdefname = f'(if defined({name}) then {name} else [])' - return Task.Command.CommandInput(f'~{{true="{bc}" false="" defined({name})}}~{{sep("{array_sep}", {name})}}', position=position) + return Task.Command.CommandInput( + f'~{{true="{bc}" false="" defined({name})}}~{{sep("{array_sep}", {name})}}', + position=position, + ) # build up new value from previous options value = name @@ -157,9 +168,14 @@ def from_fields(name: str, if (separate_value_from_prefix and prefix and prewithquotes) else f"'\"' + {prewithquotes}{value} + '\"'" ) - return Task.Command.CommandInput(f'~{{if defined({value}) then ({full_token}) else ""}}', position=position) + return Task.Command.CommandInput( + f'~{{if defined({value}) then ({full_token}) else ""}}', + position=position, + ) else: - return Task.Command.CommandInput(bc + f"~{{{value}}}", position=position) + return Task.Command.CommandInput( + bc + f"~{{{value}}}", position=position + ) def __init__( self, @@ -196,11 +212,13 @@ def get_string(self, indent: int = 0): def __init__( self, name: str, + pre_statements: Optional[List[Union[str, WdlBase]]] = None, inputs: List[Input] = None, outputs: List[Output] = None, + noninput_declarations: Optional[List[Union[str, WdlBase]]] = None, command: Command = None, runtime: Runtime = None, - version="draft-2", + version="development", meta: Meta = None, parameter_meta: ParameterMeta = None, ): @@ -214,27 +232,55 @@ def __init__( self.meta = meta self.param_meta = parameter_meta + self.pre_statements = pre_statements or [] + self.noninput_declarations = noninput_declarations or [] + self.format = """ -version {version} +{pre} task {name} {{ {blocks} }} """.strip() - def get_string(self): - tb = " " + def get_string(self, indent=0): + tb = (indent + 1) * " " name = self.name + preblocks = [f"version {self.version}"] blocks = [] + if self.pre_statements: + base_tab = indent * " " + format_obj = ( + lambda pi: pi.get_string(indent=indent) + if hasattr(pi, "get_string") + else (base_tab + str(pi)) + ) + preblocks.append("\n".join(format_obj(pi) for pi in self.pre_statements)) + if self.inputs: blocks.append( f"{tb}input {{\n" - + "\n".join(2 * tb + i.get_string() for i in self.inputs) + + "\n".join( + 2 * tb + (i.get_string() if hasattr(i, "get_string") else str(i)) + for i in self.inputs + ) + f"\n{tb}}}" ) + if self.noninput_declarations: + base_tab = " " * (indent + 1) + format_obj = ( + lambda pi: pi.get_string(indent=indent) + if hasattr(pi, "get_string") + else (base_tab + str(pi)) + ) + + blocks.append( + "\n".join(format_obj(pi) for pi in self.noninput_declarations) + ) + if self.command: if isinstance(self.command, list): @@ -245,30 +291,18 @@ def get_string(self): if self.runtime: rt = self.runtime.get_string(indent=2) - blocks.append( - "{tb}runtime {{\n{args}\n{tb}}}".format( - tb=tb, - args=rt, - ) - ) + blocks.append("{tb}runtime {{\n{args}\n{tb}}}".format(tb=tb, args=rt,)) if self.meta: mt = self.meta.get_string(indent=2) if mt: - blocks.append( - "{tb}meta {{\n{args}\n{tb}}}".format( - tb=tb, args=mt - ) - ) + blocks.append("{tb}meta {{\n{args}\n{tb}}}".format(tb=tb, args=mt)) if self.param_meta: pmt = self.param_meta.get_string(indent=2) if pmt: blocks.append( - "{tb}parameter_meta {{\n{args}\n{tb}}}".format( - tb=tb, - args=pmt - ) + "{tb}parameter_meta {{\n{args}\n{tb}}}".format(tb=tb, args=pmt) ) if self.outputs: @@ -280,5 +314,5 @@ def get_string(self): ) return self.format.format( - name=name, blocks="\n".join(blocks), version=self.version + name=name, pre="\n\n".join(preblocks), blocks="\n".join(blocks), ) diff --git a/wdlgen/types.py b/wdlgen/types.py index b91ca84..a871646 100644 --- a/wdlgen/types.py +++ b/wdlgen/types.py @@ -29,8 +29,9 @@ def __init__(self, prim_type): ) self._type = prim_type - def get_string(self): - return self._type + def get_string(self, indent=0): + tb = indent * " " + return tb + str(self._type) @staticmethod def parse(prim_type): @@ -48,9 +49,10 @@ def __init__(self, subtype, requires_multiple): self._subtype: WdlType = WdlType.parse_type(subtype, requires_type=True) self._requires_multiple: bool = requires_multiple - def get_string(self): + def get_string(self, indent=0): + tb = indent * " " - f = ArrayType.kArray + "[{t}]{quantifier}" + f = tb + ArrayType.kArray + "[{t}]{quantifier}" if isinstance(self._subtype, list): return [ @@ -111,13 +113,14 @@ def __init__(self, type_obj, optional=False): self._type = type_obj self.optional = optional - def get_string(self): + def get_string(self, indent=0): + tb = indent * " " wd = self._type.get_string() if isinstance(wd, list): - return [t + ("?" if self.optional else "") for t in wd] + return [tb + t + ("?" if self.optional else "") for t in wd] else: - return wd + ("?" if self.optional else "") + return tb + wd + ("?" if self.optional else "") @staticmethod def parse_type(t, requires_type=True): diff --git a/wdlgen/util.py b/wdlgen/util.py index a417261..4b3ea98 100644 --- a/wdlgen/util.py +++ b/wdlgen/util.py @@ -14,14 +14,16 @@ def convert_python_value_to_wdl_literal(val) -> str: if isinstance(val, bool): return "true" if val else "false" if isinstance(val, str): - return f'"{val}"' + # sanitise string here + sanitised = val.replace("\\", "\\\\").replace("\n", "\\n").replace('"', '\\"') + return f'"{sanitised}"' return str(val) class WdlBase(ABC): @abstractmethod - def get_string(self): + def get_string(self, indent=0): raise Exception("Subclass must override .get_string() method") diff --git a/wdlgen/workflow.py b/wdlgen/workflow.py index a258ef2..93ebb81 100644 --- a/wdlgen/workflow.py +++ b/wdlgen/workflow.py @@ -1,4 +1,4 @@ -from typing import List, Any +from typing import List, Any, Optional from .common import Input, Output from .util import WdlBase, Meta, ParameterMeta @@ -16,6 +16,7 @@ def __init__( version="draft-2", meta: Meta = None, parameter_meta: ParameterMeta = None, + post_import_statements: Optional[List[WdlBase]] = None, ): """ @@ -39,55 +40,69 @@ def __init__( self.meta = meta self.param_meta = parameter_meta - self.format = """ -version {version} + self.post_import_statements = post_import_statements -{imports_block} + self.format = """ +{pre} workflow {name} {{ {blocks} }}""".strip() - def get_string(self): - tb = " " + def get_string(self, indent=0): + tb = (indent + 1) * " " name = self.name - imports_block = "" + + preblocks = [] blocks = [] + if self.imports: + preblocks.append( + "\n".join(i.get_string(indent=indent) for i in self.imports) + ) + + if self.post_import_statements: + base_tab = " " * indent + format_obj = ( + lambda pi: pi.get_string(indent=indent) + if hasattr(pi, "get_string") + else (base_tab + str(pi)) + ) + + preblocks.append( + "\n".join(format_obj(pi) for pi in self.post_import_statements) + ) + if self.inputs: ins = [] for i in self.inputs: - wd = i.get_string() + wd = i.get_string(indent=indent + 1) if isinstance(wd, list): - ins.extend(2 * tb + ii for ii in wd) + ins.extend(wd) else: - ins.append(2 * tb + wd) + ins.append(wd) blocks.append(f"{tb}input {{\n" + "\n".join(ins) + f"\n{tb}}}") if self.calls: - blocks.append("\n".join(c.get_string(indent=1) for c in self.calls)) - - if self.imports: - imports_block = "\n".join(i.get_string() for i in self.imports) + base_tab = (indent + 1) * " " + format_obj = ( + lambda pi: pi.get_string(indent=indent+1) + if hasattr(pi, "get_string") + else (base_tab + str(pi)) + ) + blocks.append("\n".join(format_obj(c) for c in self.calls)) if self.meta: - mt = self.meta.get_string(indent=2) + mt = self.meta.get_string(indent=indent + 2) if mt: - blocks.append( - "{tb}meta {{\n{args}\n{tb}}}".format( - tb=tb, args=mt - ) - ) + blocks.append("{tb}meta {{\n{args}\n{tb}}}".format(tb=tb, args=mt)) if self.param_meta: pmt = self.param_meta.get_string(indent=2) if pmt: blocks.append( - "{tb}parameter_meta {{\n{args}\n{tb}}}".format( - tb=tb, - args=pmt - ) + "{tb}parameter_meta {{\n{args}\n{tb}}}".format(tb=tb, args=pmt) ) if self.outputs: @@ -108,7 +123,7 @@ def get_string(self): return self.format.format( name=name, - imports_block=imports_block, + pre="\n\n".join(preblocks), blocks="\n".join(blocks), version=self.version, ) @@ -122,9 +137,11 @@ def __init__(self, name: str, alias: str, tools_dir="tools/"): if tools_dir and not self.tools_dir.endswith("/"): tools_dir += "/" - def get_string(self): + def get_string(self, indent=0): + tb = " " * indent as_alias = " as " + self.alias if self.alias else "" - return 'import "{tools_dir}{tool}.wdl"{as_alias}'.format( + return '{tb}import "{tools_dir}{tool}.wdl"{as_alias}'.format( + tb=tb, tools_dir=self.tools_dir if self.tools_dir else "", tool=self.name, as_alias=as_alias,