Skip to content

Commit

Permalink
fix: catching AWS client error to allow for no-infra testing in CI pi…
Browse files Browse the repository at this point in the history
…pelines.
  • Loading branch information
svange committed Nov 13, 2023
1 parent e413078 commit 7f5c401
Showing 1 changed file with 34 additions and 13 deletions.
47 changes: 34 additions & 13 deletions openbrain/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Metrics,
Tracer,
)
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, NoCredentialsError
from dotenv import load_dotenv

from openbrain.exceptions import ObMissingEnvironmentVariable
Expand Down Expand Up @@ -47,6 +47,7 @@ def detect_aws_region() -> str:

class Defaults(Enum):
"""Default values for environment variables and other constants."""

SESSION_TABLE_NAME = None
LEAD_TABLE_NAME = None
AGENT_CONFIG_TABLE_NAME = None
Expand Down Expand Up @@ -85,7 +86,9 @@ class Config:

OB_MODE: str = field(default=os.environ.get(Defaults.OB_MODE.name, Defaults.OB_MODE.value))
AWS_REGION: str = field(default_factory=detect_aws_region)
INFRA_STACK_NAME: str = field(default=os.environ.get(Defaults.INFRA_STACK_NAME.name, Defaults.INFRA_STACK_NAME.value))
INFRA_STACK_NAME: str = field(
default=os.environ.get(Defaults.INFRA_STACK_NAME.name, Defaults.INFRA_STACK_NAME.value)
)

# DB TABLES
SESSION_TABLE_NAME: str = field(
Expand All @@ -101,14 +104,20 @@ class Config:
)

SESSION_TABLE_PUBLISHED_NAME: str = field(
default=os.environ.get(Defaults.SESSION_TABLE_PUBLISHED_NAME.name, Defaults.SESSION_TABLE_PUBLISHED_NAME.value)
default=os.environ.get(
Defaults.SESSION_TABLE_PUBLISHED_NAME.name, Defaults.SESSION_TABLE_PUBLISHED_NAME.value
)
)
LEAD_TABLE_PUBLISHED_NAME: str = field(
default=os.environ.get(Defaults.LEAD_TABLE_PUBLISHED_NAME.name, Defaults.LEAD_TABLE_PUBLISHED_NAME.value)
default=os.environ.get(
Defaults.LEAD_TABLE_PUBLISHED_NAME.name, Defaults.LEAD_TABLE_PUBLISHED_NAME.value
)
)
AGENT_CONFIG_TABLE_PUBLISHED_NAME: str = field(
default=os.environ.get(
Defaults.AGENT_CONFIG_TABLE_PUBLISHED_NAME.name, Defaults.AGENT_CONFIG_TABLE_PUBLISHED_NAME.value)
Defaults.AGENT_CONFIG_TABLE_PUBLISHED_NAME.name,
Defaults.AGENT_CONFIG_TABLE_PUBLISHED_NAME.value,
)
)

# MISC RESOURCES
Expand Down Expand Up @@ -143,7 +152,9 @@ def __post_init__(self):
def set_dynamic_values(self):
_logger = get_logger()
dynamic_attributes = [
attrib.replace("_PUBLISHED_NAME", "_NAME") for attrib in self.__dict__ if attrib.endswith("_PUBLISHED_NAME")
attrib.replace("_PUBLISHED_NAME", "_NAME")
for attrib in self.__dict__
if attrib.endswith("_PUBLISHED_NAME")
]
defaults = asdict(self)
undefined_resources = []
Expand All @@ -154,11 +165,15 @@ def set_dynamic_values(self):
continue

attrib_published_name = getattr(self, attrib.replace("_NAME", "_PUBLISHED_NAME"))
_logger.debug(f"{attrib} not defined in environment variables, searching for friendly name {attrib_published_name}")
_logger.debug(
f"{attrib} not defined in environment variables, searching for friendly name {attrib_published_name}"
)

# If the friendly name is using the default, emit a warning
if attrib_published_name == defaults[attrib]:
_logger.debug(f"{attrib} is using the default friendly name. Do you have this infrastructure deployed?")
_logger.debug(
f"{attrib} is using the default friendly name. Do you have this infrastructure deployed?"
)

if not attrib_published_name:
_logger.warning(f"{attrib} not defined in environment variables")
Expand All @@ -172,7 +187,9 @@ def set_dynamic_values(self):
resource_name = self._get_resource_from_central_infra(attrib_published_name)
setattr(self, attrib, resource_name)
except ClientError as e:
print(f"ERROR: Can't find {attrib} in environment variables or central infrastructure")
print(
f"ERROR: Can't find {attrib} in environment variables or central infrastructure"
)
undefined_resources.append((attrib, attrib_published_name))

# Run through accumulated errors and raise
Expand All @@ -183,7 +200,9 @@ def set_dynamic_values(self):
)

for attrib, attrib_published_name in undefined_resources:
print(f"ERROR: Can't find {attrib_published_name} values from your central infrastructure")
print(
f"ERROR: Can't find {attrib_published_name} values from your central infrastructure"
)

# raise ObMissingEnvironmentVariable(
# "Missing environment variables or central infrastructure. Please define all resource names in "
Expand All @@ -200,8 +219,10 @@ def _get_resource_from_central_infra(self, friendly_name):

try:
response = cf_client.describe_stacks(StackName=self.INFRA_STACK_NAME)
self._central_infra_outputs = {x["OutputKey"]: x["OutputValue"] for x in response["Stacks"][0]["Outputs"]}
except ClientError as e:
self._central_infra_outputs = {
x["OutputKey"]: x["OutputValue"] for x in response["Stacks"][0]["Outputs"]
}
except NoCredentialsError or ClientError:
logger.warning(
f"Can't find central infrastructure stack {self.INFRA_STACK_NAME} - to find resources from "
f"friendly names. Please define all resource names in the environment OR define the central "
Expand All @@ -225,12 +246,12 @@ def __new__(cls):
def _initialize(cls):
cls._instance = Config()


logger = get_logger()
metrics = get_metrics()
tracer = get_tracer()
config = ConfigSingleton()
if __name__ == "__main__":

logger.debug(config)
# if config.OB_MODE == Defaults.OB_MODE_LOCAL.value:
# identity = boto3.client("sts").get_caller_identity()
Expand Down

0 comments on commit 7f5c401

Please sign in to comment.