-
Notifications
You must be signed in to change notification settings - Fork 90
/
eval_flowers.py
129 lines (102 loc) · 5.87 KB
/
eval_flowers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging
from tensorflow.contrib.framework.python.ops.variables import get_or_create_global_step
import xception_preprocessing
from xception import xception, xception_arg_scope
import time
import os
from train_flowers import get_split, load_batch
import matplotlib.pyplot as plt
from tensorflow.python.framework import graph_util
plt.style.use('ggplot')
slim = tf.contrib.slim
#State your log directory where you can retrieve your model
log_dir = './log'
#Create a new evaluation log directory to visualize the validation process
log_eval = './log_eval_test'
#State the dataset directory where the validation set is found
dataset_dir = './dataset'
#State the batch_size to evaluate each time, which can be a lot more than the training batch
batch_size = 36
#State the number of epochs to evaluate
num_epochs = 1
#Get the latest checkpoint file
checkpoint_file = tf.train.latest_checkpoint(log_dir)
def run():
#Create log_dir for evaluation information
if not os.path.exists(log_eval):
os.mkdir(log_eval)
#Just construct the graph from scratch again
with tf.Graph().as_default() as graph:
tf.logging.set_verbosity(tf.logging.INFO)
#Get the dataset first and load one batch of validation images and labels tensors. Set is_training as False so as to use the evaluation preprocessing
dataset = get_split('validation', dataset_dir)
images, raw_images, labels = load_batch(dataset, batch_size = batch_size, is_training = False)
#Create some information about the training steps
num_batches_per_epoch = dataset.num_samples / batch_size
num_steps_per_epoch = num_batches_per_epoch
#Now create the inference model but set is_training=False
with slim.arg_scope(xception_arg_scope()):
logits, end_points = xception(images, num_classes = dataset.num_classes, is_training = False)
# #get all the variables to restore from the checkpoint file and create the saver function to restore
variables_to_restore = slim.get_variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
def restore_fn(sess):
return saver.restore(sess, checkpoint_file)
#Just define the metrics to track without the loss or whatsoever
probabilities = end_points['Predictions']
predictions = tf.argmax(probabilities, 1)
accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(predictions, labels)
metrics_op = tf.group(accuracy_update)
#Create the global step and an increment op for monitoring
global_step = get_or_create_global_step()
global_step_op = tf.assign(global_step, global_step + 1) #no apply_gradient method so manually increasing the global_step
#Create a evaluation step function
def eval_step(sess, metrics_op, global_step):
'''
Simply takes in a session, runs the metrics op and some logging information.
'''
start_time = time.time()
_, global_step_count, accuracy_value = sess.run([metrics_op, global_step_op, accuracy])
time_elapsed = time.time() - start_time
#Log some information
logging.info('Global Step %s: Streaming Accuracy: %.4f (%.2f sec/step)', global_step_count, accuracy_value, time_elapsed)
return accuracy_value
#Define some scalar quantities to monitor
tf.summary.scalar('Validation_Accuracy', accuracy)
my_summary_op = tf.summary.merge_all()
#Get your supervisor
sv = tf.train.Supervisor(logdir = log_eval, summary_op = None, init_fn = restore_fn)
#Now we are ready to run in one session
with sv.managed_session() as sess:
for step in xrange(int(num_batches_per_epoch * num_epochs)):
#print vital information every start of the epoch as always
if step % num_batches_per_epoch == 0:
logging.info('Epoch: %s/%s', step / num_batches_per_epoch + 1, num_epochs)
logging.info('Current Streaming Accuracy: %.4f', sess.run(accuracy))
#Compute summaries every 10 steps and continue evaluating
if step % 10 == 0:
eval_step(sess, metrics_op = metrics_op, global_step = sv.global_step)
summaries = sess.run(my_summary_op)
sv.summary_computed(sess, summaries)
#Otherwise just run as per normal
else:
eval_step(sess, metrics_op = metrics_op, global_step = sv.global_step)
#At the end of all the evaluation, show the final accuracy
logging.info('Final Streaming Accuracy: %.4f', sess.run(accuracy))
#Now we want to visualize the last batch's images just to see what our model has predicted
raw_images, labels, predictions, probabilities = sess.run([raw_images, labels, predictions, probabilities])
for i in range(10):
image, label, prediction, probability = raw_images[i], labels[i], predictions[i], probabilities[i]
prediction_name, label_name = dataset.labels_to_name[prediction], dataset.labels_to_name[label]
text = 'Prediction: %s \n Ground Truth: %s \n Probability: %s' %(prediction_name, label_name, probability[prediction])
img_plot = plt.imshow(image)
#Set up the plot and hide axes
plt.title(text)
img_plot.axes.get_yaxis().set_ticks([])
img_plot.axes.get_xaxis().set_ticks([])
plt.show()
logging.info('Model evaluation has completed! Visit TensorBoard for more information regarding your evaluation.')
sv.saver.save(sess, sv.save_path, global_step = sv.global_step)
if __name__ == '__main__':
run()