-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess_data.py
93 lines (77 loc) · 2.42 KB
/
preprocess_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# -*- coding: utf-8 -*-
from pathlib import Path
from typing import List, Optional, Dict, Any
import re
import json
from tqdm import tqdm
import copy
RAWDATA_PATH = Path("./data/train.txt")
TARGET_DIR = Path("./data/cgec")
class DataProcessor:
leading_dash_pattern = re.compile(r"^——(.*)")
def get_gec_samples_from_str(
self,
text: str,
) -> Optional[List[Dict[str, Any]]]:
_, _, source, *corrects = text.strip().split('\t')
samples = []
source = DataProcessor._remove_leading_dash(source)
if len(source) > 0:
common = dict()
common["source"] = source
if len(corrects) > 0:
for cor in corrects:
cor = DataProcessor._remove_leading_dash(cor)
if len(cor) > 0:
sample = copy.deepcopy(common)
sample["correct"] = cor
samples.append(sample)
if len(samples) == 0:
return None
else:
common["correct"] = source
samples.append(common)
return samples
return None
@classmethod
def _remove_leading_dash(cls, text: str) -> str:
matches = cls.leading_dash_pattern.findall(text)
if len(matches) > 0:
return matches[0]
else:
return text
def get_trainset_valset_from_file(self, source_file: Path,
save_dir: Path) -> None:
trainset_path = save_dir / "train.txt"
valset_path = save_dir / "val.txt"
trainset = []
valset = []
val_num, i = 5000, 0
source_data = open(source_file, "r").readlines()
for line in tqdm(source_data):
samples = self.get_gec_samples_from_str(line)
if samples is not None:
if i < val_num:
if len(samples) == 1:
valset.extend(samples)
i += 1
else:
trainset.extend(samples)
else:
trainset.extend(samples)
save_dir.mkdir(exist_ok=True)
if len(trainset) > 0:
save_dicts_to_file(trainset, trainset_path)
save_dicts_to_file(valset, valset_path)
def save_dicts_to_file(dicts: List[Dict[str, Any]],
save_path: str,
encoding='utf-8') -> None:
with open(save_path, "w+", encoding=encoding) as fout:
for d in dicts:
json.dump(d, fout, ensure_ascii=False)
fout.write('\n')
def main():
data_processor = DataProcessor()
data_processor.get_trainset_valset_from_file(RAWDATA_PATH, TARGET_DIR)
if __name__ == '__main__':
main()