diff --git a/space2stats_api/src/space2stats/app.py b/space2stats_api/src/space2stats/app.py index 50b34ed..6747e70 100644 --- a/space2stats_api/src/space2stats/app.py +++ b/space2stats_api/src/space2stats/app.py @@ -10,6 +10,7 @@ from starlette_cramjam.middleware import CompressionMiddleware from .db import close_db_connection, connect_to_db +from .errors import add_exception_handlers from .main import ( SummaryRequest, get_available_fields, @@ -22,11 +23,8 @@ @asynccontextmanager async def lifespan(app: FastAPI): - """FastAPI Lifespan.""" - # Create Connection Pool await connect_to_db(app) yield - # Close the Connection Pool await close_db_connection(app) @@ -49,6 +47,8 @@ async def lifespan(app: FastAPI): s3_client=s3_client, ) +add_exception_handlers(app) + @app.post("/summary", response_model=List[Dict[str, Any]]) def get_summary(request: Request, body: SummaryRequest): diff --git a/space2stats_api/src/space2stats/errors.py b/space2stats_api/src/space2stats/errors.py new file mode 100644 index 0000000..9dc4c65 --- /dev/null +++ b/space2stats_api/src/space2stats/errors.py @@ -0,0 +1,46 @@ +from fastapi import HTTPException, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from psycopg import OperationalError + + +async def database_exception_handler(request: Request, exc: OperationalError): + return JSONResponse( + status_code=500, + content={"error": "Database connection failed", "detail": str(exc)}, + ) + + +async def http_exception_handler(request: Request, exc: HTTPException): + return JSONResponse( + status_code=exc.status_code, + content={"error": exc.detail}, + ) + + +# Custom handler for validation errors (422 Unprocessable Entity) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + errors = exc.errors() + detailed_errors = [ + { + "field": "->".join(map(str, err.get("loc", []))), + "message": err.get("msg"), + "type": err.get("type", "Unknown error type"), + } + for err in errors + ] + + return JSONResponse( + status_code=422, + content={ + "error": "Validation Error", + "details": detailed_errors, + "hint": "Check if 'fields' is a valid list and all provided fields exist.", + }, + ) + + +def add_exception_handlers(app): + app.add_exception_handler(OperationalError, database_exception_handler) + app.add_exception_handler(HTTPException, http_exception_handler) + app.add_exception_handler(RequestValidationError, validation_exception_handler) diff --git a/space2stats_api/src/tests/test_errors.py b/space2stats_api/src/tests/test_errors.py new file mode 100644 index 0000000..94954f3 --- /dev/null +++ b/space2stats_api/src/tests/test_errors.py @@ -0,0 +1,71 @@ +import asyncio +import json + +from fastapi import HTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.testclient import TestClient +from psycopg.errors import OperationalError + +from src.space2stats.app import app +from src.space2stats.errors import ( + database_exception_handler, + http_exception_handler, + validation_exception_handler, +) + +client = TestClient(app) + + +def test_database_exception_handler(): + request = None + exception = OperationalError("Database connection failed") + response = asyncio.run(database_exception_handler(request, exception)) + + assert response.status_code == 500 + response_data = json.loads(response.body.decode("utf-8")) + assert response_data == { + "error": "Database connection failed", + "detail": "Database connection failed", + } + + +def test_http_exception_handler(): + request = None + exception = HTTPException(status_code=404, detail="Not found") + response = asyncio.run(http_exception_handler(request, exception)) + + assert response.status_code == 404 + response_data = json.loads(response.body.decode("utf-8")) + assert response_data == { + "error": "Not found", + } + + +def test_validation_exception_handler(): + request = None + exc = RequestValidationError( + [ + { + "loc": ("body", "fields"), + "msg": "Input should be a valid list", + "type": "type_error.list", + } + ] + ) + response = asyncio.run(validation_exception_handler(request, exc)) + + expected_response = { + "error": "Validation Error", + "details": [ + { + "field": "body->fields", + "message": "Input should be a valid list", + "type": "type_error.list", + } + ], + "hint": "Check if 'fields' is a valid list and all provided fields exist.", + } + + assert response.status_code == 422 + response_data = json.loads(response.body.decode("utf-8")) + assert response_data == expected_response