-
Notifications
You must be signed in to change notification settings - Fork 237
/
train.py
124 lines (95 loc) · 3.57 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Utility used by the Network class to actually train.
Based on:
https://github.com/fchollet/keras/blob/master/examples/mnist_mlp.py
"""
from keras.datasets import mnist, cifar10
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.utils.np_utils import to_categorical
from keras.callbacks import EarlyStopping
# Helper: Early stopping.
early_stopper = EarlyStopping(patience=5)
def get_cifar10():
"""Retrieve the CIFAR dataset and process the data."""
# Set defaults.
nb_classes = 10
batch_size = 64
input_shape = (3072,)
# Get the data.
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.reshape(50000, 3072)
x_test = x_test.reshape(10000, 3072)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
# convert class vectors to binary class matrices
y_train = to_categorical(y_train, nb_classes)
y_test = to_categorical(y_test, nb_classes)
return (nb_classes, batch_size, input_shape, x_train, x_test, y_train, y_test)
def get_mnist():
"""Retrieve the MNIST dataset and process the data."""
# Set defaults.
nb_classes = 10
batch_size = 128
input_shape = (784,)
# Get the data.
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
# convert class vectors to binary class matrices
y_train = to_categorical(y_train, nb_classes)
y_test = to_categorical(y_test, nb_classes)
return (nb_classes, batch_size, input_shape, x_train, x_test, y_train, y_test)
def compile_model(network, nb_classes, input_shape):
"""Compile a sequential model.
Args:
network (dict): the parameters of the network
Returns:
a compiled network.
"""
# Get our network parameters.
nb_layers = network['nb_layers']
nb_neurons = network['nb_neurons']
activation = network['activation']
optimizer = network['optimizer']
model = Sequential()
# Add each layer.
for i in range(nb_layers):
# Need input shape for first layer.
if i == 0:
model.add(Dense(nb_neurons, activation=activation, input_shape=input_shape))
else:
model.add(Dense(nb_neurons, activation=activation))
model.add(Dropout(0.2)) # hard-coded dropout
# Output layer.
model.add(Dense(nb_classes, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer=optimizer,
metrics=['accuracy'])
return model
def train_and_score(network, dataset):
"""Train the model, return test loss.
Args:
network (dict): the parameters of the network
dataset (str): Dataset to use for training/evaluating
"""
if dataset == 'cifar10':
nb_classes, batch_size, input_shape, x_train, \
x_test, y_train, y_test = get_cifar10()
elif dataset == 'mnist':
nb_classes, batch_size, input_shape, x_train, \
x_test, y_train, y_test = get_mnist()
model = compile_model(network, nb_classes, input_shape)
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=10000, # using early stopping, so no real limit
verbose=0,
validation_data=(x_test, y_test),
callbacks=[early_stopper])
score = model.evaluate(x_test, y_test, verbose=0)
return score[1] # 1 is accuracy. 0 is loss.