Skip to content

Commit

Permalink
[DLBACK-49] Actualize signature for EnvParamGetter.get_xxx_value()
Browse files Browse the repository at this point in the history
  • Loading branch information
kc41 committed Nov 21, 2023
1 parent 8d45e96 commit 14c32f3
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ class BigQuerySecretReader(BigQuerySecretReaderBase):

@_project_config.default
def _make_project_config(self) -> dict:
return self._env_param_getter.get_json_value(self.KEY_CONFIG)
return self._env_param_getter.get_json_value_strict(self.KEY_CONFIG)

@property
def project_config(self) -> dict:
return self._project_config

def get_creds(self) -> str:
return self._env_param_getter.get_str_value(self.KEY_CREDS)
return self._env_param_getter.get_str_value_strict(self.KEY_CREDS)
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,17 @@ class SnowFlakeSecretReader(SnowFlakeSecretReaderBase):

@_project_config.default
def _make_project_config(self) -> dict:
return self._env_param_getter.get_json_value(self.KEY_CONFIG)
return self._env_param_getter.get_json_value_strict(self.KEY_CONFIG)

@property
def project_config(self) -> dict:
return self._project_config

def get_client_secret(self) -> str:
return self._env_param_getter.get_str_value(self.KEY_CLIENT_SECRET)
return self._env_param_getter.get_str_value_strict(self.KEY_CLIENT_SECRET)

def get_refresh_token_expired(self) -> str:
return self._env_param_getter.get_str_value(self.KEY_REFRESH_TOKEN_EXPIRED)
return self._env_param_getter.get_str_value_strict(self.KEY_REFRESH_TOKEN_EXPIRED)

def get_refresh_token_x(self) -> str:
return self._env_param_getter.get_str_value(self.KEY_REFRESH_TOKEN_X)
return self._env_param_getter.get_str_value_strict(self.KEY_REFRESH_TOKEN_X)
4 changes: 3 additions & 1 deletion lib/dl_testing/dl_testing/env_params/generic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Optional

import attr
import yaml

Expand All @@ -12,7 +14,7 @@ class GenericEnvParamGetter(EnvParamGetter):
_loader: EnvParamGetterLoader = attr.ib(init=False, factory=EnvParamGetterLoader)
_key_mapping: dict[str, tuple[str, str]] = attr.ib(init=False, factory=dict) # key -> (getter_name, remapped_key)

def get_str_value(self, key: str) -> str:
def get_str_value(self, key: str) -> Optional[str]:
getter_name, remapped_key = self._key_mapping[key]
getter = self._loader.get_getter(getter_name)
value = getter.get_str_value(remapped_key)
Expand Down
49 changes: 42 additions & 7 deletions lib/dl_testing/dl_testing/env_params/getter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,61 @@

import abc
import json
from typing import Optional, NoReturn

import yaml


def _raise_error_no_key(key: str) -> NoReturn:
raise ValueError(f"Key {key!r} is missing")


class EnvParamGetter(abc.ABC):
@abc.abstractmethod
def get_str_value(self, key: str) -> str:
def get_str_value(self, key: str) -> Optional[str]:
raise NotImplementedError

def get_int_value(self, key: str) -> int:
def get_str_value_strict(self, key: str) -> str:
str_value = self.get_str_value(key)
if str_value is None:
_raise_error_no_key(key)
return str_value

def get_int_value(self, key: str) -> Optional[int]:
str_value = self.get_str_value(key)
return int(str_value)
if str_value is not None:
return int(str_value)
return None

def get_int_value_strict(self, key: str) -> int:
int_value = self.get_int_value(key)
if int_value is None:
_raise_error_no_key(key)
return int_value

def get_json_value(self, key: str) -> dict:
def get_json_value(self, key: str) -> Optional[dict]:
str_value = self.get_str_value(key)
return json.loads(str_value)
if str_value is not None:
return json.loads(str_value)
return None

def get_json_value_strict(self, key: str) -> dict:
json_value = self.get_json_value(key)
if json_value is None:
_raise_error_no_key(key)
return json_value

def get_yaml_value(self, key: str) -> dict:
def get_yaml_value(self, key: str) -> Optional[dict]:
str_value = self.get_str_value(key)
return yaml.safe_load(str_value)
if str_value is not None:
return yaml.safe_load(str_value)
return None

def get_yaml_value_strict(self, key: str) -> dict:
yaml_value = self.get_yaml_value(key)
if yaml_value is None:
_raise_error_no_key(key)
return yaml_value

def initialize(self, config: dict) -> None:
pass
3 changes: 2 additions & 1 deletion lib/dl_testing/dl_testing/env_params/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
ClassVar,
Mapping,
Sequence,
Optional,
)

import attr
Expand Down Expand Up @@ -40,7 +41,7 @@ def _auto_add_getter(self, name: str) -> None:
getter.initialize(config={})
self._getters[name] = getter

def _resolve_setting_item(self, setting: dict, requirement_getter: EnvParamGetter) -> str:
def _resolve_setting_item(self, setting: dict, requirement_getter: EnvParamGetter) -> Optional[str]:
if setting["type"] == "value":
return setting["value"]
if setting["type"] == "param":
Expand Down
3 changes: 2 additions & 1 deletion lib/dl_testing/dl_testing/env_params/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Optional

import attr
from dotenv import (
Expand Down Expand Up @@ -26,7 +27,7 @@ def initialize(self, config: dict) -> None:
env_file = os.environ.get("DL_TESTS_ENV_FILE") or find_dotenv(filename=".env")
self._env_from_file = dotenv_values(env_file)

def get_str_value(self, key: str) -> str:
def get_str_value(self, key: str) -> Optional[str]:
env_value = os.environ.get(key)

if env_value is None:
Expand Down

0 comments on commit 14c32f3

Please sign in to comment.