From aaa3039bebc107fea8d61dba1b968dde6a936ebd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Tue, 3 Sep 2024 00:58:32 +0200 Subject: [PATCH] fix sql ex metric --- benchmarks/sql/bench/metrics/sql.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/benchmarks/sql/bench/metrics/sql.py b/benchmarks/sql/bench/metrics/sql.py index 0b5899e7..ca8a4515 100644 --- a/benchmarks/sql/bench/metrics/sql.py +++ b/benchmarks/sql/bench/metrics/sql.py @@ -95,6 +95,7 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: Returns: Execution accuracy score and valid efficiency score. """ + results = [result for result in results if result.reference.sql and result.prediction.sql] accurate_results = [result for result in results if self._execution_accuracy(result)] return { "EX": len(accurate_results) / len(results) if results else None, @@ -121,9 +122,6 @@ def _execution_accuracy(self, result: EvaluationResult) -> bool: Returns: True if the execution results are identical, False otherwise. """ - if result.prediction.sql is None: - return False - try: ref_results = self._execute_query(result.reference.sql, result.db_id) pred_results = self._execute_query(result.prediction.sql, result.db_id) @@ -138,6 +136,10 @@ def _execution_accuracy(self, result: EvaluationResult) -> bool: if reference.shape[0] != prediction.shape[0]: return False + # If both dataframes have only one column, compare the values directly + if reference.shape[1] == prediction.shape[1] == 1: + return reference.iloc[:, 0].equals(prediction.iloc[:, 0]) + # Returned view may have the same columns, or more columns than the ground truth if not reference.columns.isin(prediction.columns).all(): return False