-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdraft.py
186 lines (160 loc) · 5.85 KB
/
draft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# encoding = utf8
def change_to_bio():
train_file = 'data/train.txt'
data = []
with open(train_file, 'r', encoding='utf-8') as f:
for line in f:
tokens = line.strip().split(' ')
sentence = []
for token in tokens:
words, tag= token.split('/')
words = words.split('_')
if tag == 'o':
for word in words:
sentence.append(word+' '+tag)
else:
sentence.append(words[0]+' B-'+tag)
for word in words[1:]:
sentence.append(word + ' I-' + tag)
data.append(sentence)
return data
def create_vocab():
dico = {}
with open('data/corpus.txt', 'r', encoding='utf-8') as f:
for line in f:
words = line.strip().split('_')
for item in words:
if item not in dico:
dico[item] = 1
else:
dico[item] += 1
dico['<S>'] = 1000000002
dico['</S>'] = 1000000001
dico['<UNK>'] = 1000000000
sorted_items = sorted(dico.items(), key=lambda x: (-x[1], x[0]))
print(sorted_items)
with open('data/vocab.txt', 'w', encoding='utf-8') as f:
for item, count in sorted_items:
f.write(item)
f.write('\n')
def ner_data_stat():
cnt_test = 0
with open('data/test.txt', 'r', encoding='utf-8') as f:
for line in f.readlines():
data = line.strip().split('_')
if len(data) > 126:
cnt_test += 1
print("有{}条测试数据长度大于126。".format(cnt_test))
cnt_train = 0
with open('data/train.txt', 'r', encoding='utf-8') as f:
for line in f:
tokens = line.strip().split(' ')
sentence = []
for token in tokens:
words, tag = token.split('/')
words = words.split('_')
if tag == 'o':
for word in words:
sentence.append([word, 'O'])
else:
sentence.append([words[0], 'B-' + tag])
for word in words[1:]:
sentence.append([word, 'I-' + tag])
if len(sentence) > 126:
cnt_train += 1
print("有{}条训练数据长度大于126。".format(cnt_train))
def gaopincidebiaoji():
gaopinci = []
with open('data/vocab.txt', 'r', encoding='utf-8') as f:
for i in range(10):
line = f.readline().strip()
gaopinci.append(line)
gaopinci = gaopinci[3:]
confirmed_not_tag_word = gaopinci[2] # 15274
# print(confirmed_not_tag_word)
one_not_tag_word = gaopinci[0] # 只出现在了一句长度小于126的句子中,标记为a。可能是"的"一类的字
# print(one_not_tag_word) # 21224
# most_word = gaopinci[6]
train, test = get_ner_long_sent()
shorter = []
print("共有{}条NER训练数据的长度大于126".format(len(train)))
# 拿掉长度大于252的句子
len_126_252 = []
len_252 = []
for item in train:
sent, tags = item
if len(sent) < 252:
len_126_252.append(item)
else:
len_252.append(item)
print("共有{}条NER训练数据的长度介于126到252之间".format(len(len_126_252)))
train_dict = []
success = 0
duandian = []
for sent, tags in len_126_252:
index = []
for idx, word in enumerate(sent):
if word == one_not_tag_word or word == confirmed_not_tag_word:
index.append(idx)
if tags[idx] != 'O':
print("error")
split = int(len(sent)/2)
for i in range(len(index) - 1):
if abs(index[i] - split) <= abs(index[i+1] - split):
split = index[i]
break
if split < 126 and (len(sent) - split) < 126:
success += 1
else:
train_dict.append({'res': len(sent)-split,
'indexes': index,
'split': split})
if not index:
print(len(sent))
print(' '.join(sent))
duandian.append(index)
print(success)
# print("共有{}条NER测试数据的长度大于126".format(len(test)))
# cnt = 0
# for sent in test:
# if confirmed_not_tag_word in sent:
# cnt += 1
# else:
# if one_not_tag_word in sent:
# cnt += 1
# print(cnt)
def get_ner_long_sent():
# get sentences longer than 126 in ner data
train_sents = []
with open('data/train.txt', 'r', encoding='utf-8') as f:
for line in f:
tokens = line.strip().split(' ')
sentence = []
tags = []
for token in tokens:
words, tag = token.split('/')
words = words.split('_')
if tag == 'o':
for word in words:
sentence.append(word)
tags.append('O')
else:
for word in words:
sentence.append(word)
tags.append('B-' + tag)
for word in words[1:]:
tags.append('I-' + tag)
if len(sentence) > 126:
train_sents.append([sentence, tags])
test_sents = []
with open('data/test.txt', 'r', encoding='utf-8') as f:
for line in f.readlines():
data = line.strip().split('_')
if len(data) > 126:
test_sents.append(data)
return train_sents, test_sents
if __name__ == '__main__':
#sentences = change_to_bio()
#print("共有:{}条训练集中句子长度超过512".format(count))
#print("训练集中最长的句子长度为:{}".format(max_len))
gaopincidebiaoji()