From 543a43071a8efbb0f42f1c2935d90953f18aa4c8 Mon Sep 17 00:00:00 2001
From: Juan Carlos Jose Camacho <jc@dataherald.com>
Date: Wed, 8 Nov 2023 12:44:47 -0600
Subject: [PATCH] Only returns sql_query_result as null when generate_csv flag
 is set and it has more than 50 rows

---
 dataherald/api/fastapi.py       | 13 +++++++++++--
 dataherald/repositories/base.py |  4 ++--
 2 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py
index a914553f..bf99cb8d 100644
--- a/dataherald/api/fastapi.py
+++ b/dataherald/api/fastapi.py
@@ -55,6 +55,8 @@
 
 logger = logging.getLogger(__name__)
 
+MAX_ROWS_TO_CREATE_CSV_FILE = 50
+
 
 def async_scanning(scanner, database, scanner_request, storage):
     scanner.scan(
@@ -176,7 +178,11 @@ def answer_question(
                 status_code=400,
                 content={"question_id": user_question.id, "error_message": str(e)},
             )
-        if generated_answer.csv_file_path:
+        if (
+            generate_csv
+            and len(generated_answer.sql_query_result.rows)
+            > MAX_ROWS_TO_CREATE_CSV_FILE
+        ):
             generated_answer.sql_query_result = None
         generated_answer.exec_time = time.time() - start_generated_answer
         response_repository = ResponseRepository(self.storage)
@@ -529,7 +535,10 @@ def create_response(
                 user_question, response, database_connection
             )
             response.confidence_score = confidence_score
-        if response.csv_file_path:
+        if (
+            generate_csv
+            and len(response.sql_query_result.rows) > MAX_ROWS_TO_CREATE_CSV_FILE
+        ):
             response.sql_query_result = None
         response.exec_time = time.time() - start_generated_answer
         response_repository.insert(response)
diff --git a/dataherald/repositories/base.py b/dataherald/repositories/base.py
index 5d589d19..511bafd2 100644
--- a/dataherald/repositories/base.py
+++ b/dataherald/repositories/base.py
@@ -11,7 +11,7 @@ def __init__(self, storage):
         self.storage = storage
 
     def insert(self, response: Response) -> Response:
-        response_dict = response.dict(exclude={"id"})
+        response_dict = response.dict(exclude={"id", "sql_query_result"})
         response_dict["question_id"] = ObjectId(response.question_id)
         response.id = str(self.storage.insert_one(DB_COLLECTION, response_dict))
         return response
@@ -25,7 +25,7 @@ def find_one(self, query: dict) -> Response | None:
         return Response(**row)
 
     def update(self, response: Response) -> Response:
-        response_dict = response.dict(exclude={"id"})
+        response_dict = response.dict(exclude={"id", "sql_query_result"})
         response_dict["question_id"] = ObjectId(response.question_id)
 
         self.storage.update_or_create(