forked from taishan1994/pytorch_bert_entity_linking
-
Notifications
You must be signed in to change notification settings - Fork 0
/
el_preprocess.py
276 lines (240 loc) · 10 KB
/
el_preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import os
import json
import logging
import random
import pickle
import numpy as np
from transformers import BertTokenizer
import el_config
from utils import utils
from utils import tokenization
logger = logging.getLogger(__name__)
args = el_config.Args().get_parser()
utils.set_seed(args.seed)
utils.set_logger(os.path.join(args.log_dir, 'el_preprocess.log'))
class InputExample:
def __init__(self, set_type, text, seq_label, entity_label):
self.set_type = set_type
self.text = text
self.seq_label = seq_label
self.entity_label = entity_label
class BaseFeature:
def __init__(self, token_ids, attention_masks, token_type_ids):
# BERT 输入
self.token_ids = token_ids
self.attention_masks = attention_masks
self.token_type_ids = token_type_ids
class BertFeature(BaseFeature):
def __init__(self, token_ids, attention_masks, token_type_ids, seq_labels, entity_labels):
super(BertFeature, self).__init__(
token_ids=token_ids,
attention_masks=attention_masks,
token_type_ids=token_type_ids)
# labels
self.seq_labels = seq_labels
self.entity_labels = entity_labels
class ELProcessor:
def __init__(self):
with open('./data/ccks2019/entity_to_ids.json','r') as fp:
self.entity_to_ids = json.loads(fp.read())
with open('./data/ccks2019/subject_id_with_info.json','r') as fp:
self.subject_id_with_info = json.loads(fp.read())
def read_json(self, path):
with open(path,'r') as fp:
lines = fp.readlines()
return lines
def get_result(self, lines, set_type):
examples = []
for i,line in enumerate(lines):
line = eval(line)
text = line['text'].lower()
for mention_data in line['mention_data']:
word = mention_data['mention'].lower()
kb_id = mention_data['kb_id']
start_id = int(mention_data['offset'])
end_id = start_id+len(word)-1
rel_texts = self.get_text_pair(word, kb_id, text)
for i, rel_text in enumerate(rel_texts):
if i == 0:
examples.append(InputExample(
set_type=set_type,
text=rel_text,
seq_label=1,
entity_label=(kb_id, word, start_id, end_id)
))
else:
examples.append(InputExample(
set_type=set_type,
text=rel_text,
seq_label=0,
entity_label=(kb_id, word, start_id, end_id)
))
return examples
def get_text_pair(self, word, kb_id, text):
"""
用于构建正负样本对,一个正样本,三个负样本
:return:
"""
results = []
if kb_id != 'NIL' and word in self.entity_to_ids:
pos_example = self.get_info(kb_id) + '#;#' + text
results.append(pos_example)
ids = self.entity_to_ids[word]
if 'NIL' in ids:
ids.remove('NIL')
ind = ids.index(kb_id)
ids = ids[:ind] + ids[ind+1:]
if len(ids) >= 3:
ids = random.sample(ids, 3)
for t_id in ids:
info = self.get_info(t_id)
neg_example = info + '#;#' + text
results.append(neg_example)
return results
def get_info(self, subject_id):
"""
根据subject_id找到其描述文本,将predicate和object拼接
:param subject_id:
:return:
"""
infos = self.subject_id_with_info[subject_id]
data = infos['data']
res = []
for kg in data:
if kg['object'][-1] != '。':
res.append("{},{}。".format(kg['predicate'],kg['object']))
else:
res.append("{},{}".format(kg['predicate'], kg['object']))
return "".join(res).lower()
def convert_bert_example(ex_idx, example: InputExample, tokenizer: BertTokenizer, max_seq_len):
set_type = example.set_type
raw_text = example.text
seq_label = example.seq_label
entity_label = example.entity_label # (subject_id,mention,start,end)
# 文本元组
text_a, text_b = raw_text.split('#;#')
tokens_a = tokenization.BasicTokenizer().tokenize(text_a)
# 将句子标签进行one-hot编码
seq_final_label = [0, 0]
if seq_label == 0:
seq_final_label[0] = 1
else:
seq_final_label[1] = 1
# 这里避免将英文切分开,这里使用tokenization里面的BasicTokenzier进行切分,
# 切分之后要重新对实体的索引进行调整
start = entity_label[2]
end = entity_label[3]
tokenizer_pre = tokenization.BasicTokenizer().tokenize(text_b[:start])
tokenizer_label = tokenization.BasicTokenizer().tokenize(entity_label[1])
tokenizer_post = tokenization.BasicTokenizer().tokenize(text_b[end+1:])
real_label_start = len(tokenizer_pre)
real_label_end = len(tokenizer_pre) + len(tokenizer_label)
tokens_b = tokenizer_pre + tokenizer_label + tokenizer_post
try:
encode_dict = tokenizer.encode_plus(text=tokens_a,
text_pair=tokens_b,
max_length=max_seq_len,
padding='max_length',
truncation_strategy='only_first',
return_token_type_ids=True,
return_attention_mask=True)
except Exception as e:
print(e)
print(tokens_a)
print(tokens_b)
return '出现错误了','400'
token_ids = encode_dict['input_ids']
attention_masks = encode_dict['attention_mask']
token_type_ids = encode_dict['token_type_ids']
offset = token_type_ids.index(1) # 找到1最先出现的位置
entity_ids = [0] * max_seq_len
start_id = offset + real_label_start
end_id = offset + real_label_end
if end_id > max_seq_len:
print('发生了不该有的截断')
for i in range(start_id, max_seq_len):
entity_ids[i] = 1
else:
for i in range(start_id, end_id):
entity_ids[i] = 1
callback_info = (text_b,)
callback_entity_labels = (entity_label[1], offset)
callback_info += (callback_entity_labels,)
if ex_idx < 3:
logger.info(f"*** {set_type}_example-{ex_idx} ***")
logger.info(f"text: {raw_text}")
logger.info(f"token_ids: {token_ids}")
logger.info(f"attention_masks: {attention_masks}")
logger.info(f"token_type_ids: {token_type_ids}")
logger.info(f"entity_ids:{entity_ids}")
logger.info(f"seq_label:{seq_final_label}")
logger.info((tokenizer.convert_ids_to_tokens(token_ids[start_id:end_id + 1])))
feature = BertFeature(
# bert inputs
token_ids=token_ids,
attention_masks=attention_masks,
token_type_ids=token_type_ids,
seq_labels=seq_final_label,
entity_labels=entity_ids,
)
return feature, callback_info
def convert_examples_to_features(examples, max_seq_len, bert_dir):
tokenizer = BertTokenizer(os.path.join(bert_dir, 'vocab.txt'))
features = []
callback_info = []
logger.info(f'Convert {len(examples)} examples to features')
total = len(examples)
for i, example in enumerate(examples):
print(i, total)
feature, tmp_callback = convert_bert_example(
ex_idx=i,
example=example,
max_seq_len=max_seq_len,
tokenizer=tokenizer,
)
if tmp_callback == '400':
continue
if feature is None:
continue
features.append(feature)
callback_info.append(tmp_callback)
logger.info(f'Build {len(features)} features')
out = (features,)
if not len(callback_info):
return out
out += (callback_info,)
return out
def split_train_test(examples, train_rate):
total = len(examples)
train_total = int(total * train_rate)
test_total = total - train_total
print('总共有数据:{},划分后训练集:{},测试集:{}'.format(total, train_total, test_total))
random.shuffle(examples)
train_examples = examples[:train_total]
test_examples = examples[train_total:]
return train_examples, test_examples
def get_out(processor, txt_path, args, mode):
raw_examples = processor.read_json(txt_path)
examples = processor.get_result(raw_examples, mode)
for i, example in enumerate(examples):
print(example.text)
print(example.seq_label)
print(example.entity_label)
if i == 1:
break
train_examples, test_examples = split_train_test(examples, 0.7)
train_out = convert_examples_to_features(train_examples, args.max_seq_len, args.bert_dir)
test_out = convert_examples_to_features(test_examples, args.max_seq_len, args.bert_dir)
return train_out, test_out
if __name__ == '__main__':
args.max_seq_len = 256
logger.info(vars(args))
elprocessor = ELProcessor()
train_out, test_out = get_out(elprocessor, os.path.join(args.data_dir, 'train.json'), args, 'train')
with open(args.data_dir + 'train.pkl', 'wb') as fp:
pickle.dump(train_out, fp)
with open(args.data_dir + 'test.pkl', 'wb') as fp:
pickle.dump(test_out, fp)
# 由于只有训练数据,我们要对训练数据进行划分
# dev_out = get_out(elprocessor, os.path.join(args.data_dir, 'dev.json'), args, 'dev')
# test_out = get_out(elprocessor, os.path.join(args.data_dir, 'test.json'), args, 'test')