Skip to content

Commit 2b295a3

Browse files
authored
Merge pull request #64 from stac-labs/nested-secrets
Nested secrets & S3 secret loading
2 parents 7f62e91 + 2855020 commit 2b295a3

File tree

4 files changed

+256
-16
lines changed

4 files changed

+256
-16
lines changed

src/stac_utils/aws.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import json
2+
import logging
23
import os.path as op
34
from tempfile import TemporaryDirectory
45

56
import boto3
67
from botocore.exceptions import ClientError
78

89

10+
logger = logging.getLogger(__name__)
11+
12+
913
def get_secret(region_name: str, secret_name: str) -> dict:
1014
# Create a Secrets Manager client
1115
"""
@@ -43,7 +47,18 @@ def write_secret(region_name: str, secret_name: str, secret: dict):
4347
)
4448

4549

46-
def load_from_s3(bucket: str, path: str, file_name: str) -> dict:
50+
def split_s3_url(url: str) -> tuple[str, str, str]:
51+
prefix = "s3://"
52+
if url.startswith(prefix):
53+
url = url[len(prefix) :]
54+
55+
bucket, _, fpath = url.partition("/")
56+
path, _, file_name = fpath.rpartition("/")
57+
58+
return bucket, path, file_name
59+
60+
61+
def load_from_s3(bucket: str, path: [str, None], file_name: str) -> dict:
4762
"""
4863
Returns data from s3 given bucket, path, and file name
4964
@@ -52,8 +67,9 @@ def load_from_s3(bucket: str, path: str, file_name: str) -> dict:
5267
:param file_name: Name of file to load
5368
:return: Data from specified file
5469
"""
70+
path = path or ""
5571
s3 = boto3.resource("s3").Bucket(bucket)
56-
key = path.strip("/") + "/" + file_name
72+
key = (path.strip("/") + "/" + file_name).lstrip("/")
5773

5874
data = {}
5975
with TemporaryDirectory() as temp:
@@ -66,12 +82,13 @@ def load_from_s3(bucket: str, path: str, file_name: str) -> dict:
6682
raise e
6783
data = {}
6884
except json.JSONDecodeError:
85+
logger.warning(f"{key} is not a JSON file!")
6986
data = {}
7087

7188
return data
7289

7390

74-
def save_to_s3(data: dict, bucket: str, path: str, file_name: str):
91+
def save_to_s3(data: dict, bucket: str, path: [str, None], file_name: str):
7592
"""
7693
Saves data to s3 in specified location
7794
@@ -81,8 +98,9 @@ def save_to_s3(data: dict, bucket: str, path: str, file_name: str):
8198
:param file_name: Desired file name
8299
:return: Data
83100
"""
101+
path = path or ""
84102
s3 = boto3.resource("s3").Bucket(bucket)
85-
key = path.strip("/") + "/" + file_name
103+
key = (path.strip("/") + "/" + file_name).lstrip("/")
86104

87105
with TemporaryDirectory() as temp:
88106
temp_file = op.join(temp, file_name)

src/stac_utils/secret_context.py

+65-3
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,23 @@
22
import os
33
from unittest.mock import patch
44

5-
from .aws import get_secret
5+
from .aws import get_secret, split_s3_url, load_from_s3
66

77

88
def secrets(
99
file_name: str = None,
1010
secret_name: str = None,
1111
aws_region: str = None,
1212
dictionary: dict = None,
13+
s3_url: str = None,
1314
):
1415
"""
15-
Takes either a file, a Python dict, or an AWS secret in Secrets Manager and loads it all into the os.environ as a context.
16+
Takes any combination of
17+
* local JSON file
18+
* Python dictionary
19+
* AWS secret in Secrets Manager
20+
* JSON file on S3
21+
and loads it all into the os.environ as a context.
1622
1723
Usage is typically:
1824
@@ -23,8 +29,10 @@ def secrets(
2329
:param secret_name: Desired secret_name
2430
:param aws_region: Desired AWS region for secret
2531
:param dictionary: Specified dictionary
32+
:param s3_url: S3 URL to JSON file
2633
:return:
2734
"""
35+
2836
values = {}
2937
if not secret_name and os.environ.get("SECRET_NAME"):
3038
secret_name = os.environ.get("SECRET_NAME")
@@ -39,12 +47,66 @@ def secrets(
3947
secret_name or os.environ["SECRET_NAME"],
4048
)
4149
)
50+
if not s3_url and os.environ.get("SECRET_S3_URL"):
51+
s3_url = os.environ.get("SECRET_S3_URL")
52+
# blank secret s3 url in the context, so it doesn't get loaded a second time
53+
# if we nest secrets
54+
values["SECRET_S3_URL"] = ""
55+
56+
if s3_url:
57+
values.update(load_from_s3(*split_s3_url(s3_url)))
4258

4359
if file_name:
4460
values.update(json.load(open(file_name, "rt")))
4561
if dictionary:
4662
values.update(dictionary)
4763

4864
# the patcher doesn't like non-string keys OR values
49-
values = {str(k): str(v) if v is not None else "" for k, v in values.items()}
65+
values = {
66+
str(k): safe_dump_json_to_string(v) if v is not None else ""
67+
for k, v in values.items()
68+
}
69+
5070
return patch.dict(os.environ, values=values)
71+
72+
73+
def safe_dump_json_to_string(value: [list, dict, str, int, tuple, float, None]) -> str:
74+
"""Utility function to encode values to string, working through nested dictionaries & list"""
75+
76+
if type(value) in [dict]:
77+
return json.dumps(
78+
{str(k): safe_dump_json_to_string(v) for k, v in value.items()}
79+
)
80+
81+
if type(value) in [list, tuple]:
82+
return json.dumps([safe_dump_json_to_string(v) for v in value])
83+
84+
if value is None:
85+
return "null"
86+
87+
return str(value)
88+
89+
90+
def safe_load_string_to_json(value: str) -> [list, dict, str, int, float, None]:
91+
"""Utility function to decode values from string, restoring all nested dictionaries & list"""
92+
93+
try:
94+
loaded = json.loads(value)
95+
except (json.decoder.JSONDecodeError, TypeError):
96+
loaded = value
97+
98+
if type(loaded) in [dict]:
99+
return {k: safe_load_string_to_json(v) for k, v in loaded.items()}
100+
101+
if type(loaded) in [list]:
102+
return [safe_load_string_to_json(v) for v in loaded]
103+
104+
return loaded
105+
106+
107+
def get_env(key: str, default=None) -> [list, dict, str]:
108+
"""Utility function combining os.environ.get & safe_load_from_json"""
109+
110+
value = os.environ.get(key, default=default)
111+
112+
return safe_load_string_to_json(value)

src/tests/test_aws.py

+53-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
import json
22
import unittest
3-
from unittest.mock import MagicMock, patch, call
3+
from unittest.mock import MagicMock, patch
44
from botocore.exceptions import ClientError
55

6-
from src.stac_utils.aws import get_secret, write_secret, load_from_s3, save_to_s3
6+
from src.stac_utils.aws import (
7+
get_secret,
8+
write_secret,
9+
load_from_s3,
10+
save_to_s3,
11+
split_s3_url,
12+
)
713

814

915
class TestAWS(unittest.TestCase):
@@ -62,6 +68,27 @@ def test_load_from_s3(
6268
result_data = load_from_s3("foo", "bar", "spam")
6369
self.assertEqual(mock_data, result_data)
6470
mock_boto.return_value.Bucket.return_value.download_file.assert_called_once()
71+
self.assertEqual(
72+
mock_boto.return_value.Bucket.return_value.download_file.call_args[0][0],
73+
"bar/spam",
74+
)
75+
76+
@patch("json.load")
77+
@patch("src.stac_utils.aws.open")
78+
@patch("boto3.resource")
79+
def test_load_from_s3_no_path(
80+
self, mock_boto: MagicMock, mock_open: MagicMock, mock_load: MagicMock
81+
):
82+
"""Test load from s3 when path is None"""
83+
mock_data = {"foo": "bar"}
84+
mock_load.return_value = mock_data
85+
result_data = load_from_s3("foo", None, "spam")
86+
self.assertEqual(mock_data, result_data)
87+
mock_boto.return_value.Bucket.return_value.download_file.assert_called_once()
88+
self.assertEqual(
89+
mock_boto.return_value.Bucket.return_value.download_file.call_args[0][0],
90+
"spam",
91+
)
6592

6693
@patch("json.load")
6794
@patch("src.stac_utils.aws.open")
@@ -128,6 +155,30 @@ def test_save_to_s3_client_error(
128155
)
129156
self.assertRaises(ClientError, save_to_s3, mock_data, "foo", "bar", "spam")
130157

158+
def test_split_s3_url(self):
159+
"""Test split S3 url"""
160+
161+
test_url = "s3://foo-bucket/bar-path/spam-key.json"
162+
self.assertTupleEqual(
163+
split_s3_url(test_url), ("foo-bucket", "bar-path", "spam-key.json")
164+
)
165+
166+
def test_split_s3_url_no_prefix(self):
167+
"""Test split S3 url with no prefix"""
168+
169+
test_url = "foo-bucket/bar-path/spam-key.json"
170+
self.assertTupleEqual(
171+
split_s3_url(test_url), ("foo-bucket", "bar-path", "spam-key.json")
172+
)
173+
174+
def test_split_s3_url_no_path(self):
175+
"""Test split S3 url with no path"""
176+
177+
test_url = "s3://foo-bucket/spam-key.json"
178+
self.assertTupleEqual(
179+
split_s3_url(test_url), ("foo-bucket", "", "spam-key.json")
180+
)
181+
131182

132183
if __name__ == "__main__":
133184
unittest.main()

0 commit comments

Comments
 (0)