From b4dde6b62649cdd539d354373b62deb00419267f Mon Sep 17 00:00:00 2001 From: Shuangyu Cai Date: Mon, 20 May 2024 23:03:46 -0700 Subject: [PATCH] fix dreamerv2 config --- dreamerv2/core/config.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/dreamerv2/core/config.py b/dreamerv2/core/config.py index cb4983d..22799ed 100644 --- a/dreamerv2/core/config.py +++ b/dreamerv2/core/config.py @@ -32,7 +32,8 @@ def save(self, filename): elif filename.suffix in ('.yml', '.yaml'): import ruamel.yaml as yaml with io.StringIO() as stream: - yaml.safe_dump(dict(self), stream) + yaml = yaml.YAML(typ='safe', pure=True) + yaml.dump(dict(self), stream) filename.write(stream.getvalue()) else: raise NotImplementedError(filename.suffix) @@ -44,7 +45,8 @@ def load(cls, filename): return cls(json.loads(filename.read_text())) elif filename.suffix in ('.yml', '.yaml'): import ruamel.yaml as yaml - return cls(yaml.safe_load(filename.read_text())) + yaml = yaml.YAML(typ='safe', pure=True) + return cls(yaml.load(filename.read_text())) else: raise NotImplementedError(filename.suffix) @@ -174,8 +176,8 @@ def _ensure_values(self, mapping): if len(value) == 0: message = 'Empty lists are disallowed because their type is unclear.' raise TypeError(message) - if not isinstance(value[0], (str, float, int, bool, list)): - message = 'Lists can only contain strings, floats, ints, bools, lists' + if not isinstance(value[0], (str, float, int, bool, list, dict)): + message = 'Lists can only contain strings, floats, ints, bools, lists, dict' message += f' but not {type(value[0])}' raise TypeError(message) if not all(isinstance(x, type(value[0])) for x in value[1:]):