From 5160e8d537a615751330517004d96340ad96247f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Carlos=20Jos=C3=A9=20Camacho?= Date: Mon, 18 Mar 2024 09:23:09 -0600 Subject: [PATCH] [DH-5597] Fix sql-generation (#432) --- dataherald/sql_generator/__init__.py | 7 +++---- dataherald/sql_generator/dataherald_finetuning_agent.py | 2 +- dataherald/sql_generator/dataherald_sqlagent.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index 461fdd00..ecc80526 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -60,7 +60,7 @@ def remove_markdown(self, query: str) -> str: matches = re.findall(pattern, query, re.DOTALL) if matches: return matches[0].strip() - return "" + return query @classmethod def get_upper_bound_limit(cls) -> int: @@ -110,15 +110,14 @@ def extract_query_from_intermediate_steps( action = step[0] if type(action) == AgentAction and action.tool == "SqlDbQuery": sql_query = self.format_sql_query(action.tool_input) - if "```sql" in sql_query: + if "SELECT" in sql_query.upper(): sql_query = self.remove_markdown(sql_query) if sql_query == "": for step in intermediate_steps: action = step[0] sql_query = action.tool_input - if "```sql" in sql_query: + if "SELECT" in sql_query.upper(): sql_query = self.remove_markdown(sql_query) - return sql_query @abstractmethod diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index eb177268..53651031 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -585,7 +585,7 @@ def generate_response( sql_query = self.remove_markdown(result["output"]) else: sql_query = self.extract_query_from_intermediate_steps( - result["intermediate"] + result["intermediate_steps"] ) logger.info(f"cost: {str(cb.total_cost)} tokens: {str(cb.total_tokens)}") response.sql = replace_unprocessable_characters(sql_query) diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 4af5b3e9..09dccef7 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -730,7 +730,7 @@ def generate_response( sql_query = self.remove_markdown(result["output"]) else: sql_query = self.extract_query_from_intermediate_steps( - result["intermediate"] + result["intermediate_steps"] ) logger.info(f"cost: {str(cb.total_cost)} tokens: {str(cb.total_tokens)}") response.sql = replace_unprocessable_characters(sql_query)