forked from junchaoIU/DetectRL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlogRank_evaluation.py
81 lines (65 loc) · 2.93 KB
/
logRank_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import logging
import random
import torch
import tqdm
import argparse
import json
from rank import get_rank
from metrics import get_roc_metrics
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
def experiment(args):
# load model
logging.info(f"Loading base model of type {args.base_model}...")
base_tokenizer = AutoTokenizer.from_pretrained(args.base_model)
base_model = AutoModelForCausalLM.from_pretrained(args.base_model)
base_model.eval()
base_model.cuda()
filenames = args.test_data_path.split(",")
for filename in filenames:
logging.info(f"Test in {filename}")
test_data = json.load(open(filename, "r"))
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
predictions = {'human': [], 'llm': []}
for item in tqdm.tqdm(test_data):
text = item["text"]
label = item["label"]
item["text_logrank"] = -get_rank(text, args, base_tokenizer, base_model, log=True)
# result
if label == "human":
predictions['human'].append(item["text_logrank"])
elif label == "llm":
predictions['llm'].append(item["text_logrank"])
else:
raise ValueError(f"Unknown label {label}")
predictions['human'] = [i for i in predictions['human'] if np.isfinite(i)]
predictions['llm'] = [i for i in predictions['llm'] if np.isfinite(i)]
roc_auc, optimal_threshold, conf_matrix, precision, recall, f1, accuracy = get_roc_metrics(predictions['human'],
predictions['llm'])
result = {
"roc_auc": roc_auc,
"optimal_threshold": optimal_threshold,
"conf_matrix": conf_matrix,
"precision": precision,
"recall": recall,
"f1": f1,
"accuracy": accuracy
}
print(f"{result}")
with open(filename.split(".json")[0] + "_logRank_data.json", "w") as f:
json.dump(test_data, f, indent=4)
with open(filename.split(".json")[0] + "_logRank_result.json", "w") as f:
json.dump(result, f, indent=4)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--test_data_path', type=str, required=True,
help="Path to the test data. could be several files with ','. "
"Note that the data should have been perturbed.")
parser.add_argument('--base_model', default="EleutherAI/gpt-neo-2.7B", type=str, required=False)
parser.add_argument('--DEVICE', default="cuda", type=str, required=False)
parser.add_argument('--seed', default=2023, type=int, required=False)
args = parser.parse_args()
experiment(args)