Skip to content

Commit

Permalink
DATA-2038/fixing the fallback and confidence score
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Apr 4, 2024
1 parent 39bef01 commit 8c62807
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
46 changes: 29 additions & 17 deletions dataherald/eval/simple_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
You are a {dialect} expert.
Given a question, a SQL query, and the database schema, analyze the correctness of the SQL query and provide a score.
Score indicates how correctly and accurately SQL query answers the question.
Note that the score should be between 0 and 100. Higher scores means the SQL Query is more accurate.
Note that the score should be between 0 and {MAX_CONFIDENCE}. Higher scores means the SQL Query is more accurate.
Double check the SQL query for the common mistakes, including:
- For columns that can contain NULL values, NULL values should be filtered out by using the IS NOT NULL operator in the WHERE condition
- when intention of the question is to include all rows from both sets, including duplicates, using UNION ALL is better than UNION
Expand Down Expand Up @@ -85,13 +85,34 @@ def answer_parser(self, answer: str) -> int:
output = int(numbers[-1])
return output

def create_sql_results(self, result: Any) -> list:
rows = []
if result:
for row in result:
modified_row = {}
for key, value in zip(row.keys(), row, strict=True):
if type(value) in [
date,
datetime,
]: # Check if the value is an instance of datetime.date
modified_row[key] = str(value)
elif (
type(value) is Decimal
): # Check if the value is an instance of decimal.Decimal
modified_row[key] = float(value)
else:
modified_row[key] = value
rows.append(modified_row)
return rows

@override
def evaluate(
self,
user_prompt: Prompt,
sql_generation: SQLGeneration,
database_connection: DatabaseConnection,
) -> Evaluation:
max_confidence = 100
database = SQLDatabase.get_sql_engine(database_connection)
logger.info(
f"(Simple evaluator) Generating score for the question/sql pair: {str(user_prompt.text)}/ {str(sql_generation.sql)}"
Expand Down Expand Up @@ -142,33 +163,24 @@ def evaluate(
with database._engine.connect() as connection:
execution = connection.execute(text(query))
result = execution.fetchmany(TOP_K)
rows = []
for row in result:
modified_row = {}
for key, value in zip(row.keys(), row, strict=True):
if type(value) in [
date,
datetime,
]: # Check if the value is an instance of datetime.date
modified_row[key] = str(value)
elif (
type(value) is Decimal
): # Check if the value is an instance of decimal.Decimal
modified_row[key] = float(value)
else:
modified_row[key] = value
rows.append(modified_row)
rows = self.create_sql_results(result)

except SQLInjectionError as e:
raise SQLInjectionError(
"Sensitive SQL keyword detected in the query."
) from e
if not rows:
logger.info(
f"(Simple evaluator) SQL query: {sql} returned no results. max confidence is 70"
)
max_confidence = 70
answer = chain.invoke(
{
"dialect": dialect,
"question": user_question,
"SQL": sql,
"SQL_result": "\n".join([str(row) for row in rows]),
"MAX_CONFIDENCE": str(max_confidence),
"schema": schema,
}
)["text"]
Expand Down
6 changes: 3 additions & 3 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def extract_query_from_intermediate_steps(
sql_query = self.remove_markdown(action.tool_input)
if sql_query == "":
for step in intermediate_steps:
action = step[0]
if "SELECT" in action.tool_input.upper():
sql_query = self.remove_markdown(action.tool_input)
thought = str(step[0].log).split("Action:")[0]
if "```sql" in thought:
sql_query = self.remove_markdown(thought)
if not sql_query.upper().strip().startswith("SELECT"):
sql_query = ""
return sql_query
Expand Down

0 comments on commit 8c62807

Please sign in to comment.