-
Notifications
You must be signed in to change notification settings - Fork 124
/
restore_model.py
62 lines (46 loc) · 1.86 KB
/
restore_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
'''Restores a model from checkpoint and evaluates it on CIFAR-10'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import time
from datetime import datetime
import data_helpers
import two_layer_fc
# Basic model parameters as external flags.
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer('hidden1', 120, 'Number of units in hidden layer 1.')
flags.DEFINE_string('train_dir', 'tf_logs',
'Directory to put the training data.')
flags.DEFINE_float('reg_constant', 0.1, 'Regularization constant.')
FLAGS._parse_flags()
print('\nParameters:')
for attr, value in sorted(FLAGS.__flags.items()):
print('{} = {}'.format(attr.upper(), value))
print()
IMAGE_PIXELS = 3072
CLASSES = 10
beginTime = time.time()
data_sets = data_helpers.load_data()
images_placeholder = tf.placeholder(tf.float32, shape=[None, IMAGE_PIXELS],
name='images')
labels_placeholder = tf.placeholder(tf.int64, shape=[None], name='image-labels')
logits = two_layer_fc.inference(images_placeholder, IMAGE_PIXELS,
FLAGS.hidden1, CLASSES, reg_constant=FLAGS.reg_constant)
global_step = tf.Variable(0, name="global_step", trainable=False)
accuracy = two_layer_fc.evaluation(logits, labels_placeholder)
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
if ckpt and ckpt.model_checkpoint_path:
print('Restoring variables from checkpoint')
saver.restore(sess, ckpt.model_checkpoint_path)
current_step = tf.train.global_step(sess, global_step)
print('Current step: {}'.format(current_step))
print('Test accuracy {:g}'.format(accuracy.eval(
feed_dict={ images_placeholder: data_sets['images_test'],
labels_placeholder: data_sets['labels_test']}
)))
endTime = time.time()
print('Total time: {:5.2f}s'.format(endTime - beginTime))