Skip to content

Commit

Permalink
DH-4770 Return the question_id if the endpoint fails
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 committed Oct 10, 2023
1 parent c347002 commit 50d24e3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
18 changes: 13 additions & 5 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from bson import json_util
from bson.objectid import InvalidId, ObjectId
from fastapi import BackgroundTasks, HTTPException
from fastapi.responses import JSONResponse
from overrides import override

from dataherald.api import API
Expand Down Expand Up @@ -135,7 +136,13 @@ def answer_question(self, question_request: QuestionRequest) -> Response:
question_request.db_connection_id
)
if not database_connection:
raise HTTPException(status_code=404, detail="Database connection not found")
return JSONResponse(
status_code=404,
content={
"question_id": user_question.id,
"error_message": "Connections doesn't exist",
},
)
context = context_store.retrieve_context_for_question(user_question)
start_generated_answer = time.time()
try:
Expand All @@ -146,10 +153,11 @@ def answer_question(self, question_request: QuestionRequest) -> Response:
confidence_score = evaluator.get_confidence_score(
user_question, generated_answer, database_connection
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
except SQLInjectionError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
except Exception as e:
return JSONResponse(
status_code=400,
content={"question_id": user_question.id, "error_message": str(e)},
)
generated_answer.confidence_score = confidence_score
generated_answer.exec_time = time.time() - start_generated_answer
response_repository = ResponseRepository(self.storage)
Expand Down
7 changes: 7 additions & 0 deletions dataherald/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, settings: Settings):
"/api/v1/database-connections",
self.create_database_connection,
methods=["POST"],
status_code=201,
tags=["Database connections"],
)

Expand All @@ -71,6 +72,7 @@ def __init__(self, settings: Settings):
"/api/v1/table-descriptions/sync-schemas",
self.scan_db,
methods=["POST"],
status_code=201,
tags=["Table descriptions"],
)

Expand Down Expand Up @@ -106,6 +108,7 @@ def __init__(self, settings: Settings):
"/api/v1/golden-records",
self.add_golden_records,
methods=["POST"],
status_code=201,
tags=["Golden records"],
)

Expand All @@ -120,6 +123,7 @@ def __init__(self, settings: Settings):
"/api/v1/questions",
self.answer_question,
methods=["POST"],
status_code=201,
tags=["Questions"],
)

Expand All @@ -141,6 +145,7 @@ def __init__(self, settings: Settings):
"/api/v1/responses",
self.create_response,
methods=["POST"],
status_code=201,
tags=["Responses"],
)

Expand All @@ -162,13 +167,15 @@ def __init__(self, settings: Settings):
"/api/v1/sql-query-executions",
self.execute_sql_query,
methods=["POST"],
status_code=201,
tags=["SQL queries"],
)

self.router.add_api_route(
"/api/v1/instructions",
self.add_instruction,
methods=["POST"],
status_code=201,
tags=["Instructions"],
)

Expand Down

0 comments on commit 50d24e3

Please sign in to comment.