From 6ca429d7bebe77468df3010328e263d0b31fe7d5 Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Thu, 5 Sep 2024 12:37:30 +0200 Subject: [PATCH 1/2] custom env loader --- .flake8 | 1 + pyproject.toml | 1 - rialto/common/env_yaml.py | 28 ++++++++++++++++++++++++++++ rialto/common/utils.py | 3 ++- 4 files changed, 31 insertions(+), 2 deletions(-) create mode 100644 rialto/common/env_yaml.py diff --git a/.flake8 b/.flake8 index 21099b7..c2cf6c9 100644 --- a/.flake8 +++ b/.flake8 @@ -14,3 +14,4 @@ extend-ignore = D100, D104, D107, + E203, diff --git a/pyproject.toml b/pyproject.toml index 5812612..23aa34e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,6 @@ pandas = "^2.1.0" flake8-broken-line = "^1.0.0" loguru = "^0.7.2" importlib-metadata = "^7.2.1" -env_yaml = "^0.0.3" [tool.poetry.dev-dependencies] pyspark = "^3.4.1" diff --git a/rialto/common/env_yaml.py b/rialto/common/env_yaml.py new file mode 100644 index 0000000..ec2f591 --- /dev/null +++ b/rialto/common/env_yaml.py @@ -0,0 +1,28 @@ +import os +import re + +import yaml +from loguru import logger + +__all__ = ["EnvLoader"] + +_path_matcher = re.compile(r"\$\{(?P[^}^{:]+)(?::(?P[^}^{]*))?\}") + + +def _path_constructor(loader, node): + value = node.value + match = _path_matcher.match(value) + sub = os.getenv(match.group("env_name"), match.group("default_value")) + new_value = value[0 : match.start()] + sub + value[match.end() :] + logger.info(f"Config: Replacing {value}, with {new_value}") + return new_value + + +class EnvLoader(yaml.SafeLoader): + """Custom loader that replaces values with environment variables""" + + pass + + +EnvLoader.add_implicit_resolver("!env_substitute", _path_matcher, None) +EnvLoader.add_constructor("!env_substitute", _path_constructor) diff --git a/rialto/common/utils.py b/rialto/common/utils.py index b2e19b4..6f5ed1f 100644 --- a/rialto/common/utils.py +++ b/rialto/common/utils.py @@ -19,10 +19,11 @@ import pyspark.sql.functions as F import yaml -from env_yaml import EnvLoader from pyspark.sql import DataFrame from pyspark.sql.types import FloatType +from rialto.common.env_yaml import EnvLoader + def load_yaml(path: str) -> Any: """ From 3b6c5540f60e37ff6fa22cfe60c10803221b7dbe Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Thu, 5 Sep 2024 14:45:28 +0200 Subject: [PATCH 2/2] allow for env in the middle of string --- rialto/common/env_yaml.py | 19 ++++++--- tests/common/test_yaml.py | 81 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 6 deletions(-) create mode 100644 tests/common/test_yaml.py diff --git a/rialto/common/env_yaml.py b/rialto/common/env_yaml.py index ec2f591..e92bf67 100644 --- a/rialto/common/env_yaml.py +++ b/rialto/common/env_yaml.py @@ -6,16 +6,23 @@ __all__ = ["EnvLoader"] -_path_matcher = re.compile(r"\$\{(?P[^}^{:]+)(?::(?P[^}^{]*))?\}") +# Regex pattern to capture variable and the rest of the string +_path_matcher = re.compile(r"(?P.*)\$\{(?P[^}^{:]+)(?::(?P[^}^{]*))?\}(?P.*)") def _path_constructor(loader, node): value = node.value - match = _path_matcher.match(value) - sub = os.getenv(match.group("env_name"), match.group("default_value")) - new_value = value[0 : match.start()] + sub + value[match.end() :] - logger.info(f"Config: Replacing {value}, with {new_value}") - return new_value + match = _path_matcher.search(value) + if match: + before = match.group("before") + after = match.group("after") + sub = os.getenv(match.group("env_name"), match.group("default_value")) + if sub is None: + raise ValueError(f"Environment variable {match.group('env_name')} has no assigned value") + new_value = before + sub + after + logger.info(f"Config: Replacing {value}, with {new_value}") + return new_value + return value class EnvLoader(yaml.SafeLoader): diff --git a/tests/common/test_yaml.py b/tests/common/test_yaml.py new file mode 100644 index 0000000..9d63b66 --- /dev/null +++ b/tests/common/test_yaml.py @@ -0,0 +1,81 @@ +import os + +import pytest +import yaml + +from rialto.common.env_yaml import EnvLoader + + +def test_plain(): + data = {"a": "string_value", "b": 2} + cfg = """ + a: string_value + b: 2 + """ + assert yaml.load(cfg, EnvLoader) == data + + +def test_full_sub_default(): + data = {"a": "default_value", "b": 2} + cfg = """ + a: ${EMPTY_VAR:default_value} + b: 2 + """ + assert yaml.load(cfg, EnvLoader) == data + + +def test_full_sub_env(): + os.environ["FILLED_VAR"] = "env_value" + data = {"a": "env_value", "b": 2} + cfg = """ + a: ${FILLED_VAR:default_value} + b: 2 + """ + assert yaml.load(cfg, EnvLoader) == data + + +def test_partial_sub_start(): + data = {"a": "start_string", "b": 2} + cfg = """ + a: ${START_VAR:start}_string + b: 2 + """ + assert yaml.load(cfg, EnvLoader) == data + + +def test_partial_sub_end(): + data = {"a": "string_end", "b": 2} + cfg = """ + a: string_${END_VAR:end} + b: 2 + """ + assert yaml.load(cfg, EnvLoader) == data + + +def test_partial_sub_mid(): + data = {"a": "string_mid_sub", "b": 2} + cfg = """ + a: string_${MID_VAR:mid}_sub + b: 2 + """ + assert yaml.load(cfg, EnvLoader) == data + + +def test_partial_sub_no_default_no_value(): + with pytest.raises(Exception) as e: + cfg = """ + a: string_${MANDATORY_VAL_MISSING}_sub + b: 2 + """ + assert yaml.load(cfg, EnvLoader) + assert str(e.value) == "Environment variable MANDATORY_VAL_MISSING has no assigned value" + + +def test_partial_sub_no_default(): + os.environ["MANDATORY_VAL"] = "mandatory_value" + data = {"a": "string_mandatory_value_sub", "b": 2} + cfg = """ + a: string_${MANDATORY_VAL}_sub + b: 2 + """ + assert yaml.load(cfg, EnvLoader) == data