diff --git a/dataherald/eval/simple_evaluator.py b/dataherald/eval/simple_evaluator.py index be090759..f3ccd854 100644 --- a/dataherald/eval/simple_evaluator.py +++ b/dataherald/eval/simple_evaluator.py @@ -30,7 +30,7 @@ You are a {dialect} expert. Given a question, a SQL query, and the database schema, analyze the correctness of the SQL query and provide a score. Score indicates how correctly and accurately SQL query answers the question. -Note that the score should be between 0 and 100. Higher scores means the SQL Query is more accurate. +Note that the score should be between 0 and {MAX_CONFIDENCE}. Higher scores means the SQL Query is more accurate. Double check the SQL query for the common mistakes, including: - For columns that can contain NULL values, NULL values should be filtered out by using the IS NOT NULL operator in the WHERE condition - when intention of the question is to include all rows from both sets, including duplicates, using UNION ALL is better than UNION @@ -85,6 +85,26 @@ def answer_parser(self, answer: str) -> int: output = int(numbers[-1]) return output + def create_sql_results(self, result: Any) -> list: + rows = [] + if result: + for row in result: + modified_row = {} + for key, value in zip(row.keys(), row, strict=True): + if type(value) in [ + date, + datetime, + ]: # Check if the value is an instance of datetime.date + modified_row[key] = str(value) + elif ( + type(value) is Decimal + ): # Check if the value is an instance of decimal.Decimal + modified_row[key] = float(value) + else: + modified_row[key] = value + rows.append(modified_row) + return rows + @override def evaluate( self, @@ -92,6 +112,7 @@ def evaluate( sql_generation: SQLGeneration, database_connection: DatabaseConnection, ) -> Evaluation: + max_confidence = 100 database = SQLDatabase.get_sql_engine(database_connection) logger.info( f"(Simple evaluator) Generating score for the question/sql pair: {str(user_prompt.text)}/ {str(sql_generation.sql)}" @@ -142,33 +163,24 @@ def evaluate( with database._engine.connect() as connection: execution = connection.execute(text(query)) result = execution.fetchmany(TOP_K) - rows = [] - for row in result: - modified_row = {} - for key, value in zip(row.keys(), row, strict=True): - if type(value) in [ - date, - datetime, - ]: # Check if the value is an instance of datetime.date - modified_row[key] = str(value) - elif ( - type(value) is Decimal - ): # Check if the value is an instance of decimal.Decimal - modified_row[key] = float(value) - else: - modified_row[key] = value - rows.append(modified_row) + rows = self.create_sql_results(result) except SQLInjectionError as e: raise SQLInjectionError( "Sensitive SQL keyword detected in the query." ) from e + if not rows: + logger.info( + f"(Simple evaluator) SQL query: {sql} returned no results. max confidence is 70" + ) + max_confidence = 70 answer = chain.invoke( { "dialect": dialect, "question": user_question, "SQL": sql, "SQL_result": "\n".join([str(row) for row in rows]), + "MAX_CONFIDENCE": str(max_confidence), "schema": schema, } )["text"] diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index 081df00c..dd318bf6 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -99,9 +99,9 @@ def extract_query_from_intermediate_steps( sql_query = self.remove_markdown(action.tool_input) if sql_query == "": for step in intermediate_steps: - action = step[0] - if "SELECT" in action.tool_input.upper(): - sql_query = self.remove_markdown(action.tool_input) + thought = str(step[0].log).split("Action:")[0] + if "```sql" in thought: + sql_query = self.remove_markdown(thought) if not sql_query.upper().strip().startswith("SELECT"): sql_query = "" return sql_query