diff --git a/beir/retrieval/evaluation.py b/beir/retrieval/evaluation.py index 0cb2f14..435cbe0 100644 --- a/beir/retrieval/evaluation.py +++ b/beir/retrieval/evaluation.py @@ -1,5 +1,6 @@ import pytrec_eval import logging +from copy import deepcopy from typing import List, Dict, Tuple from .search.base import BaseSearch from .custom_metrics import mrr, recall_cap, hole, top_k_accuracy @@ -37,6 +38,27 @@ def rerank(self, return self.retriever.search(new_corpus, queries, top_k, self.score_function) + @staticmethod + def _get_keys_lexically_sorted(qrel: Dict[str, Dict[str, int]], results: Dict[str, Dict[str, float]]) \ + -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, float]]]: + """ + The evaluation happened in C, and the order of the dict items is lost there, and the new order is the + lexical order of the keys. In order to combat that (the order is important to us for evaluation metrics like + recall_1 vs recall_3), we add a prefix to the keys when they go into evaluation, so their lexical order will + match their order in the dictionary. + """ + new_qrel = deepcopy(qrel) + new_results = deepcopy(results) + for query_id in new_results.keys(): + result = new_results[query_id] + for i, (k_, v_) in enumerate(reversed(result.items())): + new_key = f"{i}-{k_}" + result[new_key] = result.pop(k_) + if k_ in new_qrel.get(query_id, []): + new_qrel[query_id][new_key] = new_qrel[query_id].pop(k_) + + return new_qrel, new_results + @staticmethod def evaluate(qrels: Dict[str, Dict[str, int]], results: Dict[str, Dict[str, float]], @@ -67,9 +89,12 @@ def evaluate(qrels: Dict[str, Dict[str, int]], ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values]) recall_string = "recall." + ",".join([str(k) for k in k_values]) precision_string = "P." + ",".join([str(k) for k in k_values]) - evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string, precision_string}) - scores = evaluator.evaluate(results) - + + sorted_qrel, sorted_results = EvaluateRetrieval._get_keys_lexically_sorted(qrels, results) + evaluator = pytrec_eval.RelevanceEvaluator(sorted_qrel, + {map_string, ndcg_string, recall_string, precision_string}) + scores = evaluator.evaluate(sorted_results) + for query_id in scores.keys(): for k in k_values: ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]