forked from pbaylies/stylegan-encoder
-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_effnet.py
337 lines (298 loc) · 16.1 KB
/
train_effnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
"""
Trains a modified EfficientNet to generate approximate dlatents using examples from a trained StyleGAN.
Props to @SimJeg on GitHub for the original code this is based on, from this thread: https://github.com/Puzer/stylegan-encoder/issues/1#issuecomment-490469454
"""
import os
import math
import numpy as np
import pickle
import cv2
import argparse
import dnnlib
import config
import dnnlib.tflib as tflib
import tensorflow
import keras.backend as K
from efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, preprocess_input
from keras.layers import Input, LocallyConnected1D, Reshape, Permute, Conv2D, Add, Concatenate
from keras.models import Model, load_model
"""
Truncation method from @oneiroid
"""
def truncate_fancy(dlat, dlat_avg, model_scale=18, truncation_psi=0.7, minlayer=0, maxlayer=8, do_clip=False):
layer_idx = np.arange(model_scale)[np.newaxis, :, np.newaxis]
ones = np.ones(layer_idx.shape, dtype=np.float32)
coefs = np.where(layer_idx < maxlayer, truncation_psi * ones, ones)
if minlayer > 0:
coefs[0, :minlayer, :] = ones[0, :minlayer, :]
if do_clip:
return tflib.lerp_clip(dlat_avg, dlat, coefs).eval()
else:
return tflib.lerp(dlat_avg, dlat, coefs)
def truncate_normal(dlat, dlat_avg, truncation_psi=0.7):
return (dlat - dlat_avg) * truncation_psi + dlat_avg
def generate_dataset_main(n=10000, save_path=None, seed=None, model_res=1024, image_size=256, minibatch_size=32, truncation=0.7, fancy_truncation=False):
"""
Generates a dataset of 'n' images of shape ('size', 'size', 3) with random seed 'seed'
along with their dlatent vectors W of shape ('n', 512)
These datasets can serve to train an inverse mapping from X to W as well as explore the latent space
More variation added to latents; also, negative truncation added to balance these examples.
"""
n = n // 2 # this gets doubled because of negative truncation below
model_scale = int(2*(math.log(model_res,2)-1)) # For example, 1024 -> 18
Gs = load_Gs()
if (model_scale % 3 == 0):
mod_l = 3
else:
mod_l = 2
if seed is not None:
b = bool(np.random.RandomState(seed).randint(2))
Z = np.random.RandomState(seed).randn(n*mod_l, Gs.input_shape[1])
else:
b = bool(np.random.randint(2))
Z = np.random.randn(n*mod_l, Gs.input_shape[1])
if b:
mod_l = model_scale // 2
mod_r = model_scale // mod_l
if seed is not None:
Z = np.random.RandomState(seed).randn(n*mod_l, Gs.input_shape[1])
else:
Z = np.random.randn(n*mod_l, Gs.input_shape[1])
W = Gs.components.mapping.run(Z, None, minibatch_size=minibatch_size) # Use mapping network to get unique dlatents for more variation.
dlatent_avg = Gs.get_var('dlatent_avg') # [component]
if fancy_truncation:
W = np.append(truncate_fancy(W, dlatent_avg, model_scale, truncation), truncate_fancy(W, dlatent_avg, model_scale, -truncation), axis=0)
else:
W = np.append(truncate_normal(W, dlatent_avg, truncation), truncate_normal(W, dlatent_avg, -truncation), axis=0)
W = W[:, :mod_r]
W = W.reshape((n*2, model_scale, 512))
X = Gs.components.synthesis.run(W, randomize_noise=False, minibatch_size=minibatch_size, print_progress=True,
output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True))
X = np.array([cv2.resize(x, (image_size, image_size), interpolation = cv2.INTER_AREA) for x in X])
X = preprocess_input(X)
return W, X
def generate_dataset(n=10000, save_path=None, seed=None, model_res=1024, image_size=256, minibatch_size=16, truncation=0.7, fancy_truncation=False):
"""
Use generate_dataset_main() as a helper function.
Divides requests into batches to save memory.
"""
batch_size = 16
inc = n//batch_size
left = n-((batch_size-1)*inc)
W, X = generate_dataset_main(inc, save_path, seed, model_res, image_size, minibatch_size, truncation, fancy_truncation)
for i in range(batch_size-2):
aW, aX = generate_dataset_main(inc, save_path, seed, model_res, image_size, minibatch_size, truncation, fancy_truncation)
W = np.append(W, aW, axis=0)
aW = None
X = np.append(X, aX, axis=0)
aX = None
aW, aX = generate_dataset_main(left, save_path, seed, model_res, image_size, minibatch_size, truncation, fancy_truncation)
W = np.append(W, aW, axis=0)
aW = None
X = np.append(X, aX, axis=0)
aX = None
if save_path is not None:
prefix = '_{}_{}'.format(seed, n)
np.save(os.path.join(os.path.join(save_path, 'W' + prefix)), W)
np.save(os.path.join(os.path.join(save_path, 'X' + prefix)), X)
return W, X
def is_square(n):
return (n == int(math.sqrt(n) + 0.5)**2)
def get_effnet_model(save_path, model_res=1024, image_size=256, depth=1, size=3, activation='elu', loss='logcosh', optimizer='adam'):
if os.path.exists(save_path):
print('Loading model')
return load_model(save_path)
# Build model
print('Building model')
model_scale = int(2*(math.log(model_res,2)-1)) # For example, 1024 -> 18
if (size <= 0):
effnet = EfficientNetB0(include_top=False, weights='imagenet', input_shape=(image_size, image_size, 3))
if (size == 1):
effnet = EfficientNetB1(include_top=False, weights='imagenet', input_shape=(image_size, image_size, 3))
if (size == 2):
effnet = EfficientNetB2(include_top=False, weights='imagenet', input_shape=(image_size, image_size, 3))
if (size >= 3):
effnet = EfficientNetB3(include_top=False, weights='imagenet', input_shape=(image_size, image_size, 3))
layer_size = model_scale*8*8*8
if is_square(layer_size): # work out layer dimensions
layer_l = int(math.sqrt(layer_size)+0.5)
layer_r = layer_l
else:
layer_m = math.log(math.sqrt(layer_size),2)
layer_l = 2**math.ceil(layer_m)
layer_r = layer_size // layer_l
layer_l = int(layer_l)
layer_r = int(layer_r)
x_init = None
inp = Input(shape=(image_size, image_size, 3))
x = effnet(inp)
if (size < 1):
x = Conv2D(model_scale*8, 1, activation=activation)(x) # scale down
if (depth > 0):
x = Reshape((layer_r, layer_l))(x) # See https://github.com/OliverRichter/TreeConnect/blob/master/cifar.py - TreeConnect inspired layers instead of dense layers.
else:
if (depth < 1):
depth = 1
if (size <= 2):
x = Conv2D(model_scale*8*4, 1, activation=activation)(x) # scale down a bit
x = Reshape((layer_r*2, layer_l*2))(x) # See https://github.com/OliverRichter/TreeConnect/blob/master/cifar.py - TreeConnect inspired layers instead of dense layers.
else:
x = Reshape((384,256))(x) # full size for B3
while (depth > 0):
x = LocallyConnected1D(layer_r, 1, activation=activation)(x)
x = Permute((2, 1))(x)
x = LocallyConnected1D(layer_l, 1, activation=activation)(x)
x = Permute((2, 1))(x)
if x_init is not None:
x = Add()([x, x_init]) # add skip connection
x_init = x
depth-=1
if (size >= 2): # add unshared layers at end for different sections of the latent space
x_init = x
if layer_r % 3 == 0 and layer_l % 3 == 0:
a = LocallyConnected1D(layer_r, 1, activation=activation)(x)
b = LocallyConnected1D(layer_r, 1, activation=activation)(x)
c = LocallyConnected1D(layer_r, 1, activation=activation)(x)
a = Permute((2, 1))(a)
b = Permute((2, 1))(b)
c = Permute((2, 1))(c)
a = LocallyConnected1D(layer_l//3, 1, activation=activation)(a)
b = LocallyConnected1D(layer_l//3, 1, activation=activation)(b)
c = LocallyConnected1D(layer_l//3, 1, activation=activation)(c)
x = Concatenate()([a,b,c])
else:
a = LocallyConnected1D(layer_l, 1, activation=activation)(x)
b = LocallyConnected1D(layer_l, 1, activation=activation)(x)
a = Permute((2, 1))(a)
b = Permute((2, 1))(b)
a = LocallyConnected1D(layer_r//2, 1, activation=activation)(a)
b = LocallyConnected1D(layer_r//2, 1, activation=activation)(b)
x = Concatenate()([a,b])
x = Add()([x, x_init]) # add skip connection
x = Reshape((model_scale, 512))(x) # train against all dlatent values
model = Model(inputs=inp,outputs=x)
model.compile(loss=loss, metrics=[], optimizer=optimizer) # By default: adam optimizer, logcosh used for loss.
return model
def finetune_effnet(model, args):
"""
Finetunes an EfficientNet to predict W from X
Generate batches (X, W) of size 'batch_size', iterates 'n_epochs', and repeat while 'max_patience' is reached
on the test set. The model is saved every time a new best test loss is reached.
"""
save_path = args.model_path
model_res=args.model_res
image_size=args.image_size
batch_size=args.batch_size
test_size=args.test_size
max_patience=args.max_patience
n_epochs=args.epochs
seed=args.seed
minibatch_size=args.minibatch_size
truncation=args.truncation
fancy_truncation=args.fancy_truncation
use_ktrain=args.use_ktrain
ktrain_max_lr=args.ktrain_max_lr
ktrain_reduce_lr=args.ktrain_reduce_lr
ktrain_stop_early=args.ktrain_stop_early
assert image_size >= 224
# Create a test set
np.random.seed(seed)
print('Creating test set:')
W_test, X_test = generate_dataset(n=test_size, model_res=model_res, image_size=image_size, seed=seed, minibatch_size=minibatch_size, truncation=truncation, fancy_truncation=fancy_truncation)
# Iterate on batches of size batch_size
print('Generating training set:')
patience = 0
epoch = -1
best_loss = np.inf
#loss = model.evaluate(X_test, W_test)
#print('Initial test loss : {:.5f}'.format(loss))
while (patience <= max_patience):
W_train = X_train = None
W_train, X_train = generate_dataset(batch_size, model_res=model_res, image_size=image_size, seed=seed, minibatch_size=minibatch_size, truncation=truncation, fancy_truncation=fancy_truncation)
if use_ktrain:
print('Creating validation set:')
W_val, X_val = generate_dataset(n=test_size, model_res=model_res, image_size=image_size, seed=seed, minibatch_size=minibatch_size, truncation=truncation, fancy_truncation=fancy_truncation)
learner = ktrain.get_learner(model=model,
train_data=(X_train, W_train), val_data=(X_val, W_val),
workers=1, use_multiprocessing=False,
batch_size=minibatch_size)
#learner.lr_find() # simulate training to find good learning rate
#learner.lr_plot() # visually identify best learning rate
learner.autofit(ktrain_max_lr, checkpoint_folder='/tmp', reduce_on_plateau=ktrain_reduce_lr, early_stopping=ktrain_stop_early)
learner = None
print('Done with current validation set.')
model.fit(X_val, W_val, epochs=n_epochs, verbose=True, batch_size=minibatch_size)
else:
model.fit(X_train, W_train, epochs=n_epochs, verbose=True, batch_size=minibatch_size)
loss = model.evaluate(X_test, W_test, batch_size=minibatch_size)
if loss < best_loss:
print('New best test loss : {:.5f}'.format(loss))
patience = 0
best_loss = loss
else:
print('Test loss : {:.5f}'.format(loss))
patience += 1
if (patience > max_patience): # When done with test set, train with it and discard.
print('Done with current test set.')
model.fit(X_test, W_test, epochs=n_epochs, verbose=True, batch_size=minibatch_size)
print('Saving model.')
model.save(save_path)
parser = argparse.ArgumentParser(description='Train an EfficientNet to predict latent representations of images in a StyleGAN model from generated examples', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--model_url', default='https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ', help='Fetch a StyleGAN model to train on from this URL')
parser.add_argument('--model_res', default=1024, help='The dimension of images in the StyleGAN model', type=int)
parser.add_argument('--data_dir', default='data', help='Directory for storing the EfficientNet model')
parser.add_argument('--model_path', default='data/finetuned_effnet.h5', help='Save / load / create the EfficientNet model with this file path')
parser.add_argument('--model_depth', default=1, help='Number of TreeConnect layers to add after EfficientNet', type=int)
parser.add_argument('--model_size', default=1, help='Model size - 0 - small, 1 - medium, 2 - large, or 3 - full size.', type=int)
parser.add_argument('--use_ktrain', default=False, help='Use ktrain for training', type=bool)
parser.add_argument('--ktrain_max_lr', default=0.001, help='Maximum learning rate for ktrain', type=float)
parser.add_argument('--ktrain_reduce_lr', default=1, help='Patience for reducing learning rate after a plateau for ktrain', type=float)
parser.add_argument('--ktrain_stop_early', default=3, help='Patience for early stopping for ktrain', type=float)
parser.add_argument('--activation', default='elu', help='Activation function to use after EfficientNet')
parser.add_argument('--optimizer', default='adam', help='Optimizer to use')
parser.add_argument('--loss', default='logcosh', help='Loss function to use')
parser.add_argument('--use_fp16', default=False, help='Use 16-bit floating point', type=bool)
parser.add_argument('--image_size', default=256, help='Size of images for EfficientNet model', type=int)
parser.add_argument('--batch_size', default=2048, help='Batch size for training the EfficientNet model', type=int)
parser.add_argument('--test_size', default=512, help='Batch size for testing the EfficientNet model', type=int)
parser.add_argument('--truncation', default=0.7, help='Generate images using truncation trick', type=float)
parser.add_argument('--fancy_truncation', default=True, help='Use fancier truncation proposed by @oneiroid', type=float)
parser.add_argument('--max_patience', default=2, help='Number of iterations to wait while test loss does not improve', type=int)
parser.add_argument('--freeze_first', default=False, help='Start training with the pre-trained network frozen, then unfreeze', type=bool)
parser.add_argument('--epochs', default=2, help='Number of training epochs to run for each batch', type=int)
parser.add_argument('--minibatch_size', default=16, help='Size of minibatches for training and generation', type=int)
parser.add_argument('--seed', default=-1, help='Pick a random seed for reproducibility (-1 for no random seed selected)', type=int)
parser.add_argument('--loop', default=-1, help='Run this many iterations (-1 for infinite, halt with CTRL-C)', type=int)
args, other_args = parser.parse_known_args()
os.makedirs(args.data_dir, exist_ok=True)
if args.seed == -1:
args.seed = None
if args.use_fp16:
K.set_floatx('float16')
K.set_epsilon(1e-4)
if args.use_ktrain:
import ktrain
tflib.init_tf()
model = get_effnet_model(args.model_path, model_res=args.model_res, depth=args.model_depth, size=args.model_size, activation=args.activation, optimizer=args.optimizer, loss=args.loss)
with dnnlib.util.open_url(args.model_url, cache_dir=config.cache_dir) as f:
generator_network, discriminator_network, Gs_network = pickle.load(f)
def load_Gs():
return Gs_network
#K.get_session().run(tensorflow.global_variables_initializer())
if args.freeze_first:
model.layers[1].trainable = False
model.compile(loss=args.loss, metrics=[], optimizer=args.optimizer)
model.summary()
if args.freeze_first: # run a training iteration first while pretrained model is frozen, then unfreeze.
finetune_effnet(model, args)
model.layers[1].trainable = True
model.compile(loss=args.loss, metrics=[], optimizer=args.optimizer)
model.summary()
if args.loop < 0:
while True:
finetune_effnet(model, args)
else:
count = args.loop
while count > 0:
finetune_effnet(model, args)
count -= 1