-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathhard_negatives_generator.py
67 lines (53 loc) · 2.33 KB
/
hard_negatives_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from evaluation import get_relevance_label_df
from evaluation import get_relevance_label
from shared.utils import load_from_json
from shared.utils import dump_to_json
from searcher import Searcher
from tqdm import tqdm
import pandas as pd
import random
class Hard_Negatives_Generator(object):
def __init__(self, es, index, query_by, top_k=10, query_type='faq'):
self.es = es
self.index = index
self.query_by = query_by
self.top_k = top_k
self.query_type = query_type
def get_hard_negatives(self, relevance_label_df):
""" Get a list of hard negative question-answer pairs """
# define Searcher instance
s = Searcher(self.es, index=self.index, fields=self.query_by , top_k=self.top_k)
# Generate a dictionary where {key: query_string, value: list of answers}
relevance_label = get_relevance_label(relevance_label_df)
unique_questions = []
if self.query_type == "faq":
test_queries = relevance_label_df[relevance_label_df['query_type'] == self.query_type]
unique_questions = test_queries.query_string.unique()
else:
unique_questions = relevance_label_df.query_string.unique()
results = []
for query_string in tqdm(unique_questions):
# perform query using question as query_string
topk_results = s.query(query_string=query_string)
# get the list of actual answers
answers = relevance_label[query_string]
# obtain relevance label for each answer
rank = 0
for doc in topk_results:
topk_answer = doc['answer']
# check if the answer is a true answer
label = 0
if topk_answer in answers:
label = 1
if label == 0:
rank += 1
data = dict()
data["query_string"] = query_string
data["neg_answer"] = doc["answer"]
data["question"] = doc["question"]
data["question_answer"] = doc["question_answer"]
data["score"] = doc["score"]
data["label"] = label
data["rank"] = rank
results.append(data)
return results