-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_centerloss.py
99 lines (72 loc) · 2.88 KB
/
run_centerloss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Activation, add, Dense, Flatten, Dropout, Multiply, Embedding, Lambda
from keras.layers import Conv2D, MaxPooling2D,PReLU
from keras import backend as K
import numpy as np
import sys
from keras.callbacks import *
import visualization
from keras.optimizers import SGD, Adam
batch_size = 128
num_classes = 10
epochs = 50
isCenterloss = True
img_rows, img_cols = 28, 28
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
y_train_value = y_train
y_test_value = y_test
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
inputs = Input(shape=(28,28,1))
x = Conv2D(32, (3,3))(inputs)
x = PReLU()(x)
x = Conv2D(32, (3,3))(x)
x = PReLU()(x)
x = Conv2D(64, (3,3))(x)
x = PReLU()(x)
x = Conv2D(64, (5,5))(x)
x = PReLU()(x)
x = Conv2D(128, (5,5))(x)
x = PReLU()(x)
x = Conv2D(128, (5,5))(x)
x = PReLU()(x)
x = Flatten()(x)
x = Dense(2)(x)
ip1 = PReLU(name='ip1')(x)
ip2 = Dense(num_classes, activation='softmax')(ip1)
model = Model(inputs=inputs, outputs=[ip2])
model.compile(loss="categorical_crossentropy",
optimizer=SGD(lr=0.05),
metrics=['accuracy'])
if isCenterloss:
lambda_c = 0.2
input_target = Input(shape=(1,))
centers = Embedding(10,2)(input_target)
l2_loss = Lambda(lambda x: K.sum(K.square(x[0]-x[1][:,0]),1,keepdims=True),name='l2_loss')([ip1,centers])
model_centerloss = Model(inputs=[inputs,input_target],outputs=[ip2,l2_loss])
model_centerloss.compile(optimizer=SGD(lr=0.05), loss=["categorical_crossentropy", lambda y_true,y_pred: y_pred],loss_weights=[1,lambda_c],metrics=['accuracy'])
histories = visualization.Histories(isCenterloss)
if isCenterloss:
random_y_train = np.random.rand(x_train.shape[0],1)
random_y_test = np.random.rand(x_test.shape[0],1)
model_centerloss.fit([x_train,y_train_value], [y_train, random_y_train], batch_size=batch_size, epochs=epochs, verbose=1, validation_data=([x_test,y_test_value], [y_test,random_y_test]), callbacks=[histories])
else:
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_test,y_test), callbacks=[histories])