diff --git a/lib/dl_connector_bigquery/dl_connector_bigquery/testing/secrets.py b/lib/dl_connector_bigquery/dl_connector_bigquery/testing/secrets.py index a17ae1fb2..e14c27b5d 100644 --- a/lib/dl_connector_bigquery/dl_connector_bigquery/testing/secrets.py +++ b/lib/dl_connector_bigquery/dl_connector_bigquery/testing/secrets.py @@ -36,11 +36,11 @@ class BigQuerySecretReader(BigQuerySecretReaderBase): @_project_config.default def _make_project_config(self) -> dict: - return self._env_param_getter.get_json_value(self.KEY_CONFIG) + return self._env_param_getter.get_json_value_strict(self.KEY_CONFIG) @property def project_config(self) -> dict: return self._project_config def get_creds(self) -> str: - return self._env_param_getter.get_str_value(self.KEY_CREDS) + return self._env_param_getter.get_str_value_strict(self.KEY_CREDS) diff --git a/lib/dl_connector_snowflake/dl_connector_snowflake/core/testing/secrets.py b/lib/dl_connector_snowflake/dl_connector_snowflake/core/testing/secrets.py index c35b87b50..4de19ca6b 100644 --- a/lib/dl_connector_snowflake/dl_connector_snowflake/core/testing/secrets.py +++ b/lib/dl_connector_snowflake/dl_connector_snowflake/core/testing/secrets.py @@ -64,17 +64,17 @@ class SnowFlakeSecretReader(SnowFlakeSecretReaderBase): @_project_config.default def _make_project_config(self) -> dict: - return self._env_param_getter.get_json_value(self.KEY_CONFIG) + return self._env_param_getter.get_json_value_strict(self.KEY_CONFIG) @property def project_config(self) -> dict: return self._project_config def get_client_secret(self) -> str: - return self._env_param_getter.get_str_value(self.KEY_CLIENT_SECRET) + return self._env_param_getter.get_str_value_strict(self.KEY_CLIENT_SECRET) def get_refresh_token_expired(self) -> str: - return self._env_param_getter.get_str_value(self.KEY_REFRESH_TOKEN_EXPIRED) + return self._env_param_getter.get_str_value_strict(self.KEY_REFRESH_TOKEN_EXPIRED) def get_refresh_token_x(self) -> str: - return self._env_param_getter.get_str_value(self.KEY_REFRESH_TOKEN_X) + return self._env_param_getter.get_str_value_strict(self.KEY_REFRESH_TOKEN_X) diff --git a/lib/dl_testing/dl_testing/env_params/generic.py b/lib/dl_testing/dl_testing/env_params/generic.py index 306555893..16e9c14e1 100644 --- a/lib/dl_testing/dl_testing/env_params/generic.py +++ b/lib/dl_testing/dl_testing/env_params/generic.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Optional + import attr import yaml @@ -12,7 +14,7 @@ class GenericEnvParamGetter(EnvParamGetter): _loader: EnvParamGetterLoader = attr.ib(init=False, factory=EnvParamGetterLoader) _key_mapping: dict[str, tuple[str, str]] = attr.ib(init=False, factory=dict) # key -> (getter_name, remapped_key) - def get_str_value(self, key: str) -> str: + def get_str_value(self, key: str) -> Optional[str]: getter_name, remapped_key = self._key_mapping[key] getter = self._loader.get_getter(getter_name) value = getter.get_str_value(remapped_key) diff --git a/lib/dl_testing/dl_testing/env_params/getter.py b/lib/dl_testing/dl_testing/env_params/getter.py index f97923050..a4d859ef3 100644 --- a/lib/dl_testing/dl_testing/env_params/getter.py +++ b/lib/dl_testing/dl_testing/env_params/getter.py @@ -2,26 +2,65 @@ import abc import json +from typing import ( + NoReturn, + Optional, +) import yaml +def _raise_error_no_key(key: str) -> NoReturn: + raise ValueError(f"Key {key!r} is missing") + + class EnvParamGetter(abc.ABC): @abc.abstractmethod - def get_str_value(self, key: str) -> str: + def get_str_value(self, key: str) -> Optional[str]: raise NotImplementedError - def get_int_value(self, key: str) -> int: + def get_str_value_strict(self, key: str) -> str: + str_value = self.get_str_value(key) + if str_value is None: + _raise_error_no_key(key) + return str_value + + def get_int_value(self, key: str) -> Optional[int]: str_value = self.get_str_value(key) - return int(str_value) + if str_value is not None: + return int(str_value) + return None - def get_json_value(self, key: str) -> dict: + def get_int_value_strict(self, key: str) -> int: + int_value = self.get_int_value(key) + if int_value is None: + _raise_error_no_key(key) + return int_value + + def get_json_value(self, key: str) -> Optional[dict]: str_value = self.get_str_value(key) - return json.loads(str_value) + if str_value is not None: + return json.loads(str_value) + return None + + def get_json_value_strict(self, key: str) -> dict: + json_value = self.get_json_value(key) + if json_value is None: + _raise_error_no_key(key) + return json_value - def get_yaml_value(self, key: str) -> dict: + def get_yaml_value(self, key: str) -> Optional[dict]: str_value = self.get_str_value(key) - return yaml.safe_load(str_value) + if str_value is not None: + return yaml.safe_load(str_value) + return None + def get_yaml_value_strict(self, key: str) -> dict: + yaml_value = self.get_yaml_value(key) + if yaml_value is None: + _raise_error_no_key(key) + return yaml_value + + @abc.abstractmethod def initialize(self, config: dict) -> None: pass diff --git a/lib/dl_testing/dl_testing/env_params/loader.py b/lib/dl_testing/dl_testing/env_params/loader.py index 21e1e02bb..83224cb31 100644 --- a/lib/dl_testing/dl_testing/env_params/loader.py +++ b/lib/dl_testing/dl_testing/env_params/loader.py @@ -1,6 +1,7 @@ from typing import ( ClassVar, Mapping, + Optional, Sequence, ) @@ -40,7 +41,7 @@ def _auto_add_getter(self, name: str) -> None: getter.initialize(config={}) self._getters[name] = getter - def _resolve_setting_item(self, setting: dict, requirement_getter: EnvParamGetter) -> str: + def _resolve_setting_item(self, setting: dict, requirement_getter: EnvParamGetter) -> Optional[str]: if setting["type"] == "value": return setting["value"] if setting["type"] == "param": diff --git a/lib/dl_testing/dl_testing/env_params/main.py b/lib/dl_testing/dl_testing/env_params/main.py index 7d02f2acf..981fc20c7 100644 --- a/lib/dl_testing/dl_testing/env_params/main.py +++ b/lib/dl_testing/dl_testing/env_params/main.py @@ -1,4 +1,5 @@ import os +from typing import Optional import attr from dotenv import ( @@ -15,6 +16,9 @@ class DirectEnvParamGetter(EnvParamGetter): def get_str_value(self, key: str) -> str: return str(key) + def initialize(self, config: dict) -> None: + pass + @attr.s class OsEnvParamGetter(EnvParamGetter): @@ -26,7 +30,7 @@ def initialize(self, config: dict) -> None: env_file = os.environ.get("DL_TESTS_ENV_FILE") or find_dotenv(filename=".env") self._env_from_file = dotenv_values(env_file) - def get_str_value(self, key: str) -> str: + def get_str_value(self, key: str) -> Optional[str]: env_value = os.environ.get(key) if env_value is None: diff --git a/lib/dl_testing/dl_testing/regulated_test.py b/lib/dl_testing/dl_testing/regulated_test.py index 8bece74d4..081e99d78 100644 --- a/lib/dl_testing/dl_testing/regulated_test.py +++ b/lib/dl_testing/dl_testing/regulated_test.py @@ -193,7 +193,9 @@ def regulated_test_case(test_cls: type, /) -> type: @overload -def regulated_test_case(*, test_params: RegulatedTestParams = RegulatedTestParams()) -> Callable[[type], type]: +def regulated_test_case( + *, test_params: RegulatedTestParams = RegulatedTestParams() # noqa B008 +) -> Callable[[type], type]: ...