Skip to content

Commit

Permalink
Add custom error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
zacdezgeo committed Sep 10, 2024
1 parent 8d1cd7a commit 1152d65
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 3 deletions.
6 changes: 3 additions & 3 deletions space2stats_api/src/space2stats/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from starlette_cramjam.middleware import CompressionMiddleware

from .db import close_db_connection, connect_to_db
from .errors import add_exception_handlers
from .main import (
SummaryRequest,
get_available_fields,
Expand All @@ -22,11 +23,8 @@

@asynccontextmanager
async def lifespan(app: FastAPI):
"""FastAPI Lifespan."""
# Create Connection Pool
await connect_to_db(app)
yield
# Close the Connection Pool
await close_db_connection(app)


Expand All @@ -49,6 +47,8 @@ async def lifespan(app: FastAPI):
s3_client=s3_client,
)

add_exception_handlers(app)


@app.post("/summary", response_model=List[Dict[str, Any]])
def get_summary(request: Request, body: SummaryRequest):
Expand Down
46 changes: 46 additions & 0 deletions space2stats_api/src/space2stats/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from fastapi import HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from psycopg import OperationalError


async def database_exception_handler(request: Request, exc: OperationalError):
return JSONResponse(
status_code=500,
content={"error": "Database connection failed", "detail": str(exc)},
)


async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content={"error": exc.detail},
)


# Custom handler for validation errors (422 Unprocessable Entity)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
errors = exc.errors()
detailed_errors = [
{
"field": "->".join(map(str, err.get("loc", []))),
"message": err.get("msg"),
"type": err.get("type", "Unknown error type"),
}
for err in errors
]

return JSONResponse(
status_code=422,
content={
"error": "Validation Error",
"details": detailed_errors,
"hint": "Check if 'fields' is a valid list and all provided fields exist.",
},
)


def add_exception_handlers(app):
app.add_exception_handler(OperationalError, database_exception_handler)
app.add_exception_handler(HTTPException, http_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
71 changes: 71 additions & 0 deletions space2stats_api/src/tests/test_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import asyncio
import json

from fastapi import HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.testclient import TestClient
from psycopg.errors import OperationalError

from src.space2stats.app import app
from src.space2stats.errors import (
database_exception_handler,
http_exception_handler,
validation_exception_handler,
)

client = TestClient(app)


def test_database_exception_handler():
request = None
exception = OperationalError("Database connection failed")
response = asyncio.run(database_exception_handler(request, exception))

assert response.status_code == 500
response_data = json.loads(response.body.decode("utf-8"))
assert response_data == {
"error": "Database connection failed",
"detail": "Database connection failed",
}


def test_http_exception_handler():
request = None
exception = HTTPException(status_code=404, detail="Not found")
response = asyncio.run(http_exception_handler(request, exception))

assert response.status_code == 404
response_data = json.loads(response.body.decode("utf-8"))
assert response_data == {
"error": "Not found",
}


def test_validation_exception_handler():
request = None
exc = RequestValidationError(
[
{
"loc": ("body", "fields"),
"msg": "Input should be a valid list",
"type": "type_error.list",
}
]
)
response = asyncio.run(validation_exception_handler(request, exc))

expected_response = {
"error": "Validation Error",
"details": [
{
"field": "body->fields",
"message": "Input should be a valid list",
"type": "type_error.list",
}
],
"hint": "Check if 'fields' is a valid list and all provided fields exist.",
}

assert response.status_code == 422
response_data = json.loads(response.body.decode("utf-8"))
assert response_data == expected_response

0 comments on commit 1152d65

Please sign in to comment.