diff --git a/srez_input.py b/srez_input.py index 4422a9e..e7b485e 100644 --- a/srez_input.py +++ b/srez_input.py @@ -12,7 +12,7 @@ def setup_inputs(sess, filenames, image_size=None, capacity_factor=3): filename_queue = tf.train.string_input_producer(filenames) key, value = reader.read(filename_queue) channels = 3 - image = tf.image.decode_jpeg(value, channels=channels, name="dataset_image") + image = tf.image.decode_image(value, channels=channels, name="dataset_image") image.set_shape([None, None, channels]) # Crop and other random augmentations diff --git a/srez_main.py b/srez_main.py index beaea9e..ad23473 100644 --- a/srez_main.py +++ b/srez_main.py @@ -22,13 +22,13 @@ tf.app.flags.DEFINE_integer('checkpoint_period', 10000, "Number of batches in between checkpoints") -tf.app.flags.DEFINE_string('dataset', 'dataset', +tf.app.flags.DEFINE_string('dataset', '/home/data/CelebA/Img/img_align_celeba_png.7z/img_align_celeba_png', "Path to the dataset directory.") tf.app.flags.DEFINE_float('epsilon', 1e-8, "Fuzz term to avoid numerical instability") -tf.app.flags.DEFINE_string('run', 'demo', +tf.app.flags.DEFINE_string('run', 'train', "Which operation to run. [demo|train]") tf.app.flags.DEFINE_float('gene_l1_factor', .90, @@ -61,7 +61,10 @@ tf.app.flags.DEFINE_string('train_dir', 'train', "Output folder where training logs are dumped.") -tf.app.flags.DEFINE_integer('train_time', 20, +tf.app.flags.DEFINE_string('test_dataset', 'test', + "testing dataset. Warning:file will be overided") + +tf.app.flags.DEFINE_integer('train_time', 60, "Time in minutes to train the model") def prepare_dirs(delete_train_dir=False): @@ -100,7 +103,7 @@ def setup_tensorflow(): random.seed(FLAGS.random_seed) np.random.seed(FLAGS.random_seed) - summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) + summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph) return sess, summary_writer @@ -133,6 +136,36 @@ def _demo(): # Execute demo srez_demo.demo1(sess) +def _test(): + # Load checkpoint + if not tf.gfile.IsDirectory(FLAGS.checkpoint_dir): + raise FileNotFoundError("Could not find folder `%s'" % (FLAGS.checkpoint_dir,)) + if not tf.gile.IsDirectory(FLAGS.test_dataset): + raise FileNotFoundError("Could not find folder `%s'" % (FLAGS.test_dataset,)) + # Setup global tensorflow state + sess = tf.Session() + + # Prepare directories + filenames = prepare_dirs(delete_train_dir=False) + + # Setup async input queues + features, labels = srez_input.setup_inputs(sess, filenames) + + # Create and initialize model + [gene_minput, gene_moutput, + gene_output, gene_var_list, + disc_real_output, disc_fake_output, disc_var_list] = \ + srez_model.create_model(sess, features, labels) + + # Restore variables from checkpoint + saver = tf.train.Saver() + filename = 'checkpoint_new.txt' + filename = os.path.join(FLAGS.checkpoint_dir, filename) + saver.restore(sess, filename) + + + + class TrainData(object): def __init__(self, dictionary): self.__dict__.update(dictionary) @@ -178,6 +211,8 @@ def _train(): train_data = TrainData(locals()) srez_train.train_model(train_data) + + def main(argv=None): # Training or showing off? @@ -185,6 +220,8 @@ def main(argv=None): _demo() elif FLAGS.run == 'train': _train() + elif FLAGS.run == 'test': + _test() if __name__ == '__main__': tf.app.run() diff --git a/srez_model.py b/srez_model.py index 3075ae7..04a778f 100644 --- a/srez_model.py +++ b/srez_model.py @@ -326,7 +326,7 @@ def _discriminator_model(sess, features, disc_input): mapsize = 3 layers = [64, 128, 256, 512] - old_vars = tf.all_variables() + old_vars = tf.global_variables() model = Model('DIS', 2*disc_input - 1) @@ -352,7 +352,7 @@ def _discriminator_model(sess, features, disc_input): model.add_conv2d(1, mapsize=1, stride=1, stddev_factor=stddev_factor) model.add_mean() - new_vars = tf.all_variables() + new_vars = tf.global_variables() disc_vars = list(set(new_vars) - set(old_vars)) return model.get_output(), disc_vars @@ -363,7 +363,7 @@ def _generator_model(sess, features, labels, channels): mapsize = 3 res_units = [256, 128, 96] - old_vars = tf.all_variables() + old_vars = tf.global_variables() # See Arxiv 1603.05027 model = Model('GEN', features) @@ -396,7 +396,7 @@ def _generator_model(sess, features, labels, channels): model.add_conv2d(channels, mapsize=1, stride=1, stddev_factor=1.) model.add_sigmoid() - new_vars = tf.all_variables() + new_vars = tf.global_variables() gene_vars = list(set(new_vars) - set(old_vars)) return model.get_output(), gene_vars @@ -449,7 +449,7 @@ def _downscale(images, K): def create_generator_loss(disc_output, gene_output, features): # I.e. did we fool the discriminator? - cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(disc_output, tf.ones_like(disc_output)) + cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_output, labels=tf.ones_like(disc_output)) gene_ce_loss = tf.reduce_mean(cross_entropy, name='gene_ce_loss') # I.e. does the result look like the feature? @@ -466,10 +466,10 @@ def create_generator_loss(disc_output, gene_output, features): def create_discriminator_loss(disc_real_output, disc_fake_output): # I.e. did we correctly identify the input as real or not? - cross_entropy_real = tf.nn.sigmoid_cross_entropy_with_logits(disc_real_output, tf.ones_like(disc_real_output)) + cross_entropy_real = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real_output, labels=tf.ones_like(disc_real_output)) disc_real_loss = tf.reduce_mean(cross_entropy_real, name='disc_real_loss') - cross_entropy_fake = tf.nn.sigmoid_cross_entropy_with_logits(disc_fake_output, tf.zeros_like(disc_fake_output)) + cross_entropy_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake_output,labels=tf.zeros_like(disc_fake_output)) disc_fake_loss = tf.reduce_mean(cross_entropy_fake, name='disc_fake_loss') return disc_real_loss, disc_fake_loss diff --git a/srez_train.py b/srez_train.py index 8f90343..b1d95d5 100644 --- a/srez_train.py +++ b/srez_train.py @@ -19,10 +19,10 @@ def _summarize_progress(train_data, feature, label, gene_output, batch, suffix, clipped = tf.maximum(tf.minimum(gene_output, 1.0), 0.0) - image = tf.concat(2, [nearest, bicubic, clipped, label]) + image = tf.concat([nearest, bicubic, clipped, label],2) image = image[0:max_samples,:,:,:] - image = tf.concat(0, [image[i,:,:,:] for i in range(max_samples)]) + image = tf.concat([image[i,:,:,:] for i in range(max_samples)],0) image = td.sess.run(image) filename = 'batch%06d_%s.png' % (batch, suffix) @@ -62,8 +62,8 @@ def _save_checkpoint(train_data, batch): def train_model(train_data): td = train_data - summaries = tf.merge_all_summaries() - td.sess.run(tf.initialize_all_variables()) + summaries = tf.summary.merge_all() + td.sess.run(tf.global_variables_initializer()) lrval = FLAGS.learning_rate_start start_time = time.time()