forked from tensorpack/tensorpack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist-convnet.py
executable file
·127 lines (103 loc) · 5.1 KB
/
mnist-convnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: mnist-convnet.py
import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils import summary
"""
MNIST ConvNet example.
about 0.6% validation error after 30 epochs.
"""
IMAGE_SIZE = 28
class Model(ModelDesc):
def inputs(self):
"""
Define all the inputs (with type, shape, name) that the graph will need.
"""
return [tf.TensorSpec((None, IMAGE_SIZE, IMAGE_SIZE), tf.float32, 'input'),
tf.TensorSpec((None,), tf.int32, 'label')]
def build_graph(self, image, label):
"""This function should build the model which takes the input variables
and return cost at the end"""
# In tensorflow, inputs to convolution function are assumed to be
# NHWC. Add a single channel here.
image = tf.expand_dims(image, 3)
image = image * 2 - 1 # center the pixels values at zero
# The context manager `argscope` sets the default option for all the layers under
# this context. Here we use 32 channel convolution with shape 3x3
with argscope(Conv2D, kernel_size=3, activation=tf.nn.relu, filters=32):
logits = (LinearWrap(image)
.Conv2D('conv0')
.MaxPooling('pool0', 2)
.Conv2D('conv1')
.Conv2D('conv2')
.MaxPooling('pool1', 2)
.Conv2D('conv3')
.FullyConnected('fc0', 512, activation=tf.nn.relu)
.Dropout('dropout', rate=0.5)
.FullyConnected('fc1', 10, activation=tf.identity)())
# a vector of length B with loss of each sample
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss
correct = tf.cast(tf.nn.in_top_k(predictions=logits, targets=label, k=1), tf.float32, name='correct')
accuracy = tf.reduce_mean(correct, name='accuracy')
# This will monitor training error & accuracy (in a moving average fashion). The value will be automatically
# 1. written to tensosrboard
# 2. written to stat.json
# 3. printed after each epoch
train_error = tf.reduce_mean(1 - correct, name='train_error')
summary.add_moving_summary(train_error, accuracy)
# Use a regex to find parameters to apply weight decay.
# Here we apply a weight decay on all W (weight matrix) of all fc layers
# If you don't like regex, you can certainly define the cost in any other methods.
wd_cost = tf.multiply(1e-5,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
total_cost = tf.add_n([wd_cost, cost], name='total_cost')
summary.add_moving_summary(cost, wd_cost, total_cost)
# monitor histogram of all weight (of conv and fc layers) in tensorboard
summary.add_param_summary(('.*/W', ['histogram', 'rms']))
# the function should return the total cost to be optimized
return total_cost
def optimizer(self):
lr = tf.train.exponential_decay(
learning_rate=1e-3,
global_step=get_global_step_var(),
decay_steps=468 * 10,
decay_rate=0.3, staircase=True, name='learning_rate')
# This will also put the summary in tensorboard, stat.json and print in terminal,
# but this time without moving average
tf.summary.scalar('lr', lr)
return tf.train.AdamOptimizer(lr)
def get_data():
train = BatchData(dataset.Mnist('train'), 128)
test = BatchData(dataset.Mnist('test'), 256, remainder=True)
train = PrintData(train)
return train, test
if __name__ == '__main__':
# automatically setup the directory train_log/mnist-convnet for logging
logger.auto_set_dir()
dataset_train, dataset_test = get_data()
# How many iterations you want in each epoch.
# This len(data) is the default value.
steps_per_epoch = len(dataset_train)
# get the config which contains everything necessary in a training
config = TrainConfig(
model=Model(),
# The input source for training. FeedInput is slow, this is just for demo purpose.
# In practice it's best to use QueueInput or others. See tutorials for details.
data=FeedInput(dataset_train),
callbacks=[
ModelSaver(), # save the model after every epoch
InferenceRunner( # run inference(for validation) after every epoch
dataset_test, # the DataFlow instance used for validation
ScalarStats( # produce `val_accuracy` and `val_cross_entropy_loss`
['cross_entropy_loss', 'accuracy'], prefix='val')),
# MaxSaver has to come after InferenceRunner
MaxSaver('val_accuracy'), # save the model with highest accuracy
],
steps_per_epoch=steps_per_epoch,
max_epoch=100,
)
launch_train_with_config(config, SimpleTrainer())