From 41837856fefc8ffc03dcbf5a83ba911534686a89 Mon Sep 17 00:00:00 2001 From: HIMANSHU Date: Thu, 14 Mar 2024 18:53:13 +0530 Subject: [PATCH] Feat:Code Refactoring --- app/config/__init__.py | 3 +- app/config/base.py | 4 +- app/daos/home.py | 10 +++- app/daos/users.py | 67 +++++++++++++--------- app/exceptions.py | 24 ++++++++ app/middlewares/cache_middleware.py | 20 ++++--- app/middlewares/rate_limiter_middleware.py | 2 +- app/middlewares/request_id_injection.py | 6 +- app/models/__init__.py | 10 +--- app/models/users.py | 9 ++- app/routes/cache_router/cache_samples.py | 2 +- app/routes/home/home.py | 15 +++-- app/routes/users/users.py | 40 ++++++------- app/schemas/users/users_request.py | 4 +- app/sessions/db.py | 17 ++++-- app/tests/test_basic.py | 6 +- app/utils/exception_handler.py | 16 ++++-- app/utils/user_utils.py | 14 +++-- app/wrappers/cache_wrappers.py | 5 +- 19 files changed, 173 insertions(+), 101 deletions(-) create mode 100644 app/exceptions.py diff --git a/app/config/__init__.py b/app/config/__init__.py index e08e4c6..37b96d8 100644 --- a/app/config/__init__.py +++ b/app/config/__init__.py @@ -1,7 +1,8 @@ +from __future__ import annotations + from .redis_config import get_redis_pool __all__ = [ - "engine", "get_redis_pool", ] diff --git a/app/config/base.py b/app/config/base.py index 9536421..45b083d 100644 --- a/app/config/base.py +++ b/app/config/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pydantic import BaseSettings @@ -25,7 +27,7 @@ class Config: class Settings(BaseSettings): SECRET_KEY: str REDIS_URL: str - SENTRY_DSN: str + SENTRY_DSN: str | None SLACK_WEBHOOK_URL: str ALLOWED_HOSTS: list = ["*"] CACHE_MAX_AGE: int = 60 diff --git a/app/daos/home.py b/app/daos/home.py index 3f722cf..89bad05 100644 --- a/app/daos/home.py +++ b/app/daos/home.py @@ -1,14 +1,18 @@ +from __future__ import annotations + import asyncio import random +from app.exceptions import ExternalServiceException + async def external_service_call(): # Simulate network delay - delay = random.uniform(0.1, 1.0) # Random delay between 0.1 to 1.0 seconds + delay = random.uniform(0.1, 1.0) # Random delay between 0.1 to 1.0 seconds #NOSONAR await asyncio.sleep(delay) # Simulate occasional failures - if random.random() < 0.2: # 20% chance of failure - raise Exception("External service failed") + if random.random() < 0.2: # 20% chance of failure #NOSONAR + raise ExternalServiceException("External Service Failed") return "Success from external service" diff --git a/app/daos/users.py b/app/daos/users.py index d84c8d5..5cfec73 100644 --- a/app/daos/users.py +++ b/app/daos/users.py @@ -1,27 +1,37 @@ +from __future__ import annotations + import json from fastapi import HTTPException +from fastapi_pagination.ext.sqlalchemy import paginate +from sqlalchemy import select from sqlalchemy.orm import Session +from werkzeug.security import check_password_hash + from app.constants import jwt_utils from app.constants.messages.users import user_messages as messages +from app.exceptions import EmailAlreadyExistException +from app.exceptions import InvalidCredentialsException +from app.exceptions import MobileAlreadyExistException +from app.exceptions import NoUserFoundException from app.models import User -from app.schemas.users.users_request import CreateUser, Login -from werkzeug.security import check_password_hash -from fastapi_pagination.ext.sqlalchemy import paginate -from sqlalchemy import select -from app.utils.user_utils import check_existing_field, responseFormatter -from app.wrappers.cache_wrappers import create_cache, retrieve_cache +from app.schemas.users.users_request import CreateUser +from app.schemas.users.users_request import Login +from app.utils.user_utils import check_existing_field +from app.utils.user_utils import response_formatter +from app.wrappers.cache_wrappers import create_cache +from app.wrappers.cache_wrappers import retrieve_cache -async def get_user(user_id: int, dbSession: Session): +async def get_user(user_id: int, db_session: Session): try: cache_key = f"user_{user_id}" - cached_user, expire = await retrieve_cache(cache_key) + cached_user, _ = await retrieve_cache(cache_key) if cached_user: return json.loads(cached_user) # Check if the user already exists in the database user = ( - dbSession.query(User) + db_session.query(User) .where(User.id == user_id) .with_entities( User.id, @@ -35,7 +45,7 @@ async def get_user(user_id: int, dbSession: Session): .first() ) if not user: - raise Exception(messages["NO_USER_FOUND_FOR_ID"]) + raise NoUserFoundException(messages['NO_USER_FOUND_FOR_ID']) await create_cache(json.dumps(user._asdict(), default=str), cache_key, 60) return user @@ -45,12 +55,12 @@ async def get_user(user_id: int, dbSession: Session): raise HTTPException(status_code=400, detail=f"{str(e)}") -def list_users(dbSession: Session): +def list_users(db_session: Session): try: query = select(User.id, User.name, User.email, User.mobile).order_by(User.created_at) # Pass the Select object to the paginate function - users = paginate(dbSession, query=query) + users = paginate(db_session, query=query) return users @@ -60,39 +70,44 @@ def list_users(dbSession: Session): raise HTTPException(status_code=400, detail=f"{str(e)}") -def create_user(data: CreateUser, dbSession: Session): +def create_user(data: CreateUser, db_session: Session): try: user_data = data.dict() # Check if the email already exists in the db - email_exists = check_existing_field(dbSession=dbSession, model=User, field="email", value=user_data["email"]) + email_exists = check_existing_field(db_session=db_session, model=User, field="email", value=user_data["email"]) if email_exists: - raise Exception(messages["EMAIL_ALREADY_EXIST"]) + raise EmailAlreadyExistException(messages['EMAIL_ALREADY_EXIST']) # Check if the mobile already exists in the db - mobile_exists = check_existing_field(dbSession=dbSession, model=User, field="mobile", value=user_data["mobile"]) + mobile_exists = check_existing_field( + db_session=db_session, + model=User, + field="mobile", + value=user_data["mobile"], + ) if mobile_exists: - raise Exception(messages["MOBILE_ALREADY_EXIST"]) + raise MobileAlreadyExistException(messages['MOBILE_ALREADY_EXIST']) user = User(**user_data) - dbSession.add(user) - dbSession.commit() - dbSession.refresh(user) + db_session.add(user) + db_session.commit() + db_session.refresh(user) - return responseFormatter(messages["CREATED_SUCCESSFULLY"]) + return response_formatter(messages["CREATED_SUCCESSFULLY"]) except Exception as e: # Return a user-friendly error message to the client raise HTTPException(status_code=400, detail=f"{str(e)}") -def login(data: Login, dbSession: Session): +def login(data: Login, db_session: Session): try: user_data = data.dict() # Check if the course already exists in the db user_details = ( - dbSession.query(User) + db_session.query(User) .where( User.email == user_data["email"], ) @@ -100,14 +115,14 @@ def login(data: Login, dbSession: Session): ) if not user_details: - raise Exception(messages["INVALID_CREDENTIALS"]) + raise InvalidCredentialsException(messages['INVALID_CREDENTIALS']) if not check_password_hash(user_details.password, user_data["password"]): - raise Exception(messages["INVALID_CREDENTIALS"]) + raise InvalidCredentialsException(messages['INVALID_CREDENTIALS']) del user_details.password token = jwt_utils.create_access_token({"sub": user_details.email, "id": user_details.id}) - return responseFormatter(messages["LOGIN_SUCCESSFULLY"], {"token": token}) + return response_formatter(messages["LOGIN_SUCCESSFULLY"], {"token": token}) except Exception as e: print(e) diff --git a/app/exceptions.py b/app/exceptions.py new file mode 100644 index 0000000..984528f --- /dev/null +++ b/app/exceptions.py @@ -0,0 +1,24 @@ +class ExternalServiceException(Exception): + pass + + +class NoUserFoundException(Exception): + pass +class EmailAlreadyExistException(Exception): + pass + +class MobileAlreadyExistException(Exception): + pass + +class InvalidCredentialsException(Exception): + pass + +class CentryTestException(Exception): + pass + +class DatabaseConnectionException(Exception): + pass + + +class RedisUrlNotFoundException(Exception): + pass \ No newline at end of file diff --git a/app/middlewares/cache_middleware.py b/app/middlewares/cache_middleware.py index 8698a50..1bf8d3c 100644 --- a/app/middlewares/cache_middleware.py +++ b/app/middlewares/cache_middleware.py @@ -15,9 +15,9 @@ class CacheMiddleware(BaseHTTPMiddleware): def __init__( - self, - app, - cached_endpoints: list[str], + self, + app, + cached_endpoints: list[str], ): super().__init__(app) self.cached_endpoints = cached_endpoints @@ -28,6 +28,10 @@ def matches_any_path(self, path_url): return True return False + async def handle_max_age(self, max_age, response_body, key): + if max_age: + await create_cache(response_body[0].decode(), key, max_age) + async def dispatch(self, request: Request, call_next) -> Response: path_url = request.url.path request_type = request.method @@ -41,12 +45,13 @@ async def dispatch(self, request: Request, call_next) -> Response: if request_type != "GET": return await call_next(request) - stored_cache = await retrieve_cache(key) + stored_cache, expire = await retrieve_cache(key) res = stored_cache and cache_control != "no-cache" if res: - headers = {"Cache-Control": f"max-age:{stored_cache[1]}"} - return StreamingResponse(iter([stored_cache[0]]), media_type="application/json", headers=headers) + headers = {"Cache-Control": f"max-age:{expire}"} + return StreamingResponse(iter([stored_cache]), media_type="application/json", headers=headers) + response: Response = await call_next(request) response_body = [chunk async for chunk in response.body_iterator] response.body_iterator = iterate_in_threadpool(iter(response_body)) @@ -57,8 +62,7 @@ async def dispatch(self, request: Request, call_next) -> Response: max_age_match = re.search(r"max-age=(\d+)", cache_control) if max_age_match: max_age = int(max_age_match.group(1)) - if max_age: - await create_cache(response_body[0].decode(), key, max_age) + await self.handle_max_age(max_age, response_body, key) elif matches: await create_cache(response_body[0].decode(), key, max_age) return response diff --git a/app/middlewares/rate_limiter_middleware.py b/app/middlewares/rate_limiter_middleware.py index 71b693c..e9c1085 100644 --- a/app/middlewares/rate_limiter_middleware.py +++ b/app/middlewares/rate_limiter_middleware.py @@ -26,7 +26,7 @@ async def dispatch(self, request: Request, call_next): pipe.expire(client_ip, TIME_WINDOW) await pipe.execute() finally: - pass + print("Finally Block in Rate Limit exceeded") response = await call_next(request) return response diff --git a/app/middlewares/request_id_injection.py b/app/middlewares/request_id_injection.py index 0bd6180..83f4455 100644 --- a/app/middlewares/request_id_injection.py +++ b/app/middlewares/request_id_injection.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import contextvars import uuid +from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request -from fastapi.responses import JSONResponse request_id_contextvar = contextvars.ContextVar("request_id", default=None) @@ -11,7 +13,7 @@ class RequestIdInjection(BaseHTTPMiddleware): def dispatch(self, request: Request, call_next): request_id = str(uuid.uuid4()) - request_id_contextvar.set(request_id) + request_id_contextvar.set(request_id) # noqa try: return call_next(request) diff --git a/app/models/__init__.py b/app/models/__init__.py index 459e0f4..2631ae3 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -1,12 +1,4 @@ -# import glob -# from os.path import basename -# from os.path import dirname -# from os.path import isfile -# from os.path import join - -# modules = glob.glob(join(dirname(__file__), "*.py")) - -# __all__ = [basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py")] +from __future__ import annotations from .users import User diff --git a/app/models/users.py b/app/models/users.py index c397ee2..b803d7d 100644 --- a/app/models/users.py +++ b/app/models/users.py @@ -1,9 +1,12 @@ -from sqlalchemy import Column, event +from __future__ import annotations + +from sqlalchemy import Column +from sqlalchemy import event +from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.sql import func from sqlalchemy.sql.sqltypes import DateTime from sqlalchemy.sql.sqltypes import Integer from sqlalchemy.sql.sqltypes import String -from sqlalchemy.ext.declarative import declarative_base from werkzeug.security import generate_password_hash from app.sessions.db import engine @@ -11,7 +14,7 @@ Base = declarative_base() -class User(Base): +class User(Base): # noqa __tablename__ = "user" id = Column(Integer, primary_key=True, index=True) diff --git a/app/routes/cache_router/cache_samples.py b/app/routes/cache_router/cache_samples.py index 9643826..a739b15 100644 --- a/app/routes/cache_router/cache_samples.py +++ b/app/routes/cache_router/cache_samples.py @@ -11,5 +11,5 @@ @cache_sample_router.get("/get-cache", tags=["Cache-Sample"]) def get_cache(): print("Request ID:", request_id_contextvar.get()) - response = random.randint(100, 1000) + response = random.randint(100, 1000) #NOSONAR return {"random value is": response} diff --git a/app/routes/home/home.py b/app/routes/home/home.py index 5a419aa..5dc4cbc 100644 --- a/app/routes/home/home.py +++ b/app/routes/home/home.py @@ -1,10 +1,15 @@ -from dependencies import circuit_breaker -from fastapi import APIRouter, HTTPException +from __future__ import annotations + +from fastapi import APIRouter +from fastapi import HTTPException from fastapi.responses import JSONResponse +from pybreaker import CircuitBreakerError + from app.config.base import settings -from app.middlewares.request_id_injection import request_id_contextvar from app.daos.home import external_service_call -from pybreaker import CircuitBreakerError +from app.exceptions import CentryTestException +from app.middlewares.request_id_injection import request_id_contextvar +from dependencies import circuit_breaker home_router = APIRouter() @@ -32,7 +37,7 @@ async def external_service_endpoint(): def sentry_endpoint(): if not settings.SENTRY_DSN: raise HTTPException(status_code=503, detail="Sentry DSN not found") - raise Exception("Testing Sentry") + raise CentryTestException("Centry Test") @home_router.get("/{path:path}", include_in_schema=False) diff --git a/app/routes/users/users.py b/app/routes/users/users.py index 4b6c1b3..96118ae 100644 --- a/app/routes/users/users.py +++ b/app/routes/users/users.py @@ -1,20 +1,23 @@ +from __future__ import annotations + +from typing import Annotated + from fastapi import APIRouter from fastapi import Depends -from sqlalchemy.orm import Session +from fastapi.security import HTTPBearer from fastapi_pagination import Page -from app.sessions.db import create_local_session -from app.daos.users import ( - create_user as create_user_dao, - get_user as get_user_dao, - list_users as list_users_dao, - login as signin, -) -from app.schemas.users.users_request import CreateUser, Login +from sqlalchemy.orm import Session + +from app.daos.users import create_user as create_user_dao +from app.daos.users import get_user as get_user_dao +from app.daos.users import list_users as list_users_dao +from app.daos.users import login as signin +from app.middlewares.request_id_injection import request_id_contextvar +from app.schemas.users.users_request import CreateUser +from app.schemas.users.users_request import Login from app.schemas.users.users_response import UserOutResponse +from app.sessions.db import create_local_session from app.utils.user_utils import get_current_user -from typing import Annotated -from fastapi.security import HTTPBearer -from app.middlewares.request_id_injection import request_id_contextvar user_router = APIRouter() @@ -23,15 +26,13 @@ @user_router.post("/register", tags=["Users"]) def register(payload: CreateUser, db: Session = Depends(create_local_session)): - print("Request ID:", request_id_contextvar.get()) - response = create_user_dao(data=payload, dbSession=db) + response = create_user_dao(data=payload, db_session=db) return response @user_router.post("/signin", tags=["Users"]) def login(payload: Login, db: Session = Depends(create_local_session)): - print("Request ID:", request_id_contextvar.get()) - response = signin(data=payload, dbSession=db) + response = signin(data=payload, db_session=db) return response @@ -41,19 +42,16 @@ async def profile( user_id, db: Session = Depends(create_local_session), ): - print("Request ID:", request_id_contextvar.get()) - response = await get_user_dao(user_id, dbSession=db) + response = await get_user_dao(user_id, db_session=db) return response @user_router.get("/", tags=["Users"], response_model=Page[UserOutResponse]) def list_users(db: Session = Depends(create_local_session)): - print("Request ID:", request_id_contextvar.get()) - response = list_users_dao(dbSession=db) + response = list_users_dao(db_session=db) return response @user_router.get("/{user_id}/secure-route/", tags=["Users"], dependencies=[Depends(get_current_user)]) def secure_route(token: Annotated[str, Depends(httpBearerScheme)], user_id: int): - print("Request ID:", request_id_contextvar.get()) return {"message": "If you see this, you're authenticated"} diff --git a/app/schemas/users/users_request.py b/app/schemas/users/users_request.py index 7551e0a..9477d9e 100644 --- a/app/schemas/users/users_request.py +++ b/app/schemas/users/users_request.py @@ -51,7 +51,7 @@ class Config: "name": "Anas Nadeem", "email": "anas@gmail.com", "mobile": "1234567890", - "password": "Test@123", + "password": "Test@123", #NOSONAR } } @@ -67,4 +67,4 @@ def validate_email(cls, email): return email class Config: - schema_extra = {"example": {"email": "anas@gmail.com", "password": "Test@123"}} + schema_extra = {"example": {"email": "anas@gmail.com", "password": "Test@123"}} #NOSONAR diff --git a/app/sessions/db.py b/app/sessions/db.py index 51c6b5d..e9ca8b4 100644 --- a/app/sessions/db.py +++ b/app/sessions/db.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import json import os import sys +from collections.abc import Generator -from app.config.base import db_settings from dotenv import load_dotenv from sqlalchemy import create_engine from sqlalchemy import MetaData @@ -11,6 +13,9 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool +from app.config.base import db_settings +from app.exceptions import DatabaseConnectionException + load_dotenv() # Set the default values for connecting locally @@ -49,6 +54,8 @@ meta = MetaData() # Test the connection and print the status + + try: conn = engine.connect() print("-------------------------- Database connected ----------------------------") @@ -56,16 +63,16 @@ print("\nxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") except Exception as e: print(f"Failed to connect to database. Error: {e}") - raise Exception(f"Failed to connect to database. Error: {e}") + raise DatabaseConnectionException("Failed to connect to database") localSession = Session(engine) -def create_local_session() -> Session: +def create_local_session() -> Generator[Session, None, None]: """Factory function that returns a new session object""" engine = create_engine(f"mysql+pymysql://{USERNAME}:{PASSWORD}@{HOST}:{PORT}/{DBNAME}") - SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - db = SessionLocal() + session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + db = session_local() try: yield db finally: diff --git a/app/tests/test_basic.py b/app/tests/test_basic.py index 22ca0c5..10c0495 100644 --- a/app/tests/test_basic.py +++ b/app/tests/test_basic.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fastapi.testclient import TestClient from app.app import app @@ -12,7 +14,9 @@ def test_read_main(): def test_example(): - assert 1 == 1 + test_value1 = 1 + test_value2 = 1 + assert test_value1 == test_value2 # diff --git a/app/utils/exception_handler.py b/app/utils/exception_handler.py index 8dea7a9..6e1062a 100644 --- a/app/utils/exception_handler.py +++ b/app/utils/exception_handler.py @@ -1,22 +1,26 @@ +from __future__ import annotations + import traceback from fastapi import Request from fastapi.encoders import jsonable_encoder -from fastapi.exceptions import HTTPException, RequestValidationError +from fastapi.exceptions import HTTPException +from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse -from app.utils.slack_notification_utils import send_slack_message from app.middlewares.request_id_injection import request_id_contextvar +from app.utils.slack_notification_utils import send_slack_message -async def validation_exception_handler(request: Request, exc: RequestValidationError): +async def validation_exception_handler(exc: RequestValidationError): return JSONResponse( status_code=400, content=jsonable_encoder({"message": "Validation error", "detail": exc.errors()[0]["msg"]}), ) -async def http_exception_handler(request: Request, exc: HTTPException): +async def http_exception_handler(request:Request, exc: HTTPException): + print(request) return JSONResponse(status_code=exc.status_code, content={"success": False, "message": exc.detail}) @@ -26,7 +30,7 @@ async def exception_handler(request: Request, exc: Exception): traceback_str = traceback.format_exc(chain=False) send_slack_message( { - "text": f"```\nRequestID: {request_id_contextvar.get()}\nRequest URL: {str(request.url)} \nRequest_method: {str(request.method)} \nTraceback: {traceback_str}```" - } + "text": f"```\nRequestID: {request_id_contextvar.get()}\nRequest URL: {str(request.url)} \nRequest_method: {str(request.method)} \nTraceback: {traceback_str}```", + }, ) return JSONResponse(status_code=500, content={"success": False, "message": error_message}) diff --git a/app/utils/user_utils.py b/app/utils/user_utils.py index dcc1be8..70fc6b3 100644 --- a/app/utils/user_utils.py +++ b/app/utils/user_utils.py @@ -1,14 +1,18 @@ +from __future__ import annotations + +from fastapi import HTTPException +from fastapi import Request from sqlalchemy.orm import Session -from fastapi import Request, HTTPException + from app.constants import jwt_utils -def responseFormatter(message, data=None): +def response_formatter(message, data=None): return {"success": True, "message": message, "data": data} -def check_existing_field(dbSession: Session, model, field, value): - existing = dbSession.query(model).filter(getattr(model, field) == value).first() +def check_existing_field(db_session: Session, model, field, value): + existing = db_session.query(model).filter(getattr(model, field) == value).first() if existing: return True @@ -21,7 +25,7 @@ async def get_current_user(request: Request): raise HTTPException(status_code=401, detail="Token not provided") try: - user_id = int(dict(request).get("path_params")["user_id"]) + user_id = int(dict(request).get("path_params")["user_id"]) # noqa token = token.split(" ")[1] payload = jwt_utils.decode_access_token(token) if user_id == int(payload["id"]): diff --git a/app/wrappers/cache_wrappers.py b/app/wrappers/cache_wrappers.py index 01e6e6e..b96ec16 100644 --- a/app/wrappers/cache_wrappers.py +++ b/app/wrappers/cache_wrappers.py @@ -1,8 +1,11 @@ +from __future__ import annotations + from app.config.base import settings from app.config.redis_config import get_redis_pool +from app.exceptions import RedisUrlNotFoundException if not settings.REDIS_URL: - raise Exception("Please add REDIS_URL in environment") + raise RedisUrlNotFoundException("Failed To get Redis URL") async def create_cache(resp, key: str, ex: int = 60):