Skip to content

Commit

Permalink
fix sql ex metric
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Sep 2, 2024
1 parent 3beff6d commit aaa3039
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions benchmarks/sql/bench/metrics/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
Returns:
Execution accuracy score and valid efficiency score.
"""
results = [result for result in results if result.reference.sql and result.prediction.sql]
accurate_results = [result for result in results if self._execution_accuracy(result)]
return {
"EX": len(accurate_results) / len(results) if results else None,
Expand All @@ -121,9 +122,6 @@ def _execution_accuracy(self, result: EvaluationResult) -> bool:
Returns:
True if the execution results are identical, False otherwise.
"""
if result.prediction.sql is None:
return False

try:
ref_results = self._execute_query(result.reference.sql, result.db_id)
pred_results = self._execute_query(result.prediction.sql, result.db_id)
Expand All @@ -138,6 +136,10 @@ def _execution_accuracy(self, result: EvaluationResult) -> bool:
if reference.shape[0] != prediction.shape[0]:
return False

# If both dataframes have only one column, compare the values directly
if reference.shape[1] == prediction.shape[1] == 1:
return reference.iloc[:, 0].equals(prediction.iloc[:, 0])

# Returned view may have the same columns, or more columns than the ground truth
if not reference.columns.isin(prediction.columns).all():
return False
Expand Down

0 comments on commit aaa3039

Please sign in to comment.