Skip to content

Commit

Permalink
DH-4796/handling the bug with schema missing in evaluator (#198)
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza authored and DishenWang2023 committed May 7, 2024
1 parent 40584b7 commit b0eee7e
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions dataherald/eval/simple_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
from typing import Any

from bson.objectid import ObjectId
from langchain.chains import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
Expand All @@ -11,9 +12,11 @@
)
from overrides import override
from sql_metadata import Parser
from sqlalchemy.exc import SQLAlchemyError

from dataherald.config import System
from dataherald.db import DB
from dataherald.db_scanner.models.types import TableDescriptionStatus
from dataherald.db_scanner.repository.base import TableDescriptionRepository
from dataherald.eval import Evaluation, Evaluator
from dataherald.sql_database.base import SQLDatabase
from dataherald.sql_database.models.types import DatabaseConnection
Expand Down Expand Up @@ -48,6 +51,7 @@
Question: {question}
Evaluate the following SQL query:
SQL Query: {SQL}
SQL Query Result: {SQL_result}
give me a one or two lines explanation and the score after 'Score: '.
"""

Expand Down Expand Up @@ -88,6 +92,14 @@ def evaluate(
logger.info(
f"(Simple evaluator) Generating score for the question/sql pair: {str(question.question)}/ {str(generated_answer.sql_query)}"
)
storage = self.system.instance(DB)
repository = TableDescriptionRepository(storage)
db_scan = repository.get_all_tables_by_db(
{
"db_connection_id": ObjectId(database_connection.id),
"status": TableDescriptionStatus.SYNCHRONIZED.value,
}
)
self.llm = self.model.get_model(
database_connection=database_connection, temperature=0
)
Expand All @@ -103,17 +115,11 @@ def evaluate(
sql = generated_answer.sql_query
dialect = database.dialect
tables = Parser(sql).tables
database._sample_rows_in_table_info = 0
schema = database.get_table_info_no_throw(tables)
try:
run_result = database.run_sql(sql)[0]
except SQLAlchemyError as e:
"""Format the error message"""
run_result = f"Error: {e}"
except Exception as e:
logger.info(f"(Simple evaluator) Error: {e}")
run_result = f"Error: {e}"
if run_result == "[]" or "Error:" in run_result:
schema = ""
for scanned_table in db_scan:
if scanned_table.table_name in tables:
schema += f"Table: {scanned_table.table_schema}\n"
if generated_answer.sql_query_result is None:
logger.info(
f"(Simple evaluator) SQL query: {sql} is not valid. Returning score 0"
)
Expand All @@ -126,6 +132,7 @@ def evaluate(
"dialect": dialect,
"question": user_question,
"SQL": sql,
"SQL_result": str(generated_answer.sql_query_result.json()),
"schema": schema,
}
)
Expand Down

0 comments on commit b0eee7e

Please sign in to comment.