-
Notifications
You must be signed in to change notification settings - Fork 34
/
optimizer.py
142 lines (132 loc) · 6.2 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 24 21:30:14 2019
@author: wmy
"""
import tensorflow as tf
from keras import backend as K
from keras.optimizers import Adam
from tqdm import tqdm
class AdamWithWeightsNormalization(Adam):
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
lr = self.lr
if self.initial_decay > 0:
lr *= (1. / (1. + self.decay * K.cast(self.iterations, K.floatx())))
pass
t = K.cast(self.iterations + 1, K.floatx())
lr_t = lr * K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))
shapes = [K.get_variable_shape(p) for p in params]
ms = [K.zeros(shape) for shape in shapes]
vs = [K.zeros(shape) for shape in shapes]
self.weights = [self.iterations] + ms + vs
for p, g, m, v in zip(params, grads, ms, vs):
# if a weight tensor (len > 1) use weight normalized parameterization
# this is the only part changed w.r.t. keras.optimizers.Adam
ps = K.get_variable_shape(p)
if len(ps)>1:
# get weight normalization parameters
V, V_norm, V_scaler, g_param, grad_g, grad_V = get_weightnorm_params_and_grads(p, g)
# Adam containers for the 'g' parameter
V_scaler_shape = K.get_variable_shape(V_scaler)
m_g = K.zeros(V_scaler_shape)
v_g = K.zeros(V_scaler_shape)
# update g parameters
m_g_t = (self.beta_1 * m_g) + (1. - self.beta_1) * grad_g
v_g_t = (self.beta_2 * v_g) + (1. - self.beta_2) * K.square(grad_g)
new_g_param = g_param - lr_t * m_g_t / (K.sqrt(v_g_t) + self.epsilon)
self.updates.append(K.update(m_g, m_g_t))
self.updates.append(K.update(v_g, v_g_t))
# update V parameters
m_t = (self.beta_1 * m) + (1. - self.beta_1) * grad_V
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(grad_V)
new_V_param = V - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append(K.update(m, m_t))
self.updates.append(K.update(v, v_t))
# if there are constraints we apply them to V, not W
if getattr(p, 'constraint', None) is not None:
new_V_param = p.constraint(new_V_param)
pass
# wn param updates --> W updates
add_weightnorm_param_updates(self.updates, new_V_param, new_g_param, p, V_scaler)
pass
else: # do optimization normally
m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append(K.update(m, m_t))
self.updates.append(K.update(v, v_t))
new_p = p_t
# apply constraints
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
pass
self.updates.append(K.update(p, new_p))
pass
pass
return self.updates
pass
def get_weightnorm_params_and_grads(p, g):
ps = K.get_variable_shape(p)
# construct weight scaler: V_scaler = g/||V||
V_scaler_shape = (ps[-1],) # assumes we're using tensorflow!
V_scaler = K.ones(V_scaler_shape) # init to ones, so effective parameters don't change
# get V parameters = ||V||/g * W
norm_axes = [i for i in range(len(ps) - 1)]
V = p / tf.reshape(V_scaler, [1] * len(norm_axes) + [-1])
# split V_scaler into ||V|| and g parameters
V_norm = tf.sqrt(tf.reduce_sum(tf.square(V), norm_axes))
g_param = V_scaler * V_norm
# get grad in V,g parameters
grad_g = tf.reduce_sum(g * V, norm_axes) / V_norm
grad_V = tf.reshape(V_scaler, [1] * len(norm_axes) + [-1]) * \
(g - tf.reshape(grad_g / V_norm, [1] * len(norm_axes) + [-1]) * V)
return V, V_norm, V_scaler, g_param, grad_g, grad_V
def add_weightnorm_param_updates(updates, new_V_param, new_g_param, W, V_scaler):
ps = K.get_variable_shape(new_V_param)
norm_axes = [i for i in range(len(ps) - 1)]
# update W and V_scaler
new_V_norm = tf.sqrt(tf.reduce_sum(tf.square(new_V_param), norm_axes))
new_V_scaler = new_g_param / new_V_norm
new_W = tf.reshape(new_V_scaler, [1] * len(norm_axes) + [-1]) * new_V_param
updates.append(K.update(W, new_W))
updates.append(K.update(V_scaler, new_V_scaler))
pass
def data_based_init(model, input):
# input can be dict, numpy array, or list of numpy arrays
if type(input) is dict:
feed_dict = input
pass
elif type(input) is list:
feed_dict = {tf_inp: np_inp for tf_inp,np_inp in zip(model.inputs,input)}
pass
else:
feed_dict = {model.inputs[0]: input}
pass
# add learning phase if required
if model.uses_learning_phase and K.learning_phase() not in feed_dict:
feed_dict.update({K.learning_phase(): 1})
pass
# get all layer name, output, weight, bias tuples
layer_output_weight_bias = []
for l in model.layers:
trainable_weights = l.trainable_weights
if len(trainable_weights) == 2:
assert(l.built)
W,b = trainable_weights
layer_output_weight_bias.append((l.name,l.get_output_at(0),W,b)) # if more than one node, only use the first
pass
pass
# iterate over our list and do data dependent init
sess = K.get_session()
pbar = tqdm(layer_output_weight_bias)
for l,o,W,b in pbar:
pbar.set_description(f"Init layer {l}")
m,v = tf.nn.moments(o, [i for i in range(len(o.get_shape())-1)])
s = tf.sqrt(v + 1e-10)
W_updated = W/tf.reshape(s,[1]*(len(W.get_shape())-1)+[-1])
updates = tf.group(W.assign(W_updated), b.assign((b-m)/s))
sess.run(updates, feed_dict)
pass
pass