-
Notifications
You must be signed in to change notification settings - Fork 10
/
bert_pretraining.py
226 lines (173 loc) · 7.85 KB
/
bert_pretraining.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""An example of how to pretrain a transformer encoder with BERT."""
import collections
import itertools
import typing
import gensim.models.word2vec as word2vec
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import transformer
import transformer.bert as bert
__author__ = "Patrick Hohenecker"
__copyright__ = (
"Copyright (c) 2019, Patrick Hohenecker\n"
"All rights reserved.\n"
"\n"
"Redistribution and use in source and binary forms, with or without\n"
"modification, are permitted provided that the following conditions are met:\n"
"\n"
"1. Redistributions of source code must retain the above copyright notice, this\n"
" list of conditions and the following disclaimer.\n"
"2. Redistributions in binary form must reproduce the above copyright notice,\n"
" this list of conditions and the following disclaimer in the documentation\n"
" and/or other materials provided with the distribution.\n"
"\n"
"THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n"
"ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n"
"WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n"
"DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n"
"ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n"
"(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n"
"LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n"
"ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n"
"(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n"
"SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."
)
__license__ = "BSD-2-Clause"
__version__ = "2019.1"
__date__ = "23 Apr 2019"
__maintainer__ = "Patrick Hohenecker"
__email__ = "[email protected]"
__status__ = "Development"
# ==================================================================================================================== #
# C O N S T A N T S #
# ==================================================================================================================== #
Token = collections.namedtuple("Token", ["index", "word"])
"""This is used to store index-word pairs."""
DATA = [
"where the streets have no name",
"we ' re still building then burning down love",
"burning down love",
"and when i go there , i go there with you",
"it ' s all i can do"
]
"""list[str]: The already preprocessed training data."""
# SPECIAL TOKENS #####################################################################################################
SOS = Token(0, "<sos>")
"""The start-of-sequence token."""
EOS = Token(1, "<eos>")
"""The end-of-sequence token."""
PAD = Token(2, "<pad>")
"""The padding token."""
MASK = Token(3, "<mask>")
"""The mask token."""
# MODEL CONFIG #######################################################################################################
DIMENSIONS = (256, 32, 32)
"""tuple[int]: A tuple of d_model, d_k, d_v."""
DROPOUT_RATE = 0.1
"""float: The used dropout rate."""
EMBEDDING_SIZE = DIMENSIONS[0]
"""int: The used embedding size."""
NUM_LAYERS = 6
"""int: The number of layers in the trained transformer encoder."""
# TRAINING DETAILS ###################################################################################################
GPU = False # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< SET THIS TO True, IF YOU ARE USING A MACHINE WITH A GPU!
"""bool: Indicates whether to make use of a GPU."""
LEARNING_RATE = 0.0001
"""float: The used learning rate."""
NUM_EPOCHS = 500
"""int: The total number of training epochs."""
NUM_HEADS = 6
"""int: The number of attention heads to use."""
# ==================================================================================================================== #
# H E L P E R F U N C T I O N S #
# ==================================================================================================================== #
def prepare_data() -> typing.Tuple[typing.List[typing.List[str]], collections.OrderedDict]:
"""Preprocesses the training data, and creates the vocabulary.
Returns:
list[list[str]]: The training data as list of samples, each of which is a list of words.
collections.OrderedDict: The vocabulary as an ``OrderedDict`` from words to indices.
"""
# gather all words that appear in the data
all_words = set()
for sample in DATA:
all_words.update(sample.split(" "))
# create the vocabulary
vocab = collections.OrderedDict(
[
(SOS.word, SOS.index),
(EOS.word, EOS.index),
(PAD.word, PAD.index),
(MASK.word, MASK.index)
]
)
for idx, word in enumerate(sorted(all_words)):
vocab[word] = idx + 4
# split, add <sos>...<eos>, and pad the dataset
data = [[SOS.word] + sample.split(" ") + [EOS.word] for sample in DATA]
max_len = max(len(sample) for sample in data)
data = [sample + ([PAD.word] * (max_len - len(sample))) for sample in data]
return data, vocab
# ==================================================================================================================== #
# M A I N #
# ==================================================================================================================== #
def main():
# fetch the training data
data, vocab = prepare_data()
# create the word embeddings with word2vec and positional embeddings
emb_model = word2vec.Word2Vec(
sentences=data,
size=EMBEDDING_SIZE,
min_count=1
)
for word in vocab.keys():
if word not in emb_model.wv:
emb_model.wv[word] = np.zeros((EMBEDDING_SIZE,))
word_emb_mat = nn.Parameter(
data=torch.FloatTensor([emb_model[word] for word in vocab.keys()]),
requires_grad=False
)
word_emb = nn.Embedding(len(vocab), EMBEDDING_SIZE)
word_emb.weight = word_emb_mat
pos_emb = nn.Embedding(len(data[0]), EMBEDDING_SIZE)
pos_emb.weight.require_grad = True
# turn the dataset into a tensor of word indices
data = torch.LongTensor([[vocab[word] for word in sample] for sample in data])
# create the encoder, the pretraining loss, and the optimizer
encoder = transformer.Encoder(
NUM_LAYERS, # num_layers
NUM_HEADS, # num_heads
*DIMENSIONS, # dim_model / dim_keys / dim_values
DROPOUT_RATE, # residual_dropout
DROPOUT_RATE, # attention_dropout
PAD.index # pad_index
)
loss = bert.MLMLoss(
encoder,
word_emb,
pos_emb,
MASK.index
)
optimizer = optim.Adam(
itertools.chain(encoder.parameters(), loss.parameters()),
lr=LEARNING_RATE
)
# move to GPU, if possible
if GPU:
data = data.cuda()
encoder.cuda()
loss.cuda() # -> also moves embeddings to the GPU
# pretrain the encoder
for epoch in range(NUM_EPOCHS):
# compute the loss
optimizer.zero_grad()
current_loss = loss(data)
print("EPOCH", epoch + 1, ": LOSS =", current_loss.item())
# update the model
current_loss.backward()
optimizer.step()
if __name__ == "__main__":
main()