From 1857d98e39dc1f4298cce07d5baf3d13f821c975 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Thu, 4 Apr 2024 12:36:35 -0400 Subject: [PATCH 1/7] DH-5688/fixing the observations code blocks --- .../sql_generator/dataherald_finetuning_agent.py | 7 +++++-- dataherald/sql_generator/dataherald_sqlagent.py | 13 +++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index e0ff4367..b477dbae 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -250,7 +250,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): name = "SqlDbQuery" description = """ - Input: SQL query. + Input: A SQL query between ```sql and ``` tags. Output: Result from the database or an error message if the query is incorrect. Use this tool to execute the SQL query on the database, and return the results. """ @@ -335,7 +335,8 @@ def _run( {"role": "user", "content": user_prompt}, ], ) - return response.choices[0].message.content + returned_sql = response.choices[0].message.content + return f"```sql\n{returned_sql}```" async def _arun( self, @@ -372,6 +373,7 @@ def _run( tables_schema = "" for table in self.db_scan: if table.table_name in table_names_list: + tables_schema += "```sql\n" tables_schema += table.table_schema + "\n" descriptions = [] if table.description is not None: @@ -385,6 +387,7 @@ def _run( ) if len(descriptions) > 0: tables_schema += f"/*\n{''.join(descriptions)}*/\n" + tables_schema += "```\n" if tables_schema == "": tables_schema += "Tables not found in the database" return tables_schema diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 857cdd08..84dbd6df 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -150,7 +150,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): name = "SqlDbQuery" description = """ - Input: SQL query. + Input: A SQL query between ```sql and ``` tags. Output: Result from the database or an error message if the query is incorrect. If an error occurs, rewrite the query and retry. Use this tool to execute SQL queries. @@ -204,8 +204,8 @@ def _run( run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 ) -> str: response = "Admin: All of the generated SQL queries must follow the below instructions:\n" - for instruction in self.instructions: - response += f"{instruction['instruction']}\n" + for index, instruction in enumerate(self.instructions): + response += f"{index + 1}) {instruction['instruction']}\n" return response async def _arun( @@ -407,6 +407,7 @@ def _run( tables_schema = "" for table in self.db_scan: if table.table_name in table_names_list: + tables_schema += "```sql\n" tables_schema += table.table_schema + "\n" descriptions = [] if table.description is not None: @@ -420,6 +421,7 @@ def _run( ) if len(descriptions) > 0: tables_schema += f"/*\n{''.join(descriptions)}*/\n" + tables_schema += "```\n" if tables_schema == "": tables_schema += "Tables not found in the database" return tables_schema @@ -516,9 +518,8 @@ def _run( return "Action input for the fewshot_examples_retriever tool should be an integer" returned_output = "" for example in self.few_shot_examples[:number_of_samples]: - returned_output += ( - f"Question: {example['prompt_text']} -> SQL: {example['sql']}\n" - ) + returned_output += f"Question: {example['prompt_text']} \n" + returned_output += f"```sql\n{example['sql']}\n```\n" if returned_output == "": returned_output = "No previously asked Question/SQL pairs are available" return returned_output From 9cf7fb5f24b526f78fbc9e57aa1e19239a9bc79c Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Thu, 4 Apr 2024 12:41:50 -0400 Subject: [PATCH 2/7] DATA-5688/fix the inline comments --- dataherald/sql_generator/dataherald_finetuning_agent.py | 4 ++-- dataherald/sql_generator/dataherald_sqlagent.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index b477dbae..ce0ff2f6 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -228,12 +228,12 @@ def _run( table_relevance = "" for _, row in df.iterrows(): table_relevance += ( - f'Table: {row["table_name"]}, relevance score: {row["similarities"]}\n' + f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n' ) if len(most_similar_tables) > 0: for table in most_similar_tables: table_relevance += ( - f"Table: {table}, relevance score: {max(df['similarities'])}\n" + f"Table: `{table}`, relevance score: {max(df['similarities'])}\n" ) return table_relevance diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 84dbd6df..02059dc2 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -291,12 +291,12 @@ def _run( table_relevance = "" for _, row in df.iterrows(): table_relevance += ( - f'Table: {row["table_name"]}, relevance score: {row["similarities"]}\n' + f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n' ) if len(most_similar_tables) > 0: for table in most_similar_tables: table_relevance += ( - f"Table: {table}, relevance score: {max(df['similarities'])}\n" + f"Table: `{table}`, relevance score: {max(df['similarities'])}\n" ) return table_relevance From 829084c6f6cd8f50c225ae98544a0a2d8c2ab6e4 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Thu, 4 Apr 2024 12:43:19 -0400 Subject: [PATCH 3/7] DH-5688/reformat with black --- dataherald/sql_generator/dataherald_finetuning_agent.py | 4 +--- dataherald/sql_generator/dataherald_sqlagent.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index ce0ff2f6..6d7b3486 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -227,9 +227,7 @@ def _run( most_similar_tables = self.similart_tables_based_on_few_shot_examples(df) table_relevance = "" for _, row in df.iterrows(): - table_relevance += ( - f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n' - ) + table_relevance += f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n' if len(most_similar_tables) > 0: for table in most_similar_tables: table_relevance += ( diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 02059dc2..1e48bcbc 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -290,9 +290,7 @@ def _run( most_similar_tables = self.similart_tables_based_on_few_shot_examples(df) table_relevance = "" for _, row in df.iterrows(): - table_relevance += ( - f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n' - ) + table_relevance += f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n' if len(most_similar_tables) > 0: for table in most_similar_tables: table_relevance += ( From 443a49e171dc761cfc768a73f0f3b7ad920f5423 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Thu, 4 Apr 2024 13:55:45 -0400 Subject: [PATCH 4/7] Fixing the backticks of the observations --- dataherald/sql_generator/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index dd318bf6..1c8255bb 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -173,7 +173,7 @@ def stream_agent_steps( # noqa: C901 queue.put(message.content + "\n") elif "steps" in chunk: for step in chunk["steps"]: - queue.put(f"**Observation:**\n `{step.observation}`\n") + queue.put(f"**Observation:**\n {step.observation}\n") elif "output" in chunk: queue.put(f'**Final Answer:**\n {chunk["output"]}') if "```sql" in chunk["output"]: From 51c0390f7198a03c31b6c83aa14688924b53aaab Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Thu, 4 Apr 2024 14:33:13 -0400 Subject: [PATCH 5/7] Add newlines after the observations and final answer --- dataherald/sql_generator/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index 1c8255bb..b270d5a7 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -173,9 +173,9 @@ def stream_agent_steps( # noqa: C901 queue.put(message.content + "\n") elif "steps" in chunk: for step in chunk["steps"]: - queue.put(f"**Observation:**\n {step.observation}\n") + queue.put(f"\n**Observation:**\n {step.observation}\n") elif "output" in chunk: - queue.put(f'**Final Answer:**\n {chunk["output"]}') + queue.put(f'\n**Final Answer:**\n {chunk["output"]}') if "```sql" in chunk["output"]: response.sql = replace_unprocessable_characters( self.remove_markdown(chunk["output"]) From ed249f6a25f1fd1c343a6ec42b10b0826be9cb9f Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Thu, 4 Apr 2024 14:44:52 -0400 Subject: [PATCH 6/7] removing multi DDL command in create tables --- dataherald/sql_generator/dataherald_sqlagent.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 1e48bcbc..cc2a9e9b 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -402,10 +402,9 @@ def _run( replace_unprocessable_characters(table_name) for table_name in table_names_list ] - tables_schema = "" + tables_schema = "```sql\n" for table in self.db_scan: if table.table_name in table_names_list: - tables_schema += "```sql\n" tables_schema += table.table_schema + "\n" descriptions = [] if table.description is not None: @@ -419,7 +418,7 @@ def _run( ) if len(descriptions) > 0: tables_schema += f"/*\n{''.join(descriptions)}*/\n" - tables_schema += "```\n" + tables_schema += "```\n" if tables_schema == "": tables_schema += "Tables not found in the database" return tables_schema From 4b34085605cba2c5fdf51ea2e238c586606588f1 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Thu, 4 Apr 2024 15:01:43 -0400 Subject: [PATCH 7/7] adding newline for excute sql query --- dataherald/sql_generator/dataherald_finetuning_agent.py | 1 + dataherald/sql_generator/dataherald_sqlagent.py | 1 + 2 files changed, 2 insertions(+) diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index 6d7b3486..5fe0ed4d 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -251,6 +251,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): Input: A SQL query between ```sql and ``` tags. Output: Result from the database or an error message if the query is incorrect. Use this tool to execute the SQL query on the database, and return the results. + Add newline after both ```sql and ``` tags. """ args_schema: Type[BaseModel] = SQLInput diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index cc2a9e9b..3a1ecb28 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -154,6 +154,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): Output: Result from the database or an error message if the query is incorrect. If an error occurs, rewrite the query and retry. Use this tool to execute SQL queries. + Add newline after both ```sql and ``` tags. """ @catch_exceptions()