diff --git a/benten/cwl/recordtype.py b/benten/cwl/recordtype.py index 8bb3901..4c6ac7c 100644 --- a/benten/cwl/recordtype.py +++ b/benten/cwl/recordtype.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 Seven Bridges. See LICENSE +# Copyright (c) 2019-2020 Seven Bridges. See LICENSE from typing import Dict @@ -28,6 +28,10 @@ def __init__(self, name: str, doc: str, fields: Dict[str, 'CWLFieldType']): self.required_fields = set((k for k, v in self.fields.items() if v.required)) self.all_fields = set(self.fields.keys()) + def init(self): + self.required_fields = set((k for k, v in self.fields.items() if v.required)) + self.all_fields = set(self.fields.keys()) + def check(self, node, node_key: str=None, map_sp: MapSubjectPredicate=None) -> TypeCheck: if node is None: diff --git a/benten/cwl/specification.py b/benten/cwl/specification.py index 1a4cb57..63bd14d 100644 --- a/benten/cwl/specification.py +++ b/benten/cwl/specification.py @@ -1,7 +1,7 @@ """Code to load the CWL specification from JSON and represent it as a set of types""" -# Copyright (c) 2019 Seven Bridges. See LICENSE +# Copyright (c) 2019-2020 Seven Bridges. See LICENSE import json @@ -128,6 +128,8 @@ def parse_record(schema, lang_model): for field in schema.get("fields") for k, v in [parse_field(field, lang_model)] }) + lang_model[record_name].init() + return lang_model.get(record_name) diff --git a/benten/version.py b/benten/version.py index febbe00..898f55f 100644 --- a/benten/version.py +++ b/benten/version.py @@ -1,3 +1,3 @@ -# Copyright (c) 2019 Seven Bridges. See LICENSE +# Copyright (c) 2019-2020 Seven Bridges. See LICENSE -__version__ = "2019.12.06" +__version__ = "2020.01.21" diff --git a/tests/test_language_specification.py b/tests/test_language_specification.py index 89a7092..96970ef 100644 --- a/tests/test_language_specification.py +++ b/tests/test_language_specification.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 Seven Bridges. See LICENSE +# Copyright (c) 2019-2020 Seven Bridges. See LICENSE import pathlib @@ -21,6 +21,9 @@ def test_load_language_specification(): assert "steps" in lang_model["Workflow"].fields assert lang_model["Workflow"].fields["steps"].required + # Ensure type is properly initialized after forward construction + assert "entry" in lang_model["Dirent"].required_fields + def test_forward_reference_resolution(): type_dict = parse_schema(schema_fname)