-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprepare_trainset.py
126 lines (115 loc) · 4.23 KB
/
prepare_trainset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import codecs
import os
import pickle
import random
import sys
from transformers import BertTokenizer
from io_utils.io_utils import load_data
from trainset_building.trainset_building import build_trainset_for_siam
from trainset_building.trainset_building import build_trainset_for_ner
def main():
random.seed(42)
if len(sys.argv) < 2:
err_msg = 'The source training file is not specified!'
raise ValueError(err_msg)
src_fname = os.path.normpath(sys.argv[1])
if len(sys.argv) < 3:
err_msg = 'The destination training file is not specified!'
raise ValueError(err_msg)
dst_fname = os.path.normpath(sys.argv[2])
if len(sys.argv) < 4:
err_msg = 'The NER vocabulary file is not specified!'
raise ValueError(err_msg)
ners_fname = os.path.normpath(sys.argv[3])
if len(sys.argv) < 5:
err_msg = 'The training mode is not specified!'
raise ValueError(err_msg)
training_mode = sys.argv[4].strip().lower()
if len(training_mode) == 0:
err_msg = 'The training mode is not specified!'
raise ValueError(err_msg)
if len(sys.argv) < 6:
err_msg = 'The maximal sequence length is not specified!'
raise ValueError(err_msg)
try:
max_len = int(sys.argv[5])
except:
max_len = 0
if max_len < 1:
err_msg = f'{sys.argv[5]} is inadmissible value ' \
f'of the maximal sequence length!'
raise ValueError(err_msg)
if len(sys.argv) < 7:
err_msg = 'The pre-trained BERT model is not specified!'
raise ValueError(err_msg)
pretrained_model = sys.argv[6]
if training_mode == 'siamese':
if len(sys.argv) < 8:
err_msg = 'The maximal number of samples is not specified!'
raise ValueError(err_msg)
try:
max_samples = int(sys.argv[7])
except:
max_samples = 0
if max_samples < 1:
err_msg = f'{sys.argv[7]} is inadmissible value ' \
f'of the maximal number of samples!'
raise ValueError(err_msg)
else:
max_samples = 0
if not os.path.isfile(src_fname):
raise IOError(f'The file {src_fname} does not exist!')
if not os.path.isfile(ners_fname):
raise IOError(f'The file {ners_fname} does not exist!')
dname = os.path.dirname(dst_fname)
if len(dname) > 0:
if not os.path.isdir(dname):
raise IOError(f'The directory {dname} does not exist!')
if training_mode not in {'siamese', 'ner'}:
err_msg = f'The training mode {training_mode} is unknown! ' \
f'Possible values: siamese, ner.'
raise ValueError(err_msg)
with codecs.open(ners_fname, mode='r', encoding='utf-8') as fp:
possible_named_entities = list(filter(
lambda it2: len(it2) > 0,
map(
lambda it1: it1.strip(),
fp.readlines()
)
))
if len(possible_named_entities) == 0:
err_msg = f'The file {ners_fname} is empty!'
raise IOError(err_msg)
if len(possible_named_entities) != len(set(possible_named_entities)):
err_msg = f'The file {ners_fname} contains a wrong data! ' \
f'Some entities are duplicated!'
raise IOError(err_msg)
source_data = load_data(src_fname)
bert_tokenizer = BertTokenizer.from_pretrained(pretrained_model)
if training_mode == 'ner':
prep_data = build_trainset_for_ner(
data=source_data,
tokenizer=bert_tokenizer,
entities=possible_named_entities,
max_seq_len=max_len
)
print('')
print(f'X.shape = {prep_data[0].shape}')
for output_idx in range(len(prep_data[1])):
print(f'y[{output_idx}].shape = {prep_data[1][output_idx].shape}')
else:
prep_data = build_trainset_for_siam(
data=source_data,
tokenizer=bert_tokenizer,
entities=possible_named_entities,
max_seq_len=max_len,
max_samples=max_samples
)
with open(dst_fname, 'wb') as fp:
pickle.dump(
file=fp,
obj=prep_data,
protocol=pickle.HIGHEST_PROTOCOL
)
if __name__ == '__main__':
main()