Skip to content

Commit

Permalink
Dh 5584/fixing the running query forever (#427)
Browse files Browse the repository at this point in the history
* Dh-5584/fixing the sql query stucking for ever issue

* DH-5584/updating the timeout
  • Loading branch information
MohammadrezaPourreza authored and DishenWang2023 committed May 7, 2024
1 parent f80b749 commit 40ff844
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ LLM_MODEL = "gpt-4-turbo-preview"

#timeout in seconds for the engine to return a response. Defaults to 150 seconds
DH_ENGINE_TIMEOUT =
#tmeout for SQL execution, our agents exceute the SQL query to recover from errors, this is the timeout for that execution. Defaults to 30 seconds
SQL_EXECUTION_TIMEOUT =
#The upper limit on number of rows returned from the query engine (equivalent to using LIMIT N in PostgreSQL/MySQL/SQlite). Defauls to 50
UPPER_LIMIT_QUERY_RETURN_ROWS =
#Encryption key for storing DB connection data in Mongo
Expand Down
15 changes: 12 additions & 3 deletions dataherald/sql_generator/create_sql_query_status.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from sqlalchemy import text
import os

from dataherald.sql_database.base import SQLDatabase, SQLInjectionError
from dataherald.types import SQLGeneration
from dataherald.utils.timeout_utils import run_with_timeout


def format_error_message(
Expand Down Expand Up @@ -30,10 +31,18 @@ def create_sql_query_status(
else:
try:
query = db.parser_to_filter_commands(query)
with db._engine.connect() as connection:
connection.execute(text(query))
run_with_timeout(
db.run_sql,
args=(query,),
timeout_duration=int(os.getenv("SQL_EXECUTION_TIMEOUT", "60")),
)
sql_generation.status = "VALID"
sql_generation.error = None
except TimeoutError:
sql_generation = format_error_message(
sql_generation,
"The query execution exceeded the timeout.",
)
except SQLInjectionError as e:
raise SQLInjectionError(
"Sensitive SQL keyword detected in the query."
Expand Down
12 changes: 11 additions & 1 deletion dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
FORMAT_INSTRUCTIONS,
)
from dataherald.utils.models_context_window import OPENAI_FINETUNING_MODELS_WINDOW_SIZES
from dataherald.utils.timeout_utils import run_with_timeout

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -241,7 +242,16 @@ def _run(
query = replace_unprocessable_characters(query)
if "```sql" in query:
query = query.replace("```sql", "").replace("```", "")
return self.db.run_sql(query, top_k=TOP_K)[0]

try:
return run_with_timeout(
self.db.run_sql,
args=(query,),
kwargs={"top_k": TOP_K},
timeout_duration=int(os.getenv("SQL_EXECUTION_TIMEOUT", "60")),
)
except TimeoutError:
return "SQL query execution time exceeded, proceed without query execution"

async def _arun(
self,
Expand Down
12 changes: 11 additions & 1 deletion dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
SUFFIX_WITH_FEW_SHOT_SAMPLES,
SUFFIX_WITHOUT_FEW_SHOT_SAMPLES,
)
from dataherald.utils.timeout_utils import run_with_timeout

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -159,7 +160,16 @@ def _run(
query = replace_unprocessable_characters(query)
if "```sql" in query:
query = query.replace("```sql", "").replace("```", "")
return self.db.run_sql(query, top_k=top_k)[0]

try:
return run_with_timeout(
self.db.run_sql,
args=(query,),
kwargs={"top_k": top_k},
timeout_duration=int(os.getenv("SQL_EXECUTION_TIMEOUT", "60")),
)
except TimeoutError:
return "SQL query execution time exceeded, proceed without query execution"

async def _arun(
self,
Expand Down
1 change: 1 addition & 0 deletions dataherald/utils/agent_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,5 @@
Final Answer: the final answer to the original input question
If you know the final answer and do not need to use any tools, you can directly return the Final Answer: <your final answer>.
If there is a consistent parsing error, please return "I don't know" as your final answer.
"""
24 changes: 24 additions & 0 deletions dataherald/utils/timeout_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import threading


def run_with_timeout(func, args=(), kwargs=None, timeout_duration=60):
if kwargs is None:
kwargs = {}

def func_wrapper(result_container):
try:
result_container.append(func(*args, **kwargs))
except Exception as e:
result_container.append(e)

result_container = []
thread = threading.Thread(target=func_wrapper, args=(result_container,))
thread.start()
thread.join(timeout=timeout_duration)
if thread.is_alive():
raise TimeoutError("Function execution exceeded the timeout")
if result_container:
if isinstance(result_container[0], Exception):
raise result_container[0]
return result_container[0]
raise TimeoutError("Function execution exceeded the timeout")

0 comments on commit 40ff844

Please sign in to comment.