Skip to content

Commit

Permalink
fix dreamerv2 config
Browse files Browse the repository at this point in the history
  • Loading branch information
tonycaisy committed May 21, 2024
1 parent 0ddf4f2 commit b4dde6b
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions dreamerv2/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def save(self, filename):
elif filename.suffix in ('.yml', '.yaml'):
import ruamel.yaml as yaml
with io.StringIO() as stream:
yaml.safe_dump(dict(self), stream)
yaml = yaml.YAML(typ='safe', pure=True)
yaml.dump(dict(self), stream)
filename.write(stream.getvalue())
else:
raise NotImplementedError(filename.suffix)
Expand All @@ -44,7 +45,8 @@ def load(cls, filename):
return cls(json.loads(filename.read_text()))
elif filename.suffix in ('.yml', '.yaml'):
import ruamel.yaml as yaml
return cls(yaml.safe_load(filename.read_text()))
yaml = yaml.YAML(typ='safe', pure=True)
return cls(yaml.load(filename.read_text()))
else:
raise NotImplementedError(filename.suffix)

Expand Down Expand Up @@ -174,8 +176,8 @@ def _ensure_values(self, mapping):
if len(value) == 0:
message = 'Empty lists are disallowed because their type is unclear.'
raise TypeError(message)
if not isinstance(value[0], (str, float, int, bool, list)):
message = 'Lists can only contain strings, floats, ints, bools, lists'
if not isinstance(value[0], (str, float, int, bool, list, dict)):
message = 'Lists can only contain strings, floats, ints, bools, lists, dict'
message += f' but not {type(value[0])}'
raise TypeError(message)
if not all(isinstance(x, type(value[0])) for x in value[1:]):
Expand Down

0 comments on commit b4dde6b

Please sign in to comment.