-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathpreprocess_chid_zeroshot.py
85 lines (68 loc) · 2.98 KB
/
preprocess_chid_zeroshot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import json
import re
import os
import argparse
from tqdm import tqdm
from data_utils.tokenization_gpt2 import GPT2Tokenizer
def process_one_sent_eval(tokenizer, sent, answers, candidates):
pattern = re.compile(r"#idiom\d+#")
res = pattern.findall(sent)
start = 0
L = []
# fill the candidate idioms into the sentence to create candidate passages
# NOTE: there may exist more than one blank in a sentence
while True:
m = pattern.search(sent, start)
if m is None:
break
L.append({
"cands": [],
"truth": answers[m.group()]
})
for idm in candidates:
cand = sent[:m.start()] + idm + sent[m.end():]
# replace other blanks by ""
cand = re.sub(pattern, "", cand)
ids = tokenizer.encode(cand)
L[-1]["cands"].append(ids)
start = m.end()
return L
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", default=None, type=str, help="The input dir of original ChID data.")
parser.add_argument("--tokenizer_path", type=str, help="The tokenizer path.", default="./bpe_3w_new")
parser.add_argument("--output_dir", type=str, help="The processed data output dir.")
args = parser.parse_args()
tokenizer = GPT2Tokenizer(os.path.join(args.tokenizer_path, 'vocab.json'), os.path.join(args.tokenizer_path, 'chinese_vocab.model'))
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.data_dir, "idiomDict.json"), "r") as f:
idiom_d = json.load(f)
idioms = list(idiom_d.keys())
for split in ["test"]:
# for zero-shot setting, we only consider test set
with open(os.path.join(args.data_dir, "{}.json".format(split)), "r") as f:
lines = f.readlines()
with open(os.path.join(args.data_dir, "{}_answer.json".format(split)), "r") as f:
ans_d = json.load(f)
all_data = {
"contents": [],
"sids": [],
"labels": [],
"cids": []
}
sid = 0
for line in tqdm(lines, desc="Preprocessing {}".format(split)):
jobj = json.loads(line)
for sent in jobj["content"]:
sample_L = process_one_sent_eval(tokenizer, sent, ans_d, jobj["candidates"])
for samp in sample_L:
all_data["contents"].extend(samp["cands"])
all_data["sids"].extend([sid for _ in samp["cands"]])
all_data["cids"].extend([i for i in range(len(samp["cands"]))])
all_data["labels"].append(samp["truth"])
sid += 1
with open(os.path.join(args.output_dir, "{}.json".format(split)), "w") as f:
json.dump(all_data, f, indent=4, ensure_ascii=False)
print(len(all_data["contents"]))
with open(os.path.join(args.output_dir, "idioms.json"), "w") as f:
json.dump(idioms, f, indent=4, ensure_ascii=False)