-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathpreprocess_stc_finetune.py
44 lines (33 loc) · 1.96 KB
/
preprocess_stc_finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import json
import os
import argparse
from tqdm import tqdm
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", default=None, type=str, help="The input dir of original ChID data.")
parser.add_argument("--output_dir", type=str, help="The processed data output dir.")
args = parser.parse_args()
# train
with open(os.path.join(args.data_dir, 'STC.json'), 'r') as f:
raw_data = json.loads(f.read())
raw_data = raw_data["train"]
with open(os.path.join(args.output_dir, 'train_all.txt'), 'w') as f_out:
for pr_pair in tqdm(raw_data, desc="Building Train All"):
f_out.write("对话上文:" + "".join(pr_pair[0].strip().split()) + " 回复:" + "".join(pr_pair[1].strip().split()) + "\n")
with open(os.path.join(args.output_dir, 'train.txt'), 'w') as f_out:
for pr_pair in tqdm(raw_data[:int(0.1 * len(raw_data))], desc="Building Train"):
f_out.write("对话上文:" + "".join(pr_pair[0].strip().split()) + " 回复:" + "".join(pr_pair[1].strip().split()) + "\n")
# valid
with open(os.path.join(args.data_dir, 'STC.json'), 'r') as f:
raw_data = json.loads(f.read())
raw_data = raw_data["valid"]
with open(os.path.join(args.output_dir, 'valid.txt'), 'w') as f_out:
for pr_pair in tqdm(raw_data, desc="Building Valid"):
f_out.write("对话上文:" + "".join(pr_pair[0].strip().split()) + " 回复:" + "".join(pr_pair[1].strip().split()) + "\n")
# test
with open(os.path.join(args.data_dir, 'STC_test.json'), 'r') as f:
raw_data = json.loads(f.read())
raw_data = raw_data["test"]
with open(os.path.join(args.output_dir, 'test.txt'), 'w') as f_out:
for pr_pair in tqdm(raw_data, desc="Building Test"):
f_out.write("对话上文:" + "".join(pr_pair[0].strip().split()) + " 回复:" + "".join(pr_pair[1].strip().split()) + "\n")