-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_mobilenet_v2.py
150 lines (121 loc) · 6.18 KB
/
train_mobilenet_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import glob
import json
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tqdm import tqdm
from model_v2 import MobileNetV2
def main():
data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path
image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set path
train_dir = os.path.join(image_path, "train")
validation_dir = os.path.join(image_path, "val")
assert os.path.exists(train_dir), "cannot find {}".format(train_dir)
assert os.path.exists(validation_dir), "cannot find {}".format(validation_dir)
im_height = 224
im_width = 224
batch_size = 16
epochs = 20
num_classes = 5
def pre_function(img):
# img = im.open('test.jpg')
# img = np.array(img).astype(np.float32)
img = img / 255.
img = (img - 0.5) * 2.0
return img
# data generator with data augmentation
train_image_generator = ImageDataGenerator(horizontal_flip=True,
preprocessing_function=pre_function)
validation_image_generator = ImageDataGenerator(preprocessing_function=pre_function)
train_data_gen = train_image_generator.flow_from_directory(directory=train_dir,
batch_size=batch_size,
shuffle=True,
target_size=(im_height, im_width),
class_mode='categorical')
total_train = train_data_gen.n
# get class dict
class_indices = train_data_gen.class_indices
# transform value and key of dict
inverse_dict = dict((val, key) for key, val in class_indices.items())
# write dict into json file
json_str = json.dumps(inverse_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
val_data_gen = validation_image_generator.flow_from_directory(directory=validation_dir,
batch_size=batch_size,
shuffle=False,
target_size=(im_height, im_width),
class_mode='categorical')
# img, _ = next(train_data_gen)
total_val = val_data_gen.n
print("using {} images for training, {} images for validation.".format(total_train,
total_val))
# create model except fc layer
feature = MobileNetV2(include_top=False)
# download weights 链接: https://pan.baidu.com/s/1YgFoIKHqooMrTQg_IqI2hA 密码: 2qht
pre_weights_path = './pretrain_weights.ckpt'
assert len(glob.glob(pre_weights_path+"*")), "cannot find {}".format(pre_weights_path)
feature.load_weights(pre_weights_path)
feature.trainable = False
feature.summary()
# add last fc layer
model = tf.keras.Sequential([feature,
tf.keras.layers.GlobalAvgPool2D(),
tf.keras.layers.Dropout(rate=0.5),
tf.keras.layers.Dense(num_classes),
tf.keras.layers.Softmax()])
model.summary()
# using keras low level api for training
loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
val_loss = tf.keras.metrics.Mean(name='val_loss')
val_accuracy = tf.keras.metrics.CategoricalAccuracy(name='val_accuracy')
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
output = model(images, training=True)
loss = loss_object(labels, output)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, output)
@tf.function
def val_step(images, labels):
output = model(images, training=False)
loss = loss_object(labels, output)
val_loss(loss)
val_accuracy(labels, output)
best_val_acc = 0.
for epoch in range(epochs):
train_loss.reset_states() # clear history info
train_accuracy.reset_states() # clear history info
val_loss.reset_states() # clear history info
val_accuracy.reset_states() # clear history info
# train
train_bar = tqdm(range(total_train // batch_size))
for step in train_bar:
images, labels = next(train_data_gen)
train_step(images, labels)
# print train process
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}, acc:{:.3f}".format(epoch + 1,
epochs,
train_loss.result(),
train_accuracy.result())
# validate
val_bar = tqdm(range(total_val // batch_size))
for step in val_bar:
val_images, val_labels = next(val_data_gen)
val_step(val_images, val_labels)
# print val process
val_bar.desc = "valid epoch[{}/{}] loss:{:.3f}, acc:{:.3f}".format(epoch + 1,
epochs,
val_loss.result(),
val_accuracy.result())
# only save best weights
if val_accuracy.result() > best_val_acc:
best_val_acc = val_accuracy.result()
model.save_weights("./save_weights/resMobileNetV2.ckpt", save_format="tf")
if __name__ == '__main__':
main()