From 543a43071a8efbb0f42f1c2935d90953f18aa4c8 Mon Sep 17 00:00:00 2001 From: Juan Carlos Jose Camacho <jc@dataherald.com> Date: Wed, 8 Nov 2023 12:44:47 -0600 Subject: [PATCH] Only returns sql_query_result as null when generate_csv flag is set and it has more than 50 rows --- dataherald/api/fastapi.py | 13 +++++++++++-- dataherald/repositories/base.py | 4 ++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index a914553f..bf99cb8d 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -55,6 +55,8 @@ logger = logging.getLogger(__name__) +MAX_ROWS_TO_CREATE_CSV_FILE = 50 + def async_scanning(scanner, database, scanner_request, storage): scanner.scan( @@ -176,7 +178,11 @@ def answer_question( status_code=400, content={"question_id": user_question.id, "error_message": str(e)}, ) - if generated_answer.csv_file_path: + if ( + generate_csv + and len(generated_answer.sql_query_result.rows) + > MAX_ROWS_TO_CREATE_CSV_FILE + ): generated_answer.sql_query_result = None generated_answer.exec_time = time.time() - start_generated_answer response_repository = ResponseRepository(self.storage) @@ -529,7 +535,10 @@ def create_response( user_question, response, database_connection ) response.confidence_score = confidence_score - if response.csv_file_path: + if ( + generate_csv + and len(response.sql_query_result.rows) > MAX_ROWS_TO_CREATE_CSV_FILE + ): response.sql_query_result = None response.exec_time = time.time() - start_generated_answer response_repository.insert(response) diff --git a/dataherald/repositories/base.py b/dataherald/repositories/base.py index 5d589d19..511bafd2 100644 --- a/dataherald/repositories/base.py +++ b/dataherald/repositories/base.py @@ -11,7 +11,7 @@ def __init__(self, storage): self.storage = storage def insert(self, response: Response) -> Response: - response_dict = response.dict(exclude={"id"}) + response_dict = response.dict(exclude={"id", "sql_query_result"}) response_dict["question_id"] = ObjectId(response.question_id) response.id = str(self.storage.insert_one(DB_COLLECTION, response_dict)) return response @@ -25,7 +25,7 @@ def find_one(self, query: dict) -> Response | None: return Response(**row) def update(self, response: Response) -> Response: - response_dict = response.dict(exclude={"id"}) + response_dict = response.dict(exclude={"id", "sql_query_result"}) response_dict["question_id"] = ObjectId(response.question_id) self.storage.update_or_create(