-
Notifications
You must be signed in to change notification settings - Fork 1
/
model_runners.py
304 lines (258 loc) · 10.8 KB
/
model_runners.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
"""Defines Trainer, Evaluator and Inferencer class that wraps a TransformerXL
model and performs training, evaluation and inference, respectively.
"""
import functools
import os
import numpy as np
import tensorflow as tf
from commons import beam_search
from commons import utils
from commons.tokenization import EOS_ID
class TransformerXLModelTrainer(object):
"""Trains a TransformerXL model."""
def __init__(self,
model,
m_seq_len,
batch_size,
adaptive_embedding):
"""Constructor.
Args:
model: an instance of TransformerXL model.
m_seq_len: int scalar, length of the memory sequence.
batch_size: int scalar, batch size.
adaptive_embedding: bool scalar, whether to use adaptive embedding (and
softmax) layer.
"""
self._model = model
self._m_seq_len = m_seq_len
self._batch_size = batch_size
self._adaptive_embedding = adaptive_embedding
def train(self,
dataset,
optimizer,
ckpt,
ckpt_path,
num_iterations,
persist_per_iterations,
clip_norm=None,
log_per_iterations=100,
logdir='log'):
"""Run training iterations.
Args:
dataset: a tf.data.Dataset instance, the input data generator.
optimizer: a tf.keras.optimizer.Optimizer instance, applies gradient
updates.
ckpt: a tf.train.Checkpoint instance, saves or load weights to/from
checkpoint file.
ckpt_path: string scalar, the path to the directory that the checkpoint
files will be written to or loaded from.
num_iterations: int scalar, num of iterations to train the model.
persist_per_iterations: int scalar, saves weights to checkpoint files
every `persist_per_iterations` iterations.
clip_norm: float scalar, the max absolute value of the norm the gradient
tensors.
log_per_iterations: int scalar, prints log info every `log_per_iterations`
iterations.
logdir: string scalar, the directory that the tensorboard log data will
be written to.
"""
batch_size = self._batch_size
stack_size = self._model._stack_size
m_seq_len = self._m_seq_len
hidden_size = self._model._hidden_size
train_step_signature = [
tf.TensorSpec(shape=(batch_size, None), dtype='int32'),
tf.TensorSpec(shape=(batch_size, stack_size, m_seq_len, hidden_size),
dtype='float32'),
tf.TensorSpec(shape=(batch_size, None), dtype='int32')]
@tf.function(input_signature=train_step_signature)
def train_step(inputs, memories, labels):
with tf.GradientTape() as tape:
outputs, new_memories = self._model(inputs, memories, training=True)
if self._adaptive_embedding:
losses = self._model._embedding_layer(outputs, labels, mode='loss')
else:
logits = self._model._embedding_layer(outputs, mode='logits')
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
loss = tf.reduce_mean(losses)
trainable_variables = self._model.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
if clip_norm is not None:
gradients, norm = tf.clip_by_global_norm(gradients, clip_norm)
optimizer.apply_gradients(
zip(gradients, trainable_variables))
step = optimizer.iterations
lr = optimizer.learning_rate
return loss, new_memories, step - 1, lr
summary_writer = tf.summary.create_file_writer(logdir)
latest_ckpt = tf.train.latest_checkpoint(ckpt_path)
if latest_ckpt:
print('Restoring from checkpoint: %s ...' % latest_ckpt)
ckpt.restore(latest_ckpt)
else:
print('Training from scratch...')
memories = tf.zeros((batch_size, stack_size, m_seq_len, hidden_size))
for inputs, labels in dataset:
loss, memories, step, lr = train_step(inputs, memories, labels)
with summary_writer.as_default():
tf.summary.scalar('train_loss', loss, step=step)
tf.summary.scalar('learning_rate', lr, step=step)
if step.numpy() % log_per_iterations == 0:
print('global step: %d, loss: %f, learning rate:' %
(step.numpy(), loss.numpy()), lr.numpy())
if step.numpy() % persist_per_iterations == 0:
print('Saving checkpoint at global step %d ...' % step.numpy())
ckpt.save(os.path.join(ckpt_path, 'transformerxl'))
if step.numpy() == num_iterations:
break
class TransformerXLModelEvaluator(object):
"""Evaluates a trained TransformerXL model in terms of per-token perplexity.
"""
def __init__(self, model, m_seq_len, batch_size, vocab_size, adaptive_embedding):
"""Constructor.
Args:
model: an instance of TransformerXL model.
m_seq_len: int scalar, length of the memory sequence.
batch_size: int scalar, batch size.
adaptive_embedding: bool scalar, whether to use adaptive embedding (and
softmax) layer.
"""
self._model = model
self._m_seq_len = m_seq_len
self._batch_size = batch_size
self._adaptive_embedding = adaptive_embedding
def evaluate(self, dataset):
"""Iterate through the validation dataset and compute the perplexity.
Args:
dataset: a tf.data.Dataset instance, the input data generator.
Returns:
perplexity: float scalar, the average per-token perplexity.
"""
batch_size = self._batch_size
stack_size = self._model._stack_size
m_seq_len = self._m_seq_len
hidden_size = self._model._hidden_size
memories = tf.zeros((batch_size, stack_size, m_seq_len, hidden_size))
loss_list = []
def eval_step(inputs, memories, labels):
outputs, memories = self._model(inputs, memories, training=False)
if self._adaptive_embedding:
losses = self._model._embedding_layer(outputs, labels, mode='loss')
else:
logits = self._model._embedding_layer(outputs, mode='logits')
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
loss = tf.reduce_mean(losses)
return loss, memories
for inputs, labels in dataset:
loss, memories = eval_step(inputs, memories, labels)
loss_list.append(loss.numpy())
perplexity = np.exp(np.mean(loss_list))
return perplexity
class TransformerXLModelInferencer(object):
"""Make inference on the most likely (-ish) sequence of text that logically
and coherently follows a prompt (i.e. a piece of text that gives a "context")
based on a trained TransformerXL model.
"""
def __init__(self,
model,
m_seq_len,
batch_size,
adaptive_embedding,
decoding_method,
num_tokens=512,
beam_width=4,
alpha=0.6,
batch_memory_processing=False):
"""Constructor.
Args:
model: an instance of TransformerXL model.
m_seq_len: int scalar, length of the memory sequence.
batch_size: int scalar, batch_size.
adaptive_embedding: bool scalar, whether to use adaptive embedding (and
softmax) layer.
decoding_method: string scalar, decoding method. Must be "nucleus", 'topk'
or "beam_search".
num_tokens: int scalar, num of tokens to be generated.
beam_width: int scalar, number of beams for beam search. Ignored if
decoding method is not beam search.
alpha: float scalar, defining the strength of length normalization.
Ignored if decoding method is not beam search.
batch_memory_processing: bool scalar, whether to compute the sequence
embeddings in the memory segment batchwise, or one at a time.
"""
if decoding_method not in ('nucleus', 'topk', 'beam_search'):
raise ValueError('`decoding_method` must be either nucleus, topk or '
'beam_search, got %s' % decoding_method)
self._model = model
self._m_seq_len = m_seq_len
self._batch_size = batch_size
self._adaptive_embedding = adaptive_embedding
self._decoding_method = decoding_method
self._num_tokens = num_tokens
self._beam_width = beam_width
self._alpha = alpha
self._batch_memory_processing = batch_memory_processing
def infer(self, prompt_token_ids):
"""Generate text based on the prompted text.
Args:
prompt_token_ids: int tensor of shape [1, seq_len], token ids of the
prompted text.
Returns:
token_id_list: a list of integers, the token ids of the generated text.
"""
batch_size = self._batch_size
stack_size = self._model._stack_size
m_seq_len = self._m_seq_len
hidden_size = self._model._hidden_size
memories = tf.zeros((batch_size, stack_size, m_seq_len, hidden_size))
if self._batch_memory_processing:
_, memories = self._model(
prompt_token_ids[:, :-1], memories, training=False)
else:
for pos in prompt_token_ids[0, :-1]:
_, memories = self._model(pos[tf.newaxis, tf.newaxis], memories,
training=False)
if self._decoding_method != 'beam_search':
if self._decoding_method == 'nucleus':
sampling_fn = utils.nucleus_sampling
else:
sampling_fn = utils.topk_sampling
token_id_list = []
for i in range(self._num_tokens):
if i == 0:
init_ids = prompt_token_ids[:, -1:]
outputs, memories = self._model(init_ids, memories, training=False)
if self._adaptive_embedding:
scores = self._model._embedding_layer(outputs, mode='softmax')
else:
scores = self._model._embedding_layer(outputs, mode='logits')
scores = tf.nn.softmax(scores, axis=-1)
next_token_id = sampling_fn(scores.numpy()[0, 0])
token_id_list.append(next_token_id)
init_ids = tf.constant([[next_token_id]])
if next_token_id == EOS_ID:
break
else:
if self._adaptive_embedding:
scoring_fn = functools.partial(
self._model._embedding_layer, mode='softmax')
else:
def scoring_fn(inputs):
logits = self._model._embedding_layer(inputs, mode='logits')
return tf.nn.softmax(logits, axis=-1)
initial_ids = prompt_token_ids[:, -1]
decoding_fn = self._model._build_decoding_fn(scoring_fn)
decoding_cache = {'memories': memories}
bs = beam_search.BeamSearch(decoding_fn,
self._model._vocab_size,
batch_size,
self._beam_width,
self._alpha,
self._num_tokens,
EOS_ID,
logits_as_scores=False)
outputs, _, _ = bs.search(initial_ids, decoding_cache)
token_id_list = outputs[0, 0].numpy().tolist()
return token_id_list