-
Notifications
You must be signed in to change notification settings - Fork 125
/
solver.py
466 lines (392 loc) · 18.7 KB
/
solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
"""solver.py"""
import warnings
warnings.filterwarnings("ignore")
import os
from tqdm import tqdm
import visdom
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import make_grid, save_image
from utils import cuda, grid2gif
from model import BetaVAE_H, BetaVAE_B
from dataset import return_data
def reconstruction_loss(x, x_recon, distribution):
batch_size = x.size(0)
assert batch_size != 0
if distribution == 'bernoulli':
recon_loss = F.binary_cross_entropy_with_logits(x_recon, x, size_average=False).div(batch_size)
elif distribution == 'gaussian':
x_recon = F.sigmoid(x_recon)
recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)
else:
recon_loss = None
return recon_loss
def kl_divergence(mu, logvar):
batch_size = mu.size(0)
assert batch_size != 0
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))
klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean(0, True)
dimension_wise_kld = klds.mean(0)
mean_kld = klds.mean(1).mean(0, True)
return total_kld, dimension_wise_kld, mean_kld
class DataGather(object):
def __init__(self):
self.data = self.get_empty_data_dict()
def get_empty_data_dict(self):
return dict(iter=[],
recon_loss=[],
total_kld=[],
dim_wise_kld=[],
mean_kld=[],
mu=[],
var=[],
images=[],)
def insert(self, **kwargs):
for key in kwargs:
self.data[key].append(kwargs[key])
def flush(self):
self.data = self.get_empty_data_dict()
class Solver(object):
def __init__(self, args):
self.use_cuda = args.cuda and torch.cuda.is_available()
self.max_iter = args.max_iter
self.global_iter = 0
self.z_dim = args.z_dim
self.beta = args.beta
self.gamma = args.gamma
self.C_max = args.C_max
self.C_stop_iter = args.C_stop_iter
self.objective = args.objective
self.model = args.model
self.lr = args.lr
self.beta1 = args.beta1
self.beta2 = args.beta2
if args.dataset.lower() == 'dsprites':
self.nc = 1
self.decoder_dist = 'bernoulli'
elif args.dataset.lower() == '3dchairs':
self.nc = 3
self.decoder_dist = 'gaussian'
elif args.dataset.lower() == 'celeba':
self.nc = 3
self.decoder_dist = 'gaussian'
else:
raise NotImplementedError
if args.model == 'H':
net = BetaVAE_H
elif args.model == 'B':
net = BetaVAE_B
else:
raise NotImplementedError('only support model H or B')
self.net = cuda(net(self.z_dim, self.nc), self.use_cuda)
self.optim = optim.Adam(self.net.parameters(), lr=self.lr,
betas=(self.beta1, self.beta2))
self.viz_name = args.viz_name
self.viz_port = args.viz_port
self.viz_on = args.viz_on
self.win_recon = None
self.win_kld = None
self.win_mu = None
self.win_var = None
if self.viz_on:
self.viz = visdom.Visdom(port=self.viz_port)
self.ckpt_dir = os.path.join(args.ckpt_dir, args.viz_name)
if not os.path.exists(self.ckpt_dir):
os.makedirs(self.ckpt_dir, exist_ok=True)
self.ckpt_name = args.ckpt_name
if self.ckpt_name is not None:
self.load_checkpoint(self.ckpt_name)
self.save_output = args.save_output
self.output_dir = os.path.join(args.output_dir, args.viz_name)
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir, exist_ok=True)
self.gather_step = args.gather_step
self.display_step = args.display_step
self.save_step = args.save_step
self.dset_dir = args.dset_dir
self.dataset = args.dataset
self.batch_size = args.batch_size
self.data_loader = return_data(args)
self.gather = DataGather()
def train(self):
self.net_mode(train=True)
self.C_max = Variable(cuda(torch.FloatTensor([self.C_max]), self.use_cuda))
out = False
pbar = tqdm(total=self.max_iter)
pbar.update(self.global_iter)
while not out:
for x in self.data_loader:
self.global_iter += 1
pbar.update(1)
x = Variable(cuda(x, self.use_cuda))
x_recon, mu, logvar = self.net(x)
recon_loss = reconstruction_loss(x, x_recon, self.decoder_dist)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
if self.objective == 'H':
beta_vae_loss = recon_loss + self.beta*total_kld
elif self.objective == 'B':
C = torch.clamp(self.C_max/self.C_stop_iter*self.global_iter, 0, self.C_max.data[0])
beta_vae_loss = recon_loss + self.gamma*(total_kld-C).abs()
self.optim.zero_grad()
beta_vae_loss.backward()
self.optim.step()
if self.viz_on and self.global_iter%self.gather_step == 0:
self.gather.insert(iter=self.global_iter,
mu=mu.mean(0).data, var=logvar.exp().mean(0).data,
recon_loss=recon_loss.data, total_kld=total_kld.data,
dim_wise_kld=dim_wise_kld.data, mean_kld=mean_kld.data)
if self.global_iter%self.display_step == 0:
pbar.write('[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f}'.format(
self.global_iter, recon_loss.data[0], total_kld.data[0], mean_kld.data[0]))
var = logvar.exp().mean(0).data
var_str = ''
for j, var_j in enumerate(var):
var_str += 'var{}:{:.4f} '.format(j+1, var_j)
pbar.write(var_str)
if self.objective == 'B':
pbar.write('C:{:.3f}'.format(C.data[0]))
if self.viz_on:
self.gather.insert(images=x.data)
self.gather.insert(images=F.sigmoid(x_recon).data)
self.viz_reconstruction()
self.viz_lines()
self.gather.flush()
if self.viz_on or self.save_output:
self.viz_traverse()
if self.global_iter%self.save_step == 0:
self.save_checkpoint('last')
pbar.write('Saved checkpoint(iter:{})'.format(self.global_iter))
if self.global_iter%50000 == 0:
self.save_checkpoint(str(self.global_iter))
if self.global_iter >= self.max_iter:
out = True
break
pbar.write("[Training Finished]")
pbar.close()
def viz_reconstruction(self):
self.net_mode(train=False)
x = self.gather.data['images'][0][:100]
x = make_grid(x, normalize=True)
x_recon = self.gather.data['images'][1][:100]
x_recon = make_grid(x_recon, normalize=True)
images = torch.stack([x, x_recon], dim=0).cpu()
self.viz.images(images, env=self.viz_name+'_reconstruction',
opts=dict(title=str(self.global_iter)), nrow=10)
self.net_mode(train=True)
def viz_lines(self):
self.net_mode(train=False)
recon_losses = torch.stack(self.gather.data['recon_loss']).cpu()
mus = torch.stack(self.gather.data['mu']).cpu()
vars = torch.stack(self.gather.data['var']).cpu()
dim_wise_klds = torch.stack(self.gather.data['dim_wise_kld'])
mean_klds = torch.stack(self.gather.data['mean_kld'])
total_klds = torch.stack(self.gather.data['total_kld'])
klds = torch.cat([dim_wise_klds, mean_klds, total_klds], 1).cpu()
iters = torch.Tensor(self.gather.data['iter'])
legend = []
for z_j in range(self.z_dim):
legend.append('z_{}'.format(z_j))
legend.append('mean')
legend.append('total')
if self.win_recon is None:
self.win_recon = self.viz.line(
X=iters,
Y=recon_losses,
env=self.viz_name+'_lines',
opts=dict(
width=400,
height=400,
xlabel='iteration',
title='reconsturction loss',))
else:
self.win_recon = self.viz.line(
X=iters,
Y=recon_losses,
env=self.viz_name+'_lines',
win=self.win_recon,
update='append',
opts=dict(
width=400,
height=400,
xlabel='iteration',
title='reconsturction loss',))
if self.win_kld is None:
self.win_kld = self.viz.line(
X=iters,
Y=klds,
env=self.viz_name+'_lines',
opts=dict(
width=400,
height=400,
legend=legend,
xlabel='iteration',
title='kl divergence',))
else:
self.win_kld = self.viz.line(
X=iters,
Y=klds,
env=self.viz_name+'_lines',
win=self.win_kld,
update='append',
opts=dict(
width=400,
height=400,
legend=legend,
xlabel='iteration',
title='kl divergence',))
if self.win_mu is None:
self.win_mu = self.viz.line(
X=iters,
Y=mus,
env=self.viz_name+'_lines',
opts=dict(
width=400,
height=400,
legend=legend[:self.z_dim],
xlabel='iteration',
title='posterior mean',))
else:
self.win_mu = self.viz.line(
X=iters,
Y=vars,
env=self.viz_name+'_lines',
win=self.win_mu,
update='append',
opts=dict(
width=400,
height=400,
legend=legend[:self.z_dim],
xlabel='iteration',
title='posterior mean',))
if self.win_var is None:
self.win_var = self.viz.line(
X=iters,
Y=vars,
env=self.viz_name+'_lines',
opts=dict(
width=400,
height=400,
legend=legend[:self.z_dim],
xlabel='iteration',
title='posterior variance',))
else:
self.win_var = self.viz.line(
X=iters,
Y=vars,
env=self.viz_name+'_lines',
win=self.win_var,
update='append',
opts=dict(
width=400,
height=400,
legend=legend[:self.z_dim],
xlabel='iteration',
title='posterior variance',))
self.net_mode(train=True)
def viz_traverse(self, limit=3, inter=2/3, loc=-1):
self.net_mode(train=False)
import random
decoder = self.net.decoder
encoder = self.net.encoder
interpolation = torch.arange(-limit, limit+0.1, inter)
n_dsets = len(self.data_loader.dataset)
rand_idx = random.randint(1, n_dsets-1)
random_img = self.data_loader.dataset.__getitem__(rand_idx)
random_img = Variable(cuda(random_img, self.use_cuda), volatile=True).unsqueeze(0)
random_img_z = encoder(random_img)[:, :self.z_dim]
random_z = Variable(cuda(torch.rand(1, self.z_dim), self.use_cuda), volatile=True)
if self.dataset == 'dsprites':
fixed_idx1 = 87040 # square
fixed_idx2 = 332800 # ellipse
fixed_idx3 = 578560 # heart
fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)
fixed_img1 = Variable(cuda(fixed_img1, self.use_cuda), volatile=True).unsqueeze(0)
fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]
fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)
fixed_img2 = Variable(cuda(fixed_img2, self.use_cuda), volatile=True).unsqueeze(0)
fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]
fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)
fixed_img3 = Variable(cuda(fixed_img3, self.use_cuda), volatile=True).unsqueeze(0)
fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]
Z = {'fixed_square':fixed_img_z1, 'fixed_ellipse':fixed_img_z2,
'fixed_heart':fixed_img_z3, 'random_img':random_img_z}
else:
fixed_idx = 0
fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)
fixed_img = Variable(cuda(fixed_img, self.use_cuda), volatile=True).unsqueeze(0)
fixed_img_z = encoder(fixed_img)[:, :self.z_dim]
Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z}
gifs = []
for key in Z.keys():
z_ori = Z[key]
samples = []
for row in range(self.z_dim):
if loc != -1 and row != loc:
continue
z = z_ori.clone()
for val in interpolation:
z[:, row] = val
sample = F.sigmoid(decoder(z)).data
samples.append(sample)
gifs.append(sample)
samples = torch.cat(samples, dim=0).cpu()
title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter)
if self.viz_on:
self.viz.images(samples, env=self.viz_name+'_traverse',
opts=dict(title=title), nrow=len(interpolation))
if self.save_output:
output_dir = os.path.join(self.output_dir, str(self.global_iter))
os.makedirs(output_dir, exist_ok=True)
gifs = torch.cat(gifs)
gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2)
for i, key in enumerate(Z.keys()):
for j, val in enumerate(interpolation):
save_image(tensor=gifs[i][j].cpu(),
filename=os.path.join(output_dir, '{}_{}.jpg'.format(key, j)),
nrow=self.z_dim, pad_value=1)
grid2gif(os.path.join(output_dir, key+'*.jpg'),
os.path.join(output_dir, key+'.gif'), delay=10)
self.net_mode(train=True)
def net_mode(self, train):
if not isinstance(train, bool):
raise('Only bool type is supported. True or False')
if train:
self.net.train()
else:
self.net.eval()
def save_checkpoint(self, filename, silent=True):
model_states = {'net':self.net.state_dict(),}
optim_states = {'optim':self.optim.state_dict(),}
win_states = {'recon':self.win_recon,
'kld':self.win_kld,
'mu':self.win_mu,
'var':self.win_var,}
states = {'iter':self.global_iter,
'win_states':win_states,
'model_states':model_states,
'optim_states':optim_states}
file_path = os.path.join(self.ckpt_dir, filename)
with open(file_path, mode='wb+') as f:
torch.save(states, f)
if not silent:
print("=> saved checkpoint '{}' (iter {})".format(file_path, self.global_iter))
def load_checkpoint(self, filename):
file_path = os.path.join(self.ckpt_dir, filename)
if os.path.isfile(file_path):
checkpoint = torch.load(file_path)
self.global_iter = checkpoint['iter']
self.win_recon = checkpoint['win_states']['recon']
self.win_kld = checkpoint['win_states']['kld']
self.win_var = checkpoint['win_states']['var']
self.win_mu = checkpoint['win_states']['mu']
self.net.load_state_dict(checkpoint['model_states']['net'])
self.optim.load_state_dict(checkpoint['optim_states']['optim'])
print("=> loaded checkpoint '{} (iter {})'".format(file_path, self.global_iter))
else:
print("=> no checkpoint found at '{}'".format(file_path))