diff --git a/dataherald/api/__init__.py b/dataherald/api/__init__.py index a4b80317..21ef7df7 100644 --- a/dataherald/api/__init__.py +++ b/dataherald/api/__init__.py @@ -39,6 +39,12 @@ def scan_db( def answer_question(self, question_request: QuestionRequest) -> Response: pass + @abstractmethod + def answer_question_with_timeout( + self, question_request: QuestionRequest + ) -> Response: + pass + @abstractmethod def get_questions(self, db_connection_id: str | None = None) -> list[Question]: pass diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 3b764b9a..298fc3a3 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -1,5 +1,7 @@ import json import logging +import os +import threading import time from typing import List @@ -163,6 +165,37 @@ def answer_question(self, question_request: QuestionRequest) -> Response: response_repository = ResponseRepository(self.storage) return response_repository.insert(generated_answer) + @override + def answer_question_with_timeout( + self, question_request: QuestionRequest + ) -> Response: + result = None + exception = None + user_question = Question( + question=question_request.question, + db_connection_id=question_request.db_connection_id, + ) + stop_event = threading.Event() + + def run_and_catch_exceptions(): + nonlocal result, exception + if not stop_event.is_set(): + result = self.answer_question(question_request) + + thread = threading.Thread(target=run_and_catch_exceptions) + thread.start() + thread.join(timeout=int(os.getenv("DH_ENGINE_TIMEOUT"))) + if thread.is_alive(): + stop_event.set() + return JSONResponse( + status_code=400, + content={ + "question_id": user_question.id, + "error_message": "Timeout Error", + }, + ) + return result + @override def create_database_connection( self, database_connection_request: DatabaseConnectionRequest diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index c75eb5b2..89da742e 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -1,3 +1,4 @@ +import os from typing import Any, List import fastapi @@ -216,6 +217,8 @@ def scan_db( return self._api.scan_db(scanner_request, background_tasks) def answer_question(self, question_request: QuestionRequest) -> Response: + if os.getenv("DH_ENGINE_TIMEOUT", None): + return self._api.answer_question_with_timeout(question_request) return self._api.answer_question(question_request) def get_questions(self, db_connection_id: str | None = None) -> list[Question]: