Skip to content

Commit

Permalink
DH-5688/fixing the observations code blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Apr 4, 2024
1 parent 39bef01 commit 19786a0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
7 changes: 5 additions & 2 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):

name = "SqlDbQuery"
description = """
Input: SQL query.
Input: A SQL query between ```sql and ``` tags.
Output: Result from the database or an error message if the query is incorrect.
Use this tool to execute the SQL query on the database, and return the results.
"""
Expand Down Expand Up @@ -335,7 +335,8 @@ def _run(
{"role": "user", "content": user_prompt},
],
)
return response.choices[0].message.content
returned_sql = response.choices[0].message.content
return f"```sql\n{returned_sql}```"

async def _arun(
self,
Expand Down Expand Up @@ -372,6 +373,7 @@ def _run(
tables_schema = ""
for table in self.db_scan:
if table.table_name in table_names_list:
tables_schema += "```sql\n"
tables_schema += table.table_schema + "\n"
descriptions = []
if table.description is not None:
Expand All @@ -385,6 +387,7 @@ def _run(
)
if len(descriptions) > 0:
tables_schema += f"/*\n{''.join(descriptions)}*/\n"
tables_schema += "```\n"
if tables_schema == "":
tables_schema += "Tables not found in the database"
return tables_schema
Expand Down
13 changes: 7 additions & 6 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):

name = "SqlDbQuery"
description = """
Input: SQL query.
Input: A SQL query between ```sql and ``` tags.
Output: Result from the database or an error message if the query is incorrect.
If an error occurs, rewrite the query and retry.
Use this tool to execute SQL queries.
Expand Down Expand Up @@ -204,8 +204,8 @@ def _run(
run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002
) -> str:
response = "Admin: All of the generated SQL queries must follow the below instructions:\n"
for instruction in self.instructions:
response += f"{instruction['instruction']}\n"
for index, instruction in enumerate(self.instructions):
response += f"{index + 1}) {instruction['instruction']}\n"
return response

async def _arun(
Expand Down Expand Up @@ -407,6 +407,7 @@ def _run(
tables_schema = ""
for table in self.db_scan:
if table.table_name in table_names_list:
tables_schema += "```sql\n"
tables_schema += table.table_schema + "\n"
descriptions = []
if table.description is not None:
Expand All @@ -420,6 +421,7 @@ def _run(
)
if len(descriptions) > 0:
tables_schema += f"/*\n{''.join(descriptions)}*/\n"
tables_schema += "```\n"
if tables_schema == "":
tables_schema += "Tables not found in the database"
return tables_schema
Expand Down Expand Up @@ -516,9 +518,8 @@ def _run(
return "Action input for the fewshot_examples_retriever tool should be an integer"
returned_output = ""
for example in self.few_shot_examples[:number_of_samples]:
returned_output += (
f"Question: {example['prompt_text']} -> SQL: {example['sql']}\n"
)
returned_output += f"Question: {example['prompt_text']} \n"
returned_output += f"```sql\n{example['sql']}\n```\n"
if returned_output == "":
returned_output = "No previously asked Question/SQL pairs are available"
return returned_output
Expand Down

0 comments on commit 19786a0

Please sign in to comment.