-
Notifications
You must be signed in to change notification settings - Fork 5
/
optimizers.py
446 lines (397 loc) · 20.2 KB
/
optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
import numpy as np
import tensorflow as tf
import memory_saving_gradients
from tensorflow.python.ops import gradients
from pprint import pformat as pps
def create_train_op(params, loss=None, grads=None):
assert loss is None or grads is None
tf.logging.info("create_train_op(loss=%s, params=%s, grads=%s)", loss, pps(params), pps(grads))
lr = params["lr"]
global_step = tf.train.get_global_step()
assert global_step is not None
if "warmup_steps" in params.keys():
tf.logging.info('create_train_op: lr = cosine_decay_with_warmup(%s, %s, %s, warmup_steps=%s)', global_step, lr, params["max_steps"], params["warmup_steps"])
lr = cosine_decay_with_warmup(global_step, lr, params["max_steps"], warmup_steps=params["warmup_steps"])
if params["opt_name"] == "adam":
if not "weight_decay" in params.keys():
optimizer = tf.train.AdamOptimizer(
learning_rate=lr,
beta1=params["beta1"],
beta2=params["beta2"],
epsilon=params["epsilon"])
tf.logging.info('create_train_op: optimizer = tf.train.AdamOptimizer(learning_rate=%s, beta1=%s, beta2=%s, epsilon=%s)', lr, params["beta1"], params["beta2"], params["epsilon"])
else:
optimizer = tf.contrib.opt.AdamWOptimizer(
learning_rate=lr,
weight_decay=lr*params["weight_decay"],
beta1=params["beta1"],
beta2=params["beta2"],
epsilon=params["epsilon"])
tf.logging.info('create_train_op: optimizer = tf.train.AdamWOptimizer(learning_rate=%s, weight_decay=lr*%s, beta1=%s, beta2=%s, epsilon=%s)', lr, params["weight_decay"], params["beta1"], params["beta2"], params["epsilon"])
elif params["opt_name"] == "adafactor":
if params["decay_type"] == "adam":
decay_rate = adafactor_decay_rate_adam(params["beta2"])
elif params["decay_type"] == "pow":
decay_rate = adafactor_decay_rate_pow(params["decay_exponent"])
elif params["decay_type"] == "none":
decay_rate = None
else:
raise ValueError("unknown optimizer_adafactor_decay_type")
if not "weight_decay" in params.keys():
optimizer = AdafactorOptimizer(
learning_rate=lr,
decay_rate=decay_rate,
beta1=params["beta1"],
name="Adafactor")
tf.logging.info('create_train_op: optimizer = AdafactorOptimizer(learning_rate=%s, decay_rate=%s, beta1=%s)', lr, decay_rate, params["beta1"])
else:
AdafactorWOptimizer = tf.contrib.opt.extend_with_decoupled_weight_decay(AdafactorOptimizer)
optimizer = AdafactorWOptimizer(
weight_decay=params["weight_decay"] * lr,
learning_rate=lr,
decay_rate=decay_rate,
beta1=params["beta1"],
name="AdafactorW")
tf.logging.info('create_train_op: optimizer = AdafactorWOptimizer(weight_decay=lr*%s, learning_rate=%s, decay_rate=%s, beta1=%s)', params["weight_decay"], lr, decay_rate, params["beta1"])
else:
raise ValueError("Unknown optimizer type!")
if params["use_tpu"]:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # To update batchnorm, if present
only_train_transformer_layers = False if 'only_train_transformer_layers' not in params else params['only_train_transformer_layers']
def should_train_variable(v):
if only_train_transformer_layers:
if '/h' not in v.name and '/ln_f' not in v.name:
tf.logging.info("NOT training variable: %s", v)
return False
#for i in range(1):
# if ('/h%01d/' % i) in v.name:
# return False
# if ('/h%02d/' % i) in v.name:
# return False
tf.logging.info(" training variable: %s", v)
return True
train_vars = [v for v in tf.trainable_variables() if should_train_variable(v)]
non_train_vars = [v for v in tf.trainable_variables() if not should_train_variable(v)]
other_vars = [v for v in tf.global_variables() if v not in train_vars and v not in non_train_vars]
local_vars = [v for v in tf.local_variables()]
paramcount = lambda vs: sum([np.prod(v.shape.as_list()) for v in vs])
def logvars(variables, label, print_variables=False):
if print_variables:
tf.logging.info("%s (%s parameters): %s", label, paramcount(variables), pps(variables))
else:
tf.logging.info("%s (%s parameters)", label, paramcount(variables))
return variables
tf.logging.info("Training %d parameters (%.2fM) out of %d parameters (%.2fM)" % (
paramcount(train_vars), paramcount(train_vars)/(1024.0*1024.0),
paramcount(tf.trainable_variables()), paramcount(tf.trainable_variables())/(1024.0*1024.0),
))
tf.logging.info("---------")
tf.logging.info("Variable details:")
logvars(train_vars, "trainable variables", print_variables=True)
logvars(non_train_vars, "non-trainable variables", print_variables=True)
logvars(other_vars, "other global variables", print_variables=True)
logvars(local_vars, "other local variables", print_variables=True)
tf.logging.info("---------")
tf.logging.info("Variable summary:")
logvars(train_vars, "trainable variables")
logvars(non_train_vars, "non-trainable variables")
logvars(other_vars, "other global variables")
logvars(local_vars, "other local variables")
if grads is None:
tf.logging.info("---------")
tf.logging.info("Gradient options:")
#use_memory_saving_gradients=True
use_memory_saving_gradients=False if 'memory_saving_gradients' not in params else params['memory_saving_gradients']
colocate_gradients_with_ops=True if 'colocate_gradients' not in params else params['colocate_gradients']
gate_gradients=None
tf.logging.info("use_memory_saving_gradients=%s", use_memory_saving_gradients)
tf.logging.info("colocate_gradients_with_ops=%s", colocate_gradients_with_ops)
tf.logging.info("gate_gradients=%s", gate_gradients)
if use_memory_saving_gradients:
#grads = memory_saving_gradients.gradients(loss, train_vars, colocate_gradients_with_ops=colocate_gradients_with_ops, checkpoints='memory')
#grads = memory_saving_gradients.gradients_memory if i == 0 else memory_saving_gradients.gradients_speed
#grads = memory_saving_gradients.gradients_speed if i == 0 else memory_saving_gradients.gradients_speed
grads = memory_saving_gradients.gradients
grads = grads(loss, train_vars, colocate_grients_with_ops=colocate_gradients_with_ops, gate_gradients=gate_gradients)
else:
grads = gradients.gradients(loss, train_vars, colocate_gradients_with_ops=colocate_gradients_with_ops, gate_gradients=gate_gradients)
grads = list(zip(grads, train_vars))
disconnected_grads = [v for g, v in grads if g is None]
grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros
tf.logging.info("---------")
tf.logging.info("Gradient details:")
tf.logging.info("%s", pps(grads))
tf.logging.info("Disconnected gradients:")
tf.logging.info("%s", pps(disconnected_grads))
tf.logging.info("---------")
#train_op = optimizer.minimize(loss, global_step=global_step)
train_op = optimizer.apply_gradients(grads, global_step=global_step)
train_op = tf.group([train_op, update_ops], name="train_op")
return train_op
def cosine_decay_with_warmup(global_step,
learning_rate_base,
total_steps,
warmup_learning_rate=0.0,
warmup_steps=0,
hold_base_rate_steps=0,
name="learning_rate"):
if total_steps < warmup_steps:
raise ValueError('total_steps must be larger or equal to '
'warmup_steps.')
learning_rate = 0.5 * learning_rate_base * (1 + tf.cos(
np.pi *
(tf.cast(global_step, tf.float32) - warmup_steps - hold_base_rate_steps
) / float(total_steps - warmup_steps - hold_base_rate_steps)))
if hold_base_rate_steps > 0:
learning_rate = tf.where(global_step > warmup_steps + hold_base_rate_steps,
learning_rate, learning_rate_base)
if warmup_steps > 0:
if learning_rate_base < warmup_learning_rate:
raise ValueError('learning_rate_base must be larger or equal to '
'warmup_learning_rate.')
slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
warmup_rate = slope * tf.cast(global_step,
tf.float32) + warmup_learning_rate
learning_rate = tf.where(global_step < warmup_steps, warmup_rate,
learning_rate)
return tf.where(global_step > total_steps, 0.0, learning_rate,
name=name)
# Adafactor from tensor2tensor -------------------------------------------------------------
class AdafactorOptimizer(tf.train.Optimizer):
"""Optimizer that implements the Adafactor algorithm.
Adafactor is described in https://arxiv.org/abs/1804.04235.
Adafactor is most similar to Adam (Kingma and Ba), the major differences are:
1. For a two-dimensional AxB weight matrix, Adafactor uses only A+B auxiliary
parameters to maintain the second-moment estimator, instead of AB.
This is advantageous on memory-limited systems. In addition, beta1
(momentum) is set to zero by default, saving an additional auxiliary
parameter per weight. Variables with >=3 dimensions are treated as
collections of two-dimensional matrices - factorization is over the final
two dimensions.
2. Adafactor incorporates "update-clipping" - a scale-invariant analog of
gradient clipping. This adds stability
3. Adafactor does not require an external "learning rate". By default, it
incorporates a relative-update-scale schedule, corresponding to
inverse-square-root learning-rate-decay in ADAM. We hope this works well
for most applications.
ALGORITHM:
parameter -= absolute_update_scale * clip(grad / grad_scale)
where:
absolute_update_scale := relative_update_scale * parameter_scale
relative_update_scale := min((step_num + 1)**-0.5, 1e-2)
parameter_scale := max(rms(var)), epsilon2)
clip(x) := x / max(1.0, rms(x))
grad_scale := tf.sqrt(v) (v is the second-moment estimator)
The second-moment estimator v is maintained in a manner similar to Adam:
We initialize
```
if var is 2-dimensional:
v_r <- zeros([num_rows])
v_c <- zeros([num_cols])
if var is 0-dimensional or 1-dimensional:
v <- zeros(shape(var))
```
The update rule is as follows:
```
decay_rate = 1 - (step_num + 1) ^ -0.8
grad_squared = tf.square(grad) + epsilon1
if var is 2-dimensional:
v_r <- decay_rate * v_r + (1 - decay_rate) * reduce_mean(grad_squared, 1)
v_c <- decay_rate * v_c + (1 - decay_rate) * reduce_mean(grad_squared, 0)
v = outer_prod(v_r, v_c) / reduce_mean(v_r)
if var is 0-dimensional or 1-dimensional:
v <- decay_rate * v + (1 - decay_rate) * grad_squared
```
For variables with >=3 dimensions, we factorize the second-moment accumulator
over the final 2 dimensions. See the code for details.
Several parts of this algorithm are configurable from the initializer.
multiply_by_parameter_scale: If True, then compute absolute_update_scale
as described above. If False, let absolute_update_scale be the externally
supplied learning_rate.
learning_rate: represents relative_update_scale if
multiply_by_parameter_scale==True, or absolute_update_scale if
multiply_by_parameter_scale==False.
decay_rate: Decay rate of the second moment estimator (varies by step_num).
This should be set to a function such that:
1-1/(step_num + 1) <= decay_rate(step_num) < 1.0
beta1: enables momentum, as in Adam. Uses extra memory if nonzero.
clipping_threshold: should be >=1.0 or None for no update clipping
factored: whether to factor the second-moment estimator. True means
less memory usage.
"""
def __init__(self,
multiply_by_parameter_scale=True,
learning_rate=None,
decay_rate=None,
beta1=0.0,
clipping_threshold=1.0,
factored=True,
use_locking=False,
name="Adafactor",
epsilon1=1e-30,
epsilon2=1e-3):
"""Construct a new Adafactor optimizer.
See class comment.
Args:
multiply_by_parameter_scale: a boolean
learning_rate: an optional Scalar.
decay_rate: an optional Scalar.
beta1: a float value between 0 and 1
clipping_threshold: an optional float >= 1
factored: a boolean - whether to use factored second-moment estimator
for 2d variables
use_locking: If True use locks for update operations.
name: Optional name for the operations created when applying gradients.
Defaults to "AdafactorOptimizer".
epsilon1: Regularization constant for squared gradient.
epsilon2: Regularization constant for parameter scale.
Raises:
ValueError: if absolute_update_scale and relative_update_scale_fn are both
present or both absent.
"""
super(AdafactorOptimizer, self).__init__(use_locking, name)
self._multiply_by_parameter_scale = multiply_by_parameter_scale
if learning_rate is None:
learning_rate = self._learning_rate_default(multiply_by_parameter_scale)
self._learning_rate = learning_rate
if decay_rate is None:
decay_rate = self._decay_rate_default()
self._decay_rate = decay_rate
self._beta1 = beta1
self._clipping_threshold = clipping_threshold
self._factored = factored
self._epsilon1 = epsilon1
self._epsilon2 = epsilon2
def _should_use_factored_second_moment_estimate(self, shape):
"""Should we use a factored second moment estimator.
Based on the shape of the variable.
Args:
shape: a list of integers
Returns:
a boolean
"""
return self._factored and len(shape) >= 2
def _create_slots(self, var_list):
for var in var_list:
shape = var.get_shape().as_list()
if self._beta1:
self._zeros_slot(var, "m", self._name)
if self._should_use_factored_second_moment_estimate(shape):
r_val = tf.zeros(shape[:-1], dtype=tf.float32)
c_val = tf.zeros(shape[:-2] + shape[-1:], dtype=tf.float32)
self._get_or_make_slot(var, r_val, "vr", self._name)
self._get_or_make_slot(var, c_val, "vc", self._name)
else:
v_val = tf.zeros(shape, dtype=tf.float32)
self._get_or_make_slot(var, v_val, "v", self._name)
def _apply_dense(self, grad, var):
return self._resource_apply_dense(grad, var)
def _apply_sparse(self, grad, var):
return self._apply_dense(tf.convert_to_tensor(grad), var)
def _resource_apply_sparse(self, grad, handle, indices):
return self._resource_apply_dense(
tf.convert_to_tensor(tf.IndexedSlices(grad, indices, tf.shape(handle))),
handle)
def _parameter_scale(self, var):
"""Estimate the scale of the parameters from the current values.
We include a minimum value of 0.001 to give it a chance to escape 0
if it was zero-initialized.
Instead of using the value, we could impute the scale from the shape,
as initializers do.
Args:
var: a variable or Tensor.
Returns:
a Scalar
"""
return tf.maximum(reduce_rms(var), self._epsilon2)
def _resource_apply_dense(self, grad, handle):
var = handle
grad = tf.to_float(grad)
grad_squared = tf.square(grad) + self._epsilon1
grad_squared_mean = tf.reduce_mean(grad_squared)
decay_rate = self._decay_rate
update_scale = self._learning_rate
old_val = var
if var.dtype.base_dtype == tf.bfloat16:
old_val = tf.to_float(self._parameter_encoding.decode(old_val))
if self._multiply_by_parameter_scale:
update_scale *= tf.to_float(self._parameter_scale(old_val))
# HACK: Make things dependent on grad.
# This confounds the XLA rewriter and keeps it from fusing computations
# across different variables. This fusion is a bad for HBM usage, since
# it causes the gradients to persist in memory.
decay_rate += grad_squared_mean * 1e-30
update_scale += grad_squared_mean * 1e-30
# END HACK
mixing_rate = 1.0 - decay_rate
shape = var.get_shape().as_list()
updates = []
if self._should_use_factored_second_moment_estimate(shape):
grad_squared_row_mean = tf.reduce_mean(grad_squared, -1)
grad_squared_col_mean = tf.reduce_mean(grad_squared, -2)
vr = self.get_slot(var, "vr")
new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean)
vc = self.get_slot(var, "vc")
new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean)
vr_update = tf.assign(vr, new_vr, use_locking=self._use_locking)
vc_update = tf.assign(vc, new_vc, use_locking=self._use_locking)
updates = [vr_update, vc_update]
long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True)
r_factor = tf.rsqrt(new_vr / long_term_mean)
c_factor = tf.rsqrt(new_vc)
x = grad * tf.expand_dims(r_factor, -1) * tf.expand_dims(c_factor, -2)
else:
v = self.get_slot(var, "v")
new_v = decay_rate * v + mixing_rate * grad_squared
v_update = tf.assign(v, new_v, use_locking=self._use_locking)
updates = [v_update]
x = grad * tf.rsqrt(new_v)
if self._clipping_threshold is not None:
clipping_denom = tf.maximum(1.0, reduce_rms(x) / self._clipping_threshold)
x /= clipping_denom
subtrahend = update_scale * x
if self._beta1:
m = self.get_slot(var, "m")
new_m = self._beta1 * tf.to_float(m) + (1.0 - self._beta1) * subtrahend
subtrahend = new_m
new_m = cast_like(new_m, var)
updates.append(tf.assign(m, new_m, use_locking=self._use_locking))
new_val = tf.to_float(old_val) - subtrahend
var_update = tf.assign(var, new_val, use_locking=self._use_locking)
updates = [var_update] + updates
return tf.group(*updates)
def _decay_rate_default(self):
return adafactor_decay_rate_pow(0.8)
def _learning_rate_default(self, multiply_by_parameter_scale):
learning_rate = tf.minimum(tf.rsqrt(step_num() + 1.0), 0.01)
if not multiply_by_parameter_scale:
learning_rate *= 0.05
return learning_rate
def adafactor_decay_rate_adam(beta2):
t = tf.to_float(tf.train.get_or_create_global_step()) + 1.0
decay = beta2 * (1.0 - tf.pow(beta2, t - 1.0)) / (1.0 - tf.pow(beta2, t))
# decay = tf.cond(tf.equal(t, 1.0), lambda: beta2, lambda: decay)
return decay
def adafactor_decay_rate_pow(exponent):
return 1.0 - tf.pow((step_num() + 1.0), -exponent)
def step_num():
return tf.to_float(tf.train.get_or_create_global_step())
def reduce_rms(x):
return tf.sqrt(tf.reduce_mean(tf.square(x)))
def cast_like(x, y):
"""Cast x to y's dtype, if necessary."""
x = tf.convert_to_tensor(x)
y = tf.convert_to_tensor(y)
if x.dtype.base_dtype == y.dtype.base_dtype:
return x
cast_x = tf.cast(x, y.dtype)
if cast_x.device != x.device:
x_name = "(eager Tensor)"
try:
x_name = x.name
except AttributeError:
pass
tf.logging.warning("Cast for %s may induce copy from '%s' to '%s'", x_name,
x.device, cast_x.device)
return cast_x