From 19786a0583d6c0632e22a2012e5bf8a903542318 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Thu, 4 Apr 2024 12:36:35 -0400 Subject: [PATCH] DH-5688/fixing the observations code blocks --- .../sql_generator/dataherald_finetuning_agent.py | 7 +++++-- dataherald/sql_generator/dataherald_sqlagent.py | 13 +++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index e0ff4367..b477dbae 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -250,7 +250,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): name = "SqlDbQuery" description = """ - Input: SQL query. + Input: A SQL query between ```sql and ``` tags. Output: Result from the database or an error message if the query is incorrect. Use this tool to execute the SQL query on the database, and return the results. """ @@ -335,7 +335,8 @@ def _run( {"role": "user", "content": user_prompt}, ], ) - return response.choices[0].message.content + returned_sql = response.choices[0].message.content + return f"```sql\n{returned_sql}```" async def _arun( self, @@ -372,6 +373,7 @@ def _run( tables_schema = "" for table in self.db_scan: if table.table_name in table_names_list: + tables_schema += "```sql\n" tables_schema += table.table_schema + "\n" descriptions = [] if table.description is not None: @@ -385,6 +387,7 @@ def _run( ) if len(descriptions) > 0: tables_schema += f"/*\n{''.join(descriptions)}*/\n" + tables_schema += "```\n" if tables_schema == "": tables_schema += "Tables not found in the database" return tables_schema diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 857cdd08..84dbd6df 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -150,7 +150,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): name = "SqlDbQuery" description = """ - Input: SQL query. + Input: A SQL query between ```sql and ``` tags. Output: Result from the database or an error message if the query is incorrect. If an error occurs, rewrite the query and retry. Use this tool to execute SQL queries. @@ -204,8 +204,8 @@ def _run( run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 ) -> str: response = "Admin: All of the generated SQL queries must follow the below instructions:\n" - for instruction in self.instructions: - response += f"{instruction['instruction']}\n" + for index, instruction in enumerate(self.instructions): + response += f"{index + 1}) {instruction['instruction']}\n" return response async def _arun( @@ -407,6 +407,7 @@ def _run( tables_schema = "" for table in self.db_scan: if table.table_name in table_names_list: + tables_schema += "```sql\n" tables_schema += table.table_schema + "\n" descriptions = [] if table.description is not None: @@ -420,6 +421,7 @@ def _run( ) if len(descriptions) > 0: tables_schema += f"/*\n{''.join(descriptions)}*/\n" + tables_schema += "```\n" if tables_schema == "": tables_schema += "Tables not found in the database" return tables_schema @@ -516,9 +518,8 @@ def _run( return "Action input for the fewshot_examples_retriever tool should be an integer" returned_output = "" for example in self.few_shot_examples[:number_of_samples]: - returned_output += ( - f"Question: {example['prompt_text']} -> SQL: {example['sql']}\n" - ) + returned_output += f"Question: {example['prompt_text']} \n" + returned_output += f"```sql\n{example['sql']}\n```\n" if returned_output == "": returned_output = "No previously asked Question/SQL pairs are available" return returned_output