From bcbfb0cf7d67fc9eed3b785e785ff1d51d1ae2f4 Mon Sep 17 00:00:00 2001 From: lucaordronneau Date: Wed, 14 Aug 2024 17:17:33 +0200 Subject: [PATCH] [feat] Sort SQL result tuples for fair comparison during evaluation #159 --- bird/llm/src/evaluation.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bird/llm/src/evaluation.py b/bird/llm/src/evaluation.py index 0459b75c..2b6e175a 100644 --- a/bird/llm/src/evaluation.py +++ b/bird/llm/src/evaluation.py @@ -13,6 +13,8 @@ def load_json(dir): def result_callback(result): exec_result.append(result) +def sort_tuple(t): + return tuple(sorted(t, key=lambda x: str(x))) def execute_sql(predicted_sql,ground_truth, db_path): conn = sqlite3.connect(db_path) @@ -22,6 +24,11 @@ def execute_sql(predicted_sql,ground_truth, db_path): predicted_res = cursor.fetchall() cursor.execute(ground_truth) ground_truth_res = cursor.fetchall() + + # sort the results to allow fair comparison + predicted_res = [sort_tuple(t) for t in predicted_res] + ground_truth_res = [sort_tuple(t) for t in ground_truth_res] + res = 0 if set(predicted_res) == set(ground_truth_res): res = 1