-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrunner.py
131 lines (105 loc) · 4.03 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import argparse
import random
import sys
import io
import numpy as np
import torch
import time
import yaml
import logging
import json
from collections import Counter
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve
from tqdm import tqdm
CLS_DATASET = ['play', 'StrategyQA', 'Fever', 'physics']
def find_most_frequent(lst):
counter = Counter(lst)
most_common = counter.most_common(1)
return most_common[0][1] if most_common else 0
def save2jsonl(name, data):
with open(name, 'w') as file:
for dict_obj in data:
json_str = json.dumps(dict_obj)
file.write(json_str + '\n')
def readjsonl2list(name):
data = [] # Create an empty list to store the dictionaries
with open(name, 'r') as file:
for line in file:
dict_obj = json.loads(line)
data.append(dict_obj)
return data
def main():
parser = argparse.ArgumentParser()
# GPU
parser.add_argument('--gpu', type=int, default=0, help="using gpu id")
# Dataset
parser.add_argument('--dataset', type=str, default='HotpotQA', help="Dataset name, gsm8k, math, StrategyQA, play, physics, Fever, 2WikiMultihop or HotpotQA.")
# Model
parser.add_argument('--model', type=str, default='gpt-4-0314', help="LLM name, e.g., text-davinci-003.")
# Method
parser.add_argument('--method', type=str, default='ADS', help="Method name for R2PE, ADS or PDS.")
args = parser.parse_args()
# set logger
logger = logging.getLogger(name='R2PE')
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s [%(name)s] >> %(message)s")
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
# set seed
seed = 123
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
logger.info("Seed set!")
# load dataset
fp = os.path.join('data', args.dataset, args.model, 'test.jsonl')
data = readjsonl2list(fp)
logger.info('Load {} examples from {}'.format(len(data), fp))
# load nli model
if args.method in ['PDS']:
from nli import NLI
# Device setting
device_str = f"cuda:{args.gpu}" if torch.cuda.is_available() else 'cpu'
device = torch.device(device_str)
sent_nli = NLI(device=device, granularity='sentence', nli_model='mnli')
logger.info(f"Device set to {device_str}.")
# predict
con = []
gt = []
for ex in tqdm(data, desc='Predicting'):
answers, responses, label = ex['answers'], ex['responses'], ex['label']
#print(responses, answers)
gt.append(not label)
if args.method == 'ADS':
threshold = 4.5 if args.dataset in CLS_DATASET else 2.5
ads = find_most_frequent(answers)
#print(ex['id'], ads)
con.append(ads < threshold)
elif args.method == 'PDS':
# compute ENS
scores, mean_score = sent_nli.score(responses)
ads = find_most_frequent(answers)
pds_aux = np.min(mean_score)
pds = ((ads-2.5)/2.5 + pds_aux) / 2
#print(ens, pds, ads, ex['id'])
if args.dataset in CLS_DATASET:
threshold = 0.4
else:
threshold = 0.0
if args.dataset == 'physics' and args.model == 'text-davinci-003':
threshold = 0.25
if args.dataset == 'Fever' and args.model == 'gemini-pro':
threshold = 0.15
con.append(pds < threshold)
else:
raise NotImplementedError
logger.info('PRECISION: {:.2f}\t'.format(100*precision_score(gt, con))+\
'RECALL: {:.2f}\t'.format(100*recall_score(gt, con))+\
'F1: {:.2f}\t'.format(100*f1_score(gt, con)))
if __name__ == '__main__':
main()