Skip to content

Commit c277f96

Browse files
committed
[DLBACK-49] Actualize signature for EnvParamGetter.get_xxx_value()
1 parent 8d45e96 commit c277f96

File tree

7 files changed

+65
-17
lines changed

7 files changed

+65
-17
lines changed

lib/dl_connector_bigquery/dl_connector_bigquery/testing/secrets.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ class BigQuerySecretReader(BigQuerySecretReaderBase):
3636

3737
@_project_config.default
3838
def _make_project_config(self) -> dict:
39-
return self._env_param_getter.get_json_value(self.KEY_CONFIG)
39+
return self._env_param_getter.get_json_value_strict(self.KEY_CONFIG)
4040

4141
@property
4242
def project_config(self) -> dict:
4343
return self._project_config
4444

4545
def get_creds(self) -> str:
46-
return self._env_param_getter.get_str_value(self.KEY_CREDS)
46+
return self._env_param_getter.get_str_value_strict(self.KEY_CREDS)

lib/dl_connector_snowflake/dl_connector_snowflake/core/testing/secrets.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,17 @@ class SnowFlakeSecretReader(SnowFlakeSecretReaderBase):
6464

6565
@_project_config.default
6666
def _make_project_config(self) -> dict:
67-
return self._env_param_getter.get_json_value(self.KEY_CONFIG)
67+
return self._env_param_getter.get_json_value_strict(self.KEY_CONFIG)
6868

6969
@property
7070
def project_config(self) -> dict:
7171
return self._project_config
7272

7373
def get_client_secret(self) -> str:
74-
return self._env_param_getter.get_str_value(self.KEY_CLIENT_SECRET)
74+
return self._env_param_getter.get_str_value_strict(self.KEY_CLIENT_SECRET)
7575

7676
def get_refresh_token_expired(self) -> str:
77-
return self._env_param_getter.get_str_value(self.KEY_REFRESH_TOKEN_EXPIRED)
77+
return self._env_param_getter.get_str_value_strict(self.KEY_REFRESH_TOKEN_EXPIRED)
7878

7979
def get_refresh_token_x(self) -> str:
80-
return self._env_param_getter.get_str_value(self.KEY_REFRESH_TOKEN_X)
80+
return self._env_param_getter.get_str_value_strict(self.KEY_REFRESH_TOKEN_X)

lib/dl_testing/dl_testing/env_params/generic.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import Optional
4+
35
import attr
46
import yaml
57

@@ -12,7 +14,7 @@ class GenericEnvParamGetter(EnvParamGetter):
1214
_loader: EnvParamGetterLoader = attr.ib(init=False, factory=EnvParamGetterLoader)
1315
_key_mapping: dict[str, tuple[str, str]] = attr.ib(init=False, factory=dict) # key -> (getter_name, remapped_key)
1416

15-
def get_str_value(self, key: str) -> str:
17+
def get_str_value(self, key: str) -> Optional[str]:
1618
getter_name, remapped_key = self._key_mapping[key]
1719
getter = self._loader.get_getter(getter_name)
1820
value = getter.get_str_value(remapped_key)

lib/dl_testing/dl_testing/env_params/getter.py

+46-7
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,65 @@
22

33
import abc
44
import json
5+
from typing import (
6+
NoReturn,
7+
Optional,
8+
)
59

610
import yaml
711

812

13+
def _raise_error_no_key(key: str) -> NoReturn:
14+
raise ValueError(f"Key {key!r} is missing")
15+
16+
917
class EnvParamGetter(abc.ABC):
1018
@abc.abstractmethod
11-
def get_str_value(self, key: str) -> str:
19+
def get_str_value(self, key: str) -> Optional[str]:
1220
raise NotImplementedError
1321

14-
def get_int_value(self, key: str) -> int:
22+
def get_str_value_strict(self, key: str) -> str:
23+
str_value = self.get_str_value(key)
24+
if str_value is None:
25+
_raise_error_no_key(key)
26+
return str_value
27+
28+
def get_int_value(self, key: str) -> Optional[int]:
1529
str_value = self.get_str_value(key)
16-
return int(str_value)
30+
if str_value is not None:
31+
return int(str_value)
32+
return None
1733

18-
def get_json_value(self, key: str) -> dict:
34+
def get_int_value_strict(self, key: str) -> int:
35+
int_value = self.get_int_value(key)
36+
if int_value is None:
37+
_raise_error_no_key(key)
38+
return int_value
39+
40+
def get_json_value(self, key: str) -> Optional[dict]:
1941
str_value = self.get_str_value(key)
20-
return json.loads(str_value)
42+
if str_value is not None:
43+
return json.loads(str_value)
44+
return None
45+
46+
def get_json_value_strict(self, key: str) -> dict:
47+
json_value = self.get_json_value(key)
48+
if json_value is None:
49+
_raise_error_no_key(key)
50+
return json_value
2151

22-
def get_yaml_value(self, key: str) -> dict:
52+
def get_yaml_value(self, key: str) -> Optional[dict]:
2353
str_value = self.get_str_value(key)
24-
return yaml.safe_load(str_value)
54+
if str_value is not None:
55+
return yaml.safe_load(str_value)
56+
return None
2557

58+
def get_yaml_value_strict(self, key: str) -> dict:
59+
yaml_value = self.get_yaml_value(key)
60+
if yaml_value is None:
61+
_raise_error_no_key(key)
62+
return yaml_value
63+
64+
@abc.abstractmethod
2665
def initialize(self, config: dict) -> None:
2766
pass

lib/dl_testing/dl_testing/env_params/loader.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import (
22
ClassVar,
33
Mapping,
4+
Optional,
45
Sequence,
56
)
67

@@ -40,7 +41,7 @@ def _auto_add_getter(self, name: str) -> None:
4041
getter.initialize(config={})
4142
self._getters[name] = getter
4243

43-
def _resolve_setting_item(self, setting: dict, requirement_getter: EnvParamGetter) -> str:
44+
def _resolve_setting_item(self, setting: dict, requirement_getter: EnvParamGetter) -> Optional[str]:
4445
if setting["type"] == "value":
4546
return setting["value"]
4647
if setting["type"] == "param":

lib/dl_testing/dl_testing/env_params/main.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from typing import Optional
23

34
import attr
45
from dotenv import (
@@ -15,6 +16,9 @@ class DirectEnvParamGetter(EnvParamGetter):
1516
def get_str_value(self, key: str) -> str:
1617
return str(key)
1718

19+
def initialize(self, config: dict) -> None:
20+
pass
21+
1822

1923
@attr.s
2024
class OsEnvParamGetter(EnvParamGetter):
@@ -26,7 +30,7 @@ def initialize(self, config: dict) -> None:
2630
env_file = os.environ.get("DL_TESTS_ENV_FILE") or find_dotenv(filename=".env")
2731
self._env_from_file = dotenv_values(env_file)
2832

29-
def get_str_value(self, key: str) -> str:
33+
def get_str_value(self, key: str) -> Optional[str]:
3034
env_value = os.environ.get(key)
3135

3236
if env_value is None:

lib/dl_testing/dl_testing/regulated_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def regulated_test_case(test_cls: type, /) -> type:
193193

194194

195195
@overload
196-
def regulated_test_case(*, test_params: RegulatedTestParams = RegulatedTestParams()) -> Callable[[type], type]:
196+
def regulated_test_case(
197+
*, test_params: RegulatedTestParams = RegulatedTestParams() # noqa B008
198+
) -> Callable[[type], type]:
197199
...
198200

199201

0 commit comments

Comments
 (0)