diff --git a/autotest/test_misc.py b/autotest/test_misc.py index 7b65641..283e5ff 100644 --- a/autotest/test_misc.py +++ b/autotest/test_misc.py @@ -9,6 +9,7 @@ import pytest from modflow_devtools.misc import ( + get_env, get_model_paths, get_namefile_paths, get_packages, @@ -280,3 +281,25 @@ def sleep1dec(): cap = capfd.readouterr() print(cap.out) assert re.match(r"sleep1dec took \d+\.\d+ ms", cap.out) + + +def test_get_env(): + assert get_env("NO_VALUE") is None + + with set_env(TEST_VALUE=str(True)): + assert get_env("NO_VALUE", True) == True + assert get_env("TEST_VALUE") == True + assert get_env("TEST_VALUE", default=False) == True + assert get_env("TEST_VALUE", default=1) == 1 + + with set_env(TEST_VALUE=str(1)): + assert get_env("NO_VALUE", 1) == 1 + assert get_env("TEST_VALUE") == 1 + assert get_env("TEST_VALUE", default=2) == 1 + assert get_env("TEST_VALUE", default=2.1) == 2.1 + + with set_env(TEST_VALUE=str(1.1)): + assert get_env("NO_VALUE", 1.1) == 1.1 + assert get_env("TEST_VALUE") == 1.1 + assert get_env("TEST_VALUE", default=2.1) == 1.1 + assert get_env("TEST_VALUE", default=False) == False diff --git a/modflow_devtools/misc.py b/modflow_devtools/misc.py index 4915d8c..c7fb0de 100644 --- a/modflow_devtools/misc.py +++ b/modflow_devtools/misc.py @@ -2,6 +2,7 @@ import socket import sys import traceback +from ast import literal_eval from contextlib import contextmanager from functools import wraps from importlib import metadata @@ -31,39 +32,6 @@ def set_dir(path: PathLike): print(f"Returned to previous directory: {origin}") -@contextmanager -def set_env(*remove, **update): - """ - Temporarily updates the ``os.environ`` dictionary in-place. - - Referenced from https://stackoverflow.com/a/34333710/6514033. - - The ``os.environ`` dictionary is updated in-place so that the modification - is sure to work in all situations. - - :param remove: Environment variables to remove. - :param update: Dictionary of environment variables and values to add/update. - """ - env = environ - update = update or {} - remove = remove or [] - - # List of environment variables being updated or removed. - stomped = (set(update.keys()) | set(remove)) & set(env.keys()) - # Environment variables and values to restore on exit. - update_after = {k: env[k] for k in stomped} - # Environment variables and values to remove on exit. - remove_after = frozenset(k for k in update if k not in env) - - try: - env.update(update) - [env.pop(k, None) for k in remove] - yield - finally: - env.update(update_after) - [env.pop(k) for k in remove_after] - - class add_sys_path: """ Context manager to add temporarily to the system path. @@ -486,3 +454,68 @@ def call(): return res return _timed + + +def get_env(name: str, default: object = None) -> Optional[object]: + """ + Try to parse the given environment variable as the type of the given + default value, if one is provided, otherwise any type is acceptable. + If the types of the parsed value and default value don't match, the + default value is returned. The environment variable is parsed as a + Python literal with `ast.literal_eval()`. + + Parameters + ---------- + name : str + The environment variable name + default : object + The default value if the environment variable does not exist + + Returns + ------- + The value of the environment variable, parsed as a Python literal, + otherwise the default value if the environment variable is not set. + """ + try: + v = environ.get(name) + if isinstance(default, bool): + v = v.lower().title() + v = literal_eval(v) + except: + return default + if default is None: + return v + return v if isinstance(v, type(default)) else default + + +@contextmanager +def set_env(*remove, **update): + """ + Temporarily updates the ``os.environ`` dictionary in-place. + + Referenced from https://stackoverflow.com/a/34333710/6514033. + + The ``os.environ`` dictionary is updated in-place so that the modification + is sure to work in all situations. + + :param remove: Environment variables to remove. + :param update: Dictionary of environment variables and values to add/update. + """ + env = environ + update = update or {} + remove = remove or [] + + # List of environment variables being updated or removed. + stomped = (set(update.keys()) | set(remove)) & set(env.keys()) + # Environment variables and values to restore on exit. + update_after = {k: env[k] for k in stomped} + # Environment variables and values to remove on exit. + remove_after = frozenset(k for k in update if k not in env) + + try: + env.update(update) + [env.pop(k, None) for k in remove] + yield + finally: + env.update(update_after) + [env.pop(k) for k in remove_after]