Skip to content

Commit

Permalink
refactor: added Genric call to get_value
Browse files Browse the repository at this point in the history
  • Loading branch information
Killg0d committed Jan 30, 2025
1 parent 2bc9e39 commit 817c48f
Show file tree
Hide file tree
Showing 14 changed files with 34 additions and 34 deletions.
6 changes: 3 additions & 3 deletions src/apps/backend/modules/access_token/access_token_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def create_access_token_by_phone_number(*, params: OTPBasedAuthAccessTokenReques

@staticmethod
def __generate_access_token(*, account: Account) -> AccessToken:
jwt_signing_key:str = ConfigService.get_value(key="accounts.token_signing_key")
jwt_expiry = timedelta(days=ConfigService.get_value(key="accounts.token_expiry_days"))
jwt_signing_key = ConfigService[str].get_value(key="accounts.token_signing_key")
jwt_expiry = timedelta(days=ConfigService[int].get_value(key="accounts.token_expiry_days"))
expiry_time = datetime.now() + jwt_expiry
payload = {"account_id": account.id, "exp": (expiry_time).timestamp()}
jwt_token = jwt.encode(payload, jwt_signing_key, algorithm="HS256")
Expand All @@ -52,7 +52,7 @@ def __generate_access_token(*, account: Account) -> AccessToken:
@staticmethod
def verify_access_token(*, token: str) -> AccessTokenPayload:

jwt_signing_key:str = ConfigService.get_value(key="accounts.token_signing_key")
jwt_signing_key = ConfigService[str].get_value(key="accounts.token_signing_key")

try:
verified_token = jwt.decode(token, jwt_signing_key, algorithms=["HS256"])
Expand Down
4 changes: 2 additions & 2 deletions src/apps/backend/modules/application/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ApplicationRepositoryClient:

@classmethod
def get_client(cls) -> MongoClient:
connection_caching:bool = ConfigService.get_value(key="mongodb.conn_caching")
connection_caching = ConfigService[bool].get_value(key="mongodb.conn_caching")

if connection_caching:
if cls._client is None:
Expand All @@ -27,7 +27,7 @@ def get_client(cls) -> MongoClient:

@staticmethod
def _create_client() -> MongoClient:
connection_uri:str = ConfigService.get_value(key="mongodb.uri")
connection_uri = ConfigService[str].get_value(key="mongodb.uri")
Logger.info(message=f"connecting to database - {connection_uri}")
client = MongoClient(connection_uri, server_api=ServerApi("1"))
Logger.info(message=f"connected to database - {connection_uri}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ def send_email(params: SendEmailParams) -> None:
@staticmethod
def get_client() -> sendgrid.SendGridAPIClient:
if not SendGridService.__client:
api_key:str = ConfigService.get_value(key="sendgrid.api_key")
api_key = ConfigService[str].get_value(key="sendgrid.api_key")
SendGridService.__client = sendgrid.SendGridAPIClient(api_key=api_key)
return SendGridService.__client
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def send_sms(params: SendSMSParams) -> None:
# Send SMS
client.messages.create(
to=params.recipient_phone,
messaging_service_sid=ConfigService.get_value(key="twilio.messaging_service_sid"),
messaging_service_sid=ConfigService[str].get_value(key="twilio.messaging_service_sid"),
body=params.message_body,
)

Expand All @@ -32,8 +32,8 @@ def send_sms(params: SendSMSParams) -> None:
@staticmethod
def get_client() -> Client:
if not TwilioService.__client:
account_sid:str = ConfigService.get_value(key="twilio.account_sid")
auth_token:str = ConfigService.get_value(key="twilio.auth_token")
account_sid = ConfigService[str].get_value(key="twilio.account_sid")
auth_token = ConfigService[str].get_value(key="twilio.auth_token")

# Initialize the Twilio client
TwilioService.__client = Client(account_sid, auth_token)
Expand Down
2 changes: 1 addition & 1 deletion src/apps/backend/modules/communication/sms_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class SMSService:
@staticmethod
def send_sms(*, params: SendSMSParams) -> None:
is_sms_enabled:bool = ConfigService.get_value(key="sms.enabled")
is_sms_enabled = ConfigService[bool].get_value(key="sms.enabled")
if not is_sms_enabled:
Logger.warn(message=f"SMS is disabled. Could not send message - {params.message_body}")
return
Expand Down
11 changes: 5 additions & 6 deletions src/apps/backend/modules/config/config_service.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
from typing import TypeVar
from typing import Generic,TypeVar
from modules.common.types import ErrorCode
from modules.error.custom_errors import MissingKeyError
from modules.config.config import Config

T = TypeVar('T')

class ConfigService:

class ConfigService(Generic[T]):
@staticmethod
def load_config() -> None:
Config.load_config()

@staticmethod
def get_value(key: str) -> T: # type: ignore
@classmethod
def get_value(cls,key: str) -> T:
value = Config.get(key)
if value is None:
raise MissingKeyError(missing_key=key, error_code=ErrorCode.MISSING_KEY)
return value # type: ignore
return value # type: ignore

@staticmethod
def has_value(key: str) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion src/apps/backend/modules/logger/internal/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Loggers:

@staticmethod
def initialize_loggers() -> None:
logger_transports:list = ConfigService.get_value(key='logger.transports')
logger_transports = ConfigService[list[str]].get_value(key='logger.transports')
for logger_transport in logger_transports:
if logger_transport == LoggerTransports.CONSOLE:
Loggers._loggers.append(Loggers.__get_console_logger())
Expand Down
4 changes: 2 additions & 2 deletions src/apps/backend/modules/logger/internal/papertrail_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def __init__(self) -> None:

# Create a console handler and set the level to INFO
logger_config = PapertrailConfig(
host=ConfigService.get_value(key='papertrail.host'),
port=ConfigService.get_value(key='papertrail.port')
host=ConfigService[str].get_value(key='papertrail.host'),
port=ConfigService[int].get_value(key='papertrail.port')
)
papertrail_handler = SysLogHandler(address=(logger_config.host, logger_config.port))
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
Expand Down
6 changes: 3 additions & 3 deletions src/apps/backend/modules/otp/internal/otp_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ class OtpUtil:
@staticmethod
def is_default_phone_number(phone_number: str) -> bool:
default_phone_number = None
if ConfigService.has_value(key="otp.default_phone_number"):
default_phone_number = ConfigService.get_value(key="otp.default_phone_number")
if ConfigService[str].has_value(key="otp.default_phone_number"):
default_phone_number = ConfigService[str].get_value(key="otp.default_phone_number")
if default_phone_number and phone_number == default_phone_number:
return True
return False

@staticmethod
def generate_otp(length: int, phone_number: str) -> str:
if OtpUtil.is_default_phone_number(phone_number):
default_otp:str = ConfigService.get_value(key="otp.default_otp")
default_otp = ConfigService[str].get_value(key="otp.default_otp")
return default_otp
return "".join(random.choices(string.digits, k=length))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def hash_password_reset_token(reset_token: str) -> str:

@staticmethod
def get_token_expires_at() -> datetime:
default_token_expire_time_in_seconds:float = ConfigService.get_value(key="accounts.token_expires_in_seconds")
default_token_expire_time_in_seconds = ConfigService[int].get_value(key="accounts.token_expires_in_seconds")
return datetime.now() + timedelta(seconds=default_token_expire_time_in_seconds)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def verify_password_reset_token(account_id: str, token: str) -> PasswordResetTok
@staticmethod
def send_password_reset_email(account_id: str, first_name: str, username: str, password_reset_token: str) -> None:

web_app_host:str = ConfigService.get_value(key="web_app_host")
default_email:str = ConfigService.get_value(key="mailer.default_email")
default_email_name:str = ConfigService.get_value(key="mailer.default_email_name")
forgot_password_mail_template_id:str = ConfigService.get_value(key="mailer.forgot_password_mail_template_id")
web_app_host = ConfigService[str].get_value(key="web_app_host")
default_email = ConfigService[str].get_value(key="mailer.default_email")
default_email_name = ConfigService[str].get_value(key="mailer.default_email_name")
forgot_password_mail_template_id = ConfigService[str].get_value(key="mailer.forgot_password_mail_template_id")

template_data = {
"first_name": first_name,
Expand Down
2 changes: 1 addition & 1 deletion src/apps/backend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

# Apply ProxyFix to interpret `X-Forwarded` headers if enabled in configuration
# Visit: https://flask.palletsprojects.com/en/stable/deploying/proxy_fix/ for more information
if ConfigService.has_value("is_server_running_behind_proxy") and ConfigService.get_value("is_server_running_behind_proxy"):
if ConfigService.has_value("is_server_running_behind_proxy") and ConfigService[bool].get_value("is_server_running_behind_proxy"):
app.wsgi_app = ProxyFix(app.wsgi_app) # type: ignore

# Register access token apis
Expand Down
4 changes: 2 additions & 2 deletions tests/modules/account/test_account_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def test_get_account_with_expired_access_token(self) -> None:

# Create an expired token by setting the expiry to a date in the past using same method as in the
# access token service
jwt_signing_key = ConfigService.get_value(key="accounts.token_signing_key")
jwt_expiry = timedelta(days=ConfigService.get_value(key="accounts.token_expiry_days") - 1)
jwt_signing_key = ConfigService[str].get_value(key="accounts.token_signing_key")
jwt_expiry = timedelta(days=ConfigService[int].get_value(key="accounts.token_expiry_days") - 1)
payload = {"account_id": account.id, "exp": (datetime.now() - jwt_expiry).timestamp()}
expired_token = jwt.encode(payload, jwt_signing_key, algorithm="HS256")

Expand Down
9 changes: 5 additions & 4 deletions tests/modules/config/test_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import List

from modules.config.types import PapertrailConfig
from modules.common.types import ErrorCode
Expand All @@ -9,20 +10,20 @@

class TestConfig(BaseTestConfig):
def test_db_config_is_loaded(self) -> None:
uri = ConfigService.get_value(key="mongodb.uri")
uri = ConfigService[str].get_value(key="mongodb.uri")
assert uri.split(":")[0] == "mongodb"
assert uri.split("/")[-1] == "frm-boilerplate-test"

def test_logger_config_is_loaded(self) -> None:
loggers = ConfigService.get_value(key="logger.transports")
loggers = ConfigService[List[str]].get_value(key="logger.transports")
assert type(loggers) == list
assert "console" in loggers

def test_papertrail_config_is_loaded(self) -> None:
try:
PapertrailConfig(
host=ConfigService.get_value(key="papertrail.host"),
port=ConfigService.get_value(key="papertrail.port"),
host=ConfigService[str].get_value(key="papertrail.host"),
port=ConfigService[int].get_value(key="papertrail.port"),
)
except MissingKeyError as exc:
assert exc.code == ErrorCode.MISSING_KEY
Expand Down

0 comments on commit 817c48f

Please sign in to comment.