-
Notifications
You must be signed in to change notification settings - Fork 5
/
bpe.py
194 lines (158 loc) · 6.49 KB
/
bpe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import heapq
import numpy as np
def load_subword_nmt_table(path):
"""
:param path: path to merge_table with subword-nmt format
"""
table = dict()
cur_priority = 1
with open(path) as f:
for line in f:
if '#version' in line:
continue
token_1, token_2 = line.rstrip('\n').split(' ')
table[(token_1, token_2)] = int(cur_priority)
cur_priority += 1
return table
def load_merge_table(path):
"""
:param path: path to merge_table
"""
table = dict()
with open(path) as f:
for line in f:
token_1, token_2, priority = line.split('\t')
table[(token_1, token_2)] = int(priority)
return table
def tokenize_word(merge_rules, word, dropout=0.0,
random_generator=np.random.RandomState(),
sentinels=['^', '$'],
regime='begin',
bpe_symbol='`',
always_merge_sentinels=True):
""" Tokenize word using bpe merge rules
:param merge_rules: dict [(a,b)] -> id, merge table, ids are in increasing order
:param word: string
:param dropout: float, dropout rate
:param random_generator: random generator with .rand() method
:param sentinels: list of two strings, beginning of word sentinel and end of word sentinel (empty string means that no corresponding sentinel is applied)
:param regime:
'begin' -- add bpe symbol to the beginning of bpe token
'end' -- add bpe symbol to the end of bpe token
:param bpe_symbol: str, could be one of '`', '@@', '▁'
:param always_merge_sentinels: bool, if True, sentinels are always concatenated
to the first and last characters before applying BPE merges (True is equivalent to subword-nmt>=0.2, False is equivalent to subword-nmt<0.2)
"""
# Subword tokens
sw_tokens = list(word)
# Add sentinels
if always_merge_sentinels:
sw_tokens = [sentinels[0] + sw_tokens[0]] + sw_tokens[1:]
sw_tokens = sw_tokens[:-1] + [sw_tokens[-1] + sentinels[1]]
else:
beg_sentinel = [sentinels[0]] if len(sentinels[0]) > 0 else []
end_sentinel = [sentinels[1]] if len(sentinels[1]) > 0 else []
sw_tokens = beg_sentinel + sw_tokens + end_sentinel
# Add start merges
# Heap with pairs (priority, position)
merge_heap = []
for pos in range(len(sw_tokens) - 1):
cur_nxt_pair = (sw_tokens[pos], sw_tokens[pos + 1])
if cur_nxt_pair in merge_rules:
cur_priority = merge_rules[cur_nxt_pair]
merge_heap.append([cur_priority, pos])
heapq.heapify(merge_heap)
sw_length = len(sw_tokens)
dropped_merges = []
while len(merge_heap):
cur_priority, cur_pos = heapq.heappop(merge_heap)
# Delete not valid merges
if cur_pos > sw_length - 2:
continue
cur = sw_tokens[cur_pos]
nxt = sw_tokens[cur_pos + 1]
if merge_rules.get((cur, nxt), None) != cur_priority:
continue
# Apply dropout
if random_generator.rand() < dropout:
dropped_merges.append([cur_priority, cur_pos])
continue
sw_tokens[cur_pos:cur_pos + 2] = [cur + nxt]
sw_length -= 1
for pair in merge_heap:
if pair[1] > cur_pos:
pair[1] -= 1
# Add dropped merges back
for priority, position in dropped_merges:
if position > cur_pos:
position -= 1
heapq.heappush(merge_heap, [priority, position])
dropped_merges = []
# Add new possible merge
new_cur = sw_tokens[cur_pos]
if cur_pos > 0:
prev = sw_tokens[cur_pos - 1]
if (prev, new_cur) in merge_rules:
heapq.heappush(merge_heap, [merge_rules[(prev, new_cur)], cur_pos - 1])
if cur_pos < (sw_length - 1):
new_next = sw_tokens[cur_pos + 1]
if (new_cur, new_next) in merge_rules:
heapq.heappush(merge_heap, [merge_rules[(new_cur, new_next)], cur_pos])
sw_tokens[0] = sw_tokens[0].replace(sentinels[0], '')
sw_tokens[-1] = sw_tokens[-1].replace(sentinels[1], '')
if regime == 'begin':
for i in range(1, sw_length):
sw_tokens[i] = bpe_symbol + sw_tokens[i]
if sw_tokens[0] == '':
sw_tokens = sw_tokens[1:]
sw_tokens[0] = sw_tokens[0].lstrip(bpe_symbol)
if sw_tokens[-1] == bpe_symbol:
sw_tokens.pop()
elif regime == 'end':
for i in range(sw_length -1):
sw_tokens[i] = sw_tokens[i] + bpe_symbol
if sw_tokens[0] == bpe_symbol:
sw_tokens.pop(0)
if sw_tokens[-1] == '':
sw_tokens = sw_tokens[:-1]
sw_tokens[-1] = sw_tokens[-1].rstrip(bpe_symbol)
return sw_tokens
def tokenize_text(rules, line, dropout=0.0, random_generator=np.random.RandomState(), **args):
return ' '.join([' '.join(tokenize_word(rules, word, dropout, random_generator, **args)) for word in line.split(' ')])
class BpeOnlineTokenizer:
"""
Apply bpe tokenization to str line
"""
def __init__(self, bpe_dropout_rate, merge_table, random_seed=None):
"""
:param bpe_dropout_rate: float [0,1)
:param merge_table: dict [(token_1, token_2)] -> priority
"""
self.random_generator = np.random.RandomState(random_seed)
self.bpe_dropout_rate = bpe_dropout_rate
self.merge_table = merge_table
def __call__(self, line, **args):
"""
:param line: str
:return:
"""
return tokenize_text(self.merge_table, line, self.bpe_dropout_rate, self.random_generator, **args)
class BpeOnlineParallelApplier:
"""
Apply bpe online to data in parallel
"""
def __init__(self, bpe_dropout_rates, merge_tables, random_seed=42):
"""
:param bpe_dropout_rate: float [0,1)
:param merge_table: dict [(token_1, token_2)] -> priority
"""
assert len(bpe_dropout_rates) == len(merge_tables)
self.bpe_appliers = []
for rate, table in zip(bpe_dropout_rates, merge_tables):
if table is not None:
self.bpe_appliers.append(BpeOnlineTokenizer(rate, table, random_seed))
else:
self.bpe_appliers.append(lambda x: x)
def __call__(self, lines):
assert len(self.bpe_appliers) == len(lines)
return tuple(applier(l) for applier, l in zip(self.bpe_appliers, lines))