Skip to content
This repository has been archived by the owner on Apr 3, 2022. It is now read-only.

api compatible with tensorflow 1.0 and some other modifies #25

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion srez_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def setup_inputs(sess, filenames, image_size=None, capacity_factor=3):
filename_queue = tf.train.string_input_producer(filenames)
key, value = reader.read(filename_queue)
channels = 3
image = tf.image.decode_jpeg(value, channels=channels, name="dataset_image")
image = tf.image.decode_image(value, channels=channels, name="dataset_image")
image.set_shape([None, None, channels])

# Crop and other random augmentations
Expand Down
45 changes: 41 additions & 4 deletions srez_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
tf.app.flags.DEFINE_integer('checkpoint_period', 10000,
"Number of batches in between checkpoints")

tf.app.flags.DEFINE_string('dataset', 'dataset',
tf.app.flags.DEFINE_string('dataset', '/home/data/CelebA/Img/img_align_celeba_png.7z/img_align_celeba_png',
"Path to the dataset directory.")

tf.app.flags.DEFINE_float('epsilon', 1e-8,
"Fuzz term to avoid numerical instability")

tf.app.flags.DEFINE_string('run', 'demo',
tf.app.flags.DEFINE_string('run', 'train',
"Which operation to run. [demo|train]")

tf.app.flags.DEFINE_float('gene_l1_factor', .90,
Expand Down Expand Up @@ -61,7 +61,10 @@
tf.app.flags.DEFINE_string('train_dir', 'train',
"Output folder where training logs are dumped.")

tf.app.flags.DEFINE_integer('train_time', 20,
tf.app.flags.DEFINE_string('test_dataset', 'test',
"testing dataset. Warning:file will be overided")

tf.app.flags.DEFINE_integer('train_time', 60,
"Time in minutes to train the model")

def prepare_dirs(delete_train_dir=False):
Expand Down Expand Up @@ -100,7 +103,7 @@ def setup_tensorflow():
random.seed(FLAGS.random_seed)
np.random.seed(FLAGS.random_seed)

summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

return sess, summary_writer

Expand Down Expand Up @@ -133,6 +136,36 @@ def _demo():
# Execute demo
srez_demo.demo1(sess)

def _test():
# Load checkpoint
if not tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
raise FileNotFoundError("Could not find folder `%s'" % (FLAGS.checkpoint_dir,))
if not tf.gile.IsDirectory(FLAGS.test_dataset):
raise FileNotFoundError("Could not find folder `%s'" % (FLAGS.test_dataset,))
# Setup global tensorflow state
sess = tf.Session()

# Prepare directories
filenames = prepare_dirs(delete_train_dir=False)

# Setup async input queues
features, labels = srez_input.setup_inputs(sess, filenames)

# Create and initialize model
[gene_minput, gene_moutput,
gene_output, gene_var_list,
disc_real_output, disc_fake_output, disc_var_list] = \
srez_model.create_model(sess, features, labels)

# Restore variables from checkpoint
saver = tf.train.Saver()
filename = 'checkpoint_new.txt'
filename = os.path.join(FLAGS.checkpoint_dir, filename)
saver.restore(sess, filename)




class TrainData(object):
def __init__(self, dictionary):
self.__dict__.update(dictionary)
Expand Down Expand Up @@ -178,13 +211,17 @@ def _train():
train_data = TrainData(locals())
srez_train.train_model(train_data)



def main(argv=None):
# Training or showing off?

if FLAGS.run == 'demo':
_demo()
elif FLAGS.run == 'train':
_train()
elif FLAGS.run == 'test':
_test()

if __name__ == '__main__':
tf.app.run()
14 changes: 7 additions & 7 deletions srez_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def _discriminator_model(sess, features, disc_input):
mapsize = 3
layers = [64, 128, 256, 512]

old_vars = tf.all_variables()
old_vars = tf.global_variables()

model = Model('DIS', 2*disc_input - 1)

Expand All @@ -352,7 +352,7 @@ def _discriminator_model(sess, features, disc_input):
model.add_conv2d(1, mapsize=1, stride=1, stddev_factor=stddev_factor)
model.add_mean()

new_vars = tf.all_variables()
new_vars = tf.global_variables()
disc_vars = list(set(new_vars) - set(old_vars))

return model.get_output(), disc_vars
Expand All @@ -363,7 +363,7 @@ def _generator_model(sess, features, labels, channels):
mapsize = 3
res_units = [256, 128, 96]

old_vars = tf.all_variables()
old_vars = tf.global_variables()

# See Arxiv 1603.05027
model = Model('GEN', features)
Expand Down Expand Up @@ -396,7 +396,7 @@ def _generator_model(sess, features, labels, channels):
model.add_conv2d(channels, mapsize=1, stride=1, stddev_factor=1.)
model.add_sigmoid()

new_vars = tf.all_variables()
new_vars = tf.global_variables()
gene_vars = list(set(new_vars) - set(old_vars))

return model.get_output(), gene_vars
Expand Down Expand Up @@ -449,7 +449,7 @@ def _downscale(images, K):

def create_generator_loss(disc_output, gene_output, features):
# I.e. did we fool the discriminator?
cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(disc_output, tf.ones_like(disc_output))
cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_output, labels=tf.ones_like(disc_output))
gene_ce_loss = tf.reduce_mean(cross_entropy, name='gene_ce_loss')

# I.e. does the result look like the feature?
Expand All @@ -466,10 +466,10 @@ def create_generator_loss(disc_output, gene_output, features):

def create_discriminator_loss(disc_real_output, disc_fake_output):
# I.e. did we correctly identify the input as real or not?
cross_entropy_real = tf.nn.sigmoid_cross_entropy_with_logits(disc_real_output, tf.ones_like(disc_real_output))
cross_entropy_real = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real_output, labels=tf.ones_like(disc_real_output))
disc_real_loss = tf.reduce_mean(cross_entropy_real, name='disc_real_loss')

cross_entropy_fake = tf.nn.sigmoid_cross_entropy_with_logits(disc_fake_output, tf.zeros_like(disc_fake_output))
cross_entropy_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake_output,labels=tf.zeros_like(disc_fake_output))
disc_fake_loss = tf.reduce_mean(cross_entropy_fake, name='disc_fake_loss')

return disc_real_loss, disc_fake_loss
Expand Down
8 changes: 4 additions & 4 deletions srez_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def _summarize_progress(train_data, feature, label, gene_output, batch, suffix,

clipped = tf.maximum(tf.minimum(gene_output, 1.0), 0.0)

image = tf.concat(2, [nearest, bicubic, clipped, label])
image = tf.concat([nearest, bicubic, clipped, label],2)

image = image[0:max_samples,:,:,:]
image = tf.concat(0, [image[i,:,:,:] for i in range(max_samples)])
image = tf.concat([image[i,:,:,:] for i in range(max_samples)],0)
image = td.sess.run(image)

filename = 'batch%06d_%s.png' % (batch, suffix)
Expand Down Expand Up @@ -62,8 +62,8 @@ def _save_checkpoint(train_data, batch):
def train_model(train_data):
td = train_data

summaries = tf.merge_all_summaries()
td.sess.run(tf.initialize_all_variables())
summaries = tf.summary.merge_all()
td.sess.run(tf.global_variables_initializer())

lrval = FLAGS.learning_rate_start
start_time = time.time()
Expand Down