Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use file object directly in temporary_config() #1598

Merged
merged 6 commits into from
Jun 21, 2024
Merged
34 changes: 21 additions & 13 deletions dace/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import platform
import tempfile
import io
from typing import Any, Dict
import yaml
import warnings
Expand Down Expand Up @@ -39,10 +40,11 @@ def temporary_config():
Config.set("optimizer", "autooptimize", value=True)
foo()
"""
with tempfile.NamedTemporaryFile() as fp:
Config.save(fp.name)
with tempfile.NamedTemporaryFile(mode='w+t') as fp:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does 't' mean in mode?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It means text mode, on UNIX \n is the newline character, but on Windows it is \r\n (https://en.wikipedia.org/wiki/Newline#Representation). In that mode \r\n is transformed into \n it is essentially a compatibility layer, see help(open).
I used it because it preserves old behaviour, but it is probably not needed.

Config.save(file=fp)
yield
Config.load(fp.name)
fp.seek(0) # rewind to the beginning of the file.
Config.load(file=fp)


def _env2bool(envval):
Expand Down Expand Up @@ -157,19 +159,21 @@ def initialize():
Config.save(all=False)

@staticmethod
def load(filename=None):
def load(filename=None, file=None):
"""
Loads a configuration from an existing file.

:param filename: The file to load. If unspecified,
uses default configuration file.
:param file: Load the configuration from the file object.
"""
if filename is None:
filename = Config._cfg_filename

# Read configuration file
with open(filename, 'r') as f:
Config._config = yaml.load(f.read(), Loader=yaml.SafeLoader)
if file is not None:
assert filename is None
Config._config = yaml.load(file.read(), Loader=yaml.SafeLoader)
else:
with open(filename if filename else Config._cfg_filename, 'r') as f:
Config._config = yaml.load(f.read(), Loader=yaml.SafeLoader)

if Config._config is None:
Config._config = {}
Expand All @@ -191,16 +195,17 @@ def load_schema(filename=None):
Config._config_metadata = yaml.load(f.read(), Loader=yaml.SafeLoader)

@staticmethod
def save(path=None, all: bool = False):
def save(path=None, all: bool = False, file=None):
"""
Saves the current configuration to a file.

:param path: The file to save to. If unspecified,
uses default configuration file.
:param all: If False, only saves non-default configuration entries.
Otherwise saves all entries.
:param file: A file object to use directly.
"""
if path is None:
if path is None and file is None:
path = Config._cfg_filename
if path is None:
# Try to create a new config file in reversed priority order, and if all else fails keep config in memory
Expand All @@ -217,8 +222,11 @@ def save(path=None, all: bool = False):
return

# Write configuration file
with open(path, 'w') as f:
yaml.dump(Config._config if all else Config.nondefaults(), f, default_flow_style=False)
if file is not None:
yaml.dump(Config._config if all else Config.nondefaults(), file, default_flow_style=False)
else:
with open(path, 'w') as f:
yaml.dump(Config._config if all else Config.nondefaults(), f, default_flow_style=False)

@staticmethod
def get_metadata(*key_hierarchy):
Expand Down
12 changes: 11 additions & 1 deletion tests/config_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
from dace.config import set_temporary, Config
from dace.config import Config, set_temporary, temporary_config


def test_set_temporary():
Expand All @@ -10,5 +10,15 @@ def test_set_temporary():
assert Config.get(*path) == current_value


def test_temporary_config():
path = ["compiler", "build_type"]
current_value = Config.get(*path)
with temporary_config():
Config.set(*path, value="I'm not a build type")
assert Config.get(*path) == "I'm not a build type"
assert Config.get(*path) == current_value


if __name__ == '__main__':
test_set_temporary()
test_temporary_config()
Loading