diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index ccc6d31c..3b764b9a 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -6,6 +6,7 @@ from bson import json_util from bson.objectid import InvalidId, ObjectId from fastapi import BackgroundTasks, HTTPException +from fastapi.responses import JSONResponse from overrides import override from dataherald.api import API @@ -135,7 +136,13 @@ def answer_question(self, question_request: QuestionRequest) -> Response: question_request.db_connection_id ) if not database_connection: - raise HTTPException(status_code=404, detail="Database connection not found") + return JSONResponse( + status_code=404, + content={ + "question_id": user_question.id, + "error_message": "Connections doesn't exist", + }, + ) context = context_store.retrieve_context_for_question(user_question) start_generated_answer = time.time() try: @@ -146,10 +153,11 @@ def answer_question(self, question_request: QuestionRequest) -> Response: confidence_score = evaluator.get_confidence_score( user_question, generated_answer, database_connection ) - except ValueError as e: - raise HTTPException(status_code=404, detail=str(e)) from e - except SQLInjectionError as e: - raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + return JSONResponse( + status_code=400, + content={"question_id": user_question.id, "error_message": str(e)}, + ) generated_answer.confidence_score = confidence_score generated_answer.exec_time = time.time() - start_generated_answer response_repository = ResponseRepository(self.storage) diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index 4001fb0e..c75eb5b2 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -50,6 +50,7 @@ def __init__(self, settings: Settings): "/api/v1/database-connections", self.create_database_connection, methods=["POST"], + status_code=201, tags=["Database connections"], ) @@ -71,6 +72,7 @@ def __init__(self, settings: Settings): "/api/v1/table-descriptions/sync-schemas", self.scan_db, methods=["POST"], + status_code=201, tags=["Table descriptions"], ) @@ -106,6 +108,7 @@ def __init__(self, settings: Settings): "/api/v1/golden-records", self.add_golden_records, methods=["POST"], + status_code=201, tags=["Golden records"], ) @@ -120,6 +123,7 @@ def __init__(self, settings: Settings): "/api/v1/questions", self.answer_question, methods=["POST"], + status_code=201, tags=["Questions"], ) @@ -141,6 +145,7 @@ def __init__(self, settings: Settings): "/api/v1/responses", self.create_response, methods=["POST"], + status_code=201, tags=["Responses"], ) @@ -162,6 +167,7 @@ def __init__(self, settings: Settings): "/api/v1/sql-query-executions", self.execute_sql_query, methods=["POST"], + status_code=201, tags=["SQL queries"], ) @@ -169,6 +175,7 @@ def __init__(self, settings: Settings): "/api/v1/instructions", self.add_instruction, methods=["POST"], + status_code=201, tags=["Instructions"], )