Skip to content

Commit

Permalink
Use file object directly in temporary_config() (#1598)
Browse files Browse the repository at this point in the history
The context manager uses `NamedTemporaryFile` to store the current
configuration, to later restore them. Instead of passing the file object
directly to the save function, it just passes the file name, i.e. the
save (and the load function) will open the file again, which is in
itself not a problem. However, on the Github Windows image this leads to
a permission error (using the created file object is fine).

This commit solves this by adding the `file` argument to `Config.save()`
that allows to pass a file object directly to the function. The same
change is applied to the load function of the config object.
  • Loading branch information
philip-paul-mueller authored Jun 21, 2024
1 parent 93b557f commit 6a490ec
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
34 changes: 21 additions & 13 deletions dace/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import platform
import tempfile
import io
from typing import Any, Dict
import yaml
import warnings
Expand Down Expand Up @@ -39,10 +40,11 @@ def temporary_config():
Config.set("optimizer", "autooptimize", value=True)
foo()
"""
with tempfile.NamedTemporaryFile() as fp:
Config.save(fp.name)
with tempfile.NamedTemporaryFile(mode='w+t') as fp:
Config.save(file=fp)
yield
Config.load(fp.name)
fp.seek(0) # rewind to the beginning of the file.
Config.load(file=fp)


def _env2bool(envval):
Expand Down Expand Up @@ -157,19 +159,21 @@ def initialize():
Config.save(all=False)

@staticmethod
def load(filename=None):
def load(filename=None, file=None):
"""
Loads a configuration from an existing file.
:param filename: The file to load. If unspecified,
uses default configuration file.
:param file: Load the configuration from the file object.
"""
if filename is None:
filename = Config._cfg_filename

# Read configuration file
with open(filename, 'r') as f:
Config._config = yaml.load(f.read(), Loader=yaml.SafeLoader)
if file is not None:
assert filename is None
Config._config = yaml.load(file.read(), Loader=yaml.SafeLoader)
else:
with open(filename if filename else Config._cfg_filename, 'r') as f:
Config._config = yaml.load(f.read(), Loader=yaml.SafeLoader)

if Config._config is None:
Config._config = {}
Expand All @@ -191,16 +195,17 @@ def load_schema(filename=None):
Config._config_metadata = yaml.load(f.read(), Loader=yaml.SafeLoader)

@staticmethod
def save(path=None, all: bool = False):
def save(path=None, all: bool = False, file=None):
"""
Saves the current configuration to a file.
:param path: The file to save to. If unspecified,
uses default configuration file.
:param all: If False, only saves non-default configuration entries.
Otherwise saves all entries.
:param file: A file object to use directly.
"""
if path is None:
if path is None and file is None:
path = Config._cfg_filename
if path is None:
# Try to create a new config file in reversed priority order, and if all else fails keep config in memory
Expand All @@ -217,8 +222,11 @@ def save(path=None, all: bool = False):
return

# Write configuration file
with open(path, 'w') as f:
yaml.dump(Config._config if all else Config.nondefaults(), f, default_flow_style=False)
if file is not None:
yaml.dump(Config._config if all else Config.nondefaults(), file, default_flow_style=False)
else:
with open(path, 'w') as f:
yaml.dump(Config._config if all else Config.nondefaults(), f, default_flow_style=False)

@staticmethod
def get_metadata(*key_hierarchy):
Expand Down
12 changes: 11 additions & 1 deletion tests/config_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
from dace.config import set_temporary, Config
from dace.config import Config, set_temporary, temporary_config


def test_set_temporary():
Expand All @@ -10,5 +10,15 @@ def test_set_temporary():
assert Config.get(*path) == current_value


def test_temporary_config():
path = ["compiler", "build_type"]
current_value = Config.get(*path)
with temporary_config():
Config.set(*path, value="I'm not a build type")
assert Config.get(*path) == "I'm not a build type"
assert Config.get(*path) == current_value


if __name__ == '__main__':
test_set_temporary()
test_temporary_config()

0 comments on commit 6a490ec

Please sign in to comment.