From 30a2d637323398aaded100eb2f6ceac0d4c0a95c Mon Sep 17 00:00:00 2001 From: Kaushik Ghose Date: Tue, 21 Jan 2020 05:13:43 -0500 Subject: [PATCH] Ensure type is properly initialized Commit 8a8deee handles self-references in the specification by placing the still being constructed type in the type dict before recursing into its children. This is the correct thing to do, but we forgot that now the type information is incomplete when it is initialized. Therefore the `required_fields` and `all_fields` variables were not set correctly. This caused type inference to fail in particular cases such as for `Dirent` We now re-run the init (the setup) once the type is completely constructed ... Version bumped to 2020.01.21. Happy New Year! Closes #77 Closes #76 --- benten/cwl/recordtype.py | 6 +++++- benten/cwl/specification.py | 4 +++- benten/version.py | 4 ++-- tests/test_language_specification.py | 5 ++++- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/benten/cwl/recordtype.py b/benten/cwl/recordtype.py index 8bb3901..4c6ac7c 100644 --- a/benten/cwl/recordtype.py +++ b/benten/cwl/recordtype.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 Seven Bridges. See LICENSE +# Copyright (c) 2019-2020 Seven Bridges. See LICENSE from typing import Dict @@ -28,6 +28,10 @@ def __init__(self, name: str, doc: str, fields: Dict[str, 'CWLFieldType']): self.required_fields = set((k for k, v in self.fields.items() if v.required)) self.all_fields = set(self.fields.keys()) + def init(self): + self.required_fields = set((k for k, v in self.fields.items() if v.required)) + self.all_fields = set(self.fields.keys()) + def check(self, node, node_key: str=None, map_sp: MapSubjectPredicate=None) -> TypeCheck: if node is None: diff --git a/benten/cwl/specification.py b/benten/cwl/specification.py index 1a4cb57..63bd14d 100644 --- a/benten/cwl/specification.py +++ b/benten/cwl/specification.py @@ -1,7 +1,7 @@ """Code to load the CWL specification from JSON and represent it as a set of types""" -# Copyright (c) 2019 Seven Bridges. See LICENSE +# Copyright (c) 2019-2020 Seven Bridges. See LICENSE import json @@ -128,6 +128,8 @@ def parse_record(schema, lang_model): for field in schema.get("fields") for k, v in [parse_field(field, lang_model)] }) + lang_model[record_name].init() + return lang_model.get(record_name) diff --git a/benten/version.py b/benten/version.py index febbe00..898f55f 100644 --- a/benten/version.py +++ b/benten/version.py @@ -1,3 +1,3 @@ -# Copyright (c) 2019 Seven Bridges. See LICENSE +# Copyright (c) 2019-2020 Seven Bridges. See LICENSE -__version__ = "2019.12.06" +__version__ = "2020.01.21" diff --git a/tests/test_language_specification.py b/tests/test_language_specification.py index 89a7092..96970ef 100644 --- a/tests/test_language_specification.py +++ b/tests/test_language_specification.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 Seven Bridges. See LICENSE +# Copyright (c) 2019-2020 Seven Bridges. See LICENSE import pathlib @@ -21,6 +21,9 @@ def test_load_language_specification(): assert "steps" in lang_model["Workflow"].fields assert lang_model["Workflow"].fields["steps"].required + # Ensure type is properly initialized after forward construction + assert "entry" in lang_model["Dirent"].required_fields + def test_forward_reference_resolution(): type_dict = parse_schema(schema_fname)