diff --git a/.flake8 b/.flake8 index 21099b7..c2cf6c9 100644 --- a/.flake8 +++ b/.flake8 @@ -14,3 +14,4 @@ extend-ignore = D100, D104, D107, + E203, diff --git a/pyproject.toml b/pyproject.toml index 5812612..23aa34e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,6 @@ pandas = "^2.1.0" flake8-broken-line = "^1.0.0" loguru = "^0.7.2" importlib-metadata = "^7.2.1" -env_yaml = "^0.0.3" [tool.poetry.dev-dependencies] pyspark = "^3.4.1" diff --git a/rialto/common/env_yaml.py b/rialto/common/env_yaml.py new file mode 100644 index 0000000..e92bf67 --- /dev/null +++ b/rialto/common/env_yaml.py @@ -0,0 +1,35 @@ +import os +import re + +import yaml +from loguru import logger + +__all__ = ["EnvLoader"] + +# Regex pattern to capture variable and the rest of the string +_path_matcher = re.compile(r"(?P.*)\$\{(?P[^}^{:]+)(?::(?P[^}^{]*))?\}(?P.*)") + + +def _path_constructor(loader, node): + value = node.value + match = _path_matcher.search(value) + if match: + before = match.group("before") + after = match.group("after") + sub = os.getenv(match.group("env_name"), match.group("default_value")) + if sub is None: + raise ValueError(f"Environment variable {match.group('env_name')} has no assigned value") + new_value = before + sub + after + logger.info(f"Config: Replacing {value}, with {new_value}") + return new_value + return value + + +class EnvLoader(yaml.SafeLoader): + """Custom loader that replaces values with environment variables""" + + pass + + +EnvLoader.add_implicit_resolver("!env_substitute", _path_matcher, None) +EnvLoader.add_constructor("!env_substitute", _path_constructor) diff --git a/rialto/common/utils.py b/rialto/common/utils.py index b2e19b4..6f5ed1f 100644 --- a/rialto/common/utils.py +++ b/rialto/common/utils.py @@ -19,10 +19,11 @@ import pyspark.sql.functions as F import yaml -from env_yaml import EnvLoader from pyspark.sql import DataFrame from pyspark.sql.types import FloatType +from rialto.common.env_yaml import EnvLoader + def load_yaml(path: str) -> Any: """ diff --git a/tests/common/test_yaml.py b/tests/common/test_yaml.py new file mode 100644 index 0000000..9d63b66 --- /dev/null +++ b/tests/common/test_yaml.py @@ -0,0 +1,81 @@ +import os + +import pytest +import yaml + +from rialto.common.env_yaml import EnvLoader + + +def test_plain(): + data = {"a": "string_value", "b": 2} + cfg = """ + a: string_value + b: 2 + """ + assert yaml.load(cfg, EnvLoader) == data + + +def test_full_sub_default(): + data = {"a": "default_value", "b": 2} + cfg = """ + a: ${EMPTY_VAR:default_value} + b: 2 + """ + assert yaml.load(cfg, EnvLoader) == data + + +def test_full_sub_env(): + os.environ["FILLED_VAR"] = "env_value" + data = {"a": "env_value", "b": 2} + cfg = """ + a: ${FILLED_VAR:default_value} + b: 2 + """ + assert yaml.load(cfg, EnvLoader) == data + + +def test_partial_sub_start(): + data = {"a": "start_string", "b": 2} + cfg = """ + a: ${START_VAR:start}_string + b: 2 + """ + assert yaml.load(cfg, EnvLoader) == data + + +def test_partial_sub_end(): + data = {"a": "string_end", "b": 2} + cfg = """ + a: string_${END_VAR:end} + b: 2 + """ + assert yaml.load(cfg, EnvLoader) == data + + +def test_partial_sub_mid(): + data = {"a": "string_mid_sub", "b": 2} + cfg = """ + a: string_${MID_VAR:mid}_sub + b: 2 + """ + assert yaml.load(cfg, EnvLoader) == data + + +def test_partial_sub_no_default_no_value(): + with pytest.raises(Exception) as e: + cfg = """ + a: string_${MANDATORY_VAL_MISSING}_sub + b: 2 + """ + assert yaml.load(cfg, EnvLoader) + assert str(e.value) == "Environment variable MANDATORY_VAL_MISSING has no assigned value" + + +def test_partial_sub_no_default(): + os.environ["MANDATORY_VAL"] = "mandatory_value" + data = {"a": "string_mandatory_value_sub", "b": 2} + cfg = """ + a: string_${MANDATORY_VAL}_sub + b: 2 + """ + assert yaml.load(cfg, EnvLoader) == data