This repository has been archived by the owner on Apr 3, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 666
/
srez_train.py
114 lines (84 loc) · 3.57 KB
/
srez_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
import os.path
import scipy.misc
import tensorflow as tf
import time
FLAGS = tf.app.flags.FLAGS
def _summarize_progress(train_data, feature, label, gene_output, batch, suffix, max_samples=8):
td = train_data
size = [label.shape[1], label.shape[2]]
nearest = tf.image.resize_nearest_neighbor(feature, size)
nearest = tf.maximum(tf.minimum(nearest, 1.0), 0.0)
bicubic = tf.image.resize_bicubic(feature, size)
bicubic = tf.maximum(tf.minimum(bicubic, 1.0), 0.0)
clipped = tf.maximum(tf.minimum(gene_output, 1.0), 0.0)
image = tf.concat(2, [nearest, bicubic, clipped, label])
image = image[0:max_samples,:,:,:]
image = tf.concat(0, [image[i,:,:,:] for i in range(max_samples)])
image = td.sess.run(image)
filename = 'batch%06d_%s.png' % (batch, suffix)
filename = os.path.join(FLAGS.train_dir, filename)
scipy.misc.toimage(image, cmin=0., cmax=1.).save(filename)
print(" Saved %s" % (filename,))
def _save_checkpoint(train_data, batch):
td = train_data
oldname = 'checkpoint_old.txt'
newname = 'checkpoint_new.txt'
oldname = os.path.join(FLAGS.checkpoint_dir, oldname)
newname = os.path.join(FLAGS.checkpoint_dir, newname)
# Delete oldest checkpoint
try:
tf.gfile.Remove(oldname)
tf.gfile.Remove(oldname + '.meta')
except:
pass
# Rename old checkpoint
try:
tf.gfile.Rename(newname, oldname)
tf.gfile.Rename(newname + '.meta', oldname + '.meta')
except:
pass
# Generate new checkpoint
saver = tf.train.Saver()
saver.save(td.sess, newname)
print(" Checkpoint saved")
def train_model(train_data):
td = train_data
summaries = tf.merge_all_summaries()
td.sess.run(tf.initialize_all_variables())
lrval = FLAGS.learning_rate_start
start_time = time.time()
done = False
batch = 0
assert FLAGS.learning_rate_half_life % 10 == 0
# Cache test features and labels (they are small)
test_feature, test_label = td.sess.run([td.test_features, td.test_labels])
while not done:
batch += 1
gene_loss = disc_real_loss = disc_fake_loss = -1.234
feed_dict = {td.learning_rate : lrval}
ops = [td.gene_minimize, td.disc_minimize, td.gene_loss, td.disc_real_loss, td.disc_fake_loss]
_, _, gene_loss, disc_real_loss, disc_fake_loss = td.sess.run(ops, feed_dict=feed_dict)
if batch % 10 == 0:
# Show we are alive
elapsed = int(time.time() - start_time)/60
print('Progress[%3d%%], ETA[%4dm], Batch [%4d], G_Loss[%3.3f], D_Real_Loss[%3.3f], D_Fake_Loss[%3.3f]' %
(int(100*elapsed/FLAGS.train_time), FLAGS.train_time - elapsed,
batch, gene_loss, disc_real_loss, disc_fake_loss))
# Finished?
current_progress = elapsed / FLAGS.train_time
if current_progress >= 1.0:
done = True
# Update learning rate
if batch % FLAGS.learning_rate_half_life == 0:
lrval *= .5
if batch % FLAGS.summary_period == 0:
# Show progress with test features
feed_dict = {td.gene_minput: test_feature}
gene_output = td.sess.run(td.gene_moutput, feed_dict=feed_dict)
_summarize_progress(td, test_feature, test_label, gene_output, batch, 'out')
if batch % FLAGS.checkpoint_period == 0:
# Save checkpoint
_save_checkpoint(td, batch)
_save_checkpoint(td, batch)
print('Finished training!')