forked from tslgithub/image_class
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Build_model.py
118 lines (104 loc) · 5.15 KB
/
Build_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python
#-*- coding:utf-8 -*-
# author:"tsl"
# email:"[email protected]"
# datetime:19-1-17 下午3:07
# software: PyCharm
from __future__ import print_function
import keras
from MODEL import MODEL,ResnetBuilder
import sys
sys.setrecursionlimit(10000)
from keras import backend as K
# import densenet #取消densenet模型
class Build_model(object):
def __init__(self,config):
self.train_data_path = config.train_data_path
self.checkpoints = config.checkpoints
self.normal_size = config.normal_size
self.channles = config.channles
self.epochs = config.epochs
self.batch_size = config.batch_size
self.classNumber = config.classNumber
self.model_name = config.model_name
self.lr = config.lr
self.config = config
# self.default_optimizers = config.default_optimizers
self.data_augmentation = config.data_augmentation
self.rat = config.rat
self.cut = config.cut
def model_confirm(self,choosed_model):
if choosed_model == 'VGG16':
model = MODEL(self.config).VGG16()
elif choosed_model == 'VGG19':
model = MODEL(self.config).VGG19()
elif choosed_model == 'AlexNet':
model = MODEL(self.config).AlexNet()
elif choosed_model == 'LeNet':
model = MODEL(self.config).LeNet()
elif choosed_model == 'ZF_Net':
model = MODEL(self.config).ZF_Net()
elif choosed_model == 'ResNet18':
model = ResnetBuilder().build_resnet18(self.config)
elif choosed_model == 'ResNet34':
model = ResnetBuilder().build_resnet34(self.config)
elif choosed_model == 'ResNet101':
model = ResnetBuilder().build_resnet101(self.config)
elif choosed_model == 'ResNet152':
model = ResnetBuilder().build_resnet152(self.config)
elif choosed_model =='mnist_net':
model = MODEL(self.config).mnist_net()
elif choosed_model == 'TSL16':
model = MODEL(self.config).TSL16()
elif choosed_model == 'ResNet50':
model = keras.applications.ResNet50(include_top=True,
weights=None,
input_tensor=None,
input_shape=(self.normal_size,self.normal_size,self.channles),
pooling='max',
classNumber=self.classNumber)
elif choosed_model == 'InceptionV3':
model = keras.applications.InceptionV3(include_top=True,
weights=None,
input_tensor=None,
input_shape=(self.normal_size,self.normal_size,self.channles),
pooling='max',
classNumber=self.classNumber)
elif choosed_model == 'Xception':
model = keras.applications.Xception(include_top=True,
weights=None,
input_tensor=None,
input_shape=(self.normal_size,self.normal_size,self.channles),
pooling='max',
classNumber=self.classNumber)
elif choosed_model == 'MobileNet':
model = keras.applications.MobileNet(include_top=True,
weights=None,
input_tensor=None,
input_shape=(self.normal_size,self.normal_size,self.channles),
pooling='max',
classNumber=self.classNumber)
# elif choosed_model == 'DenseNet':
# depth = 40
# nb_dense_block = 3
# growth_rate = 12
# nb_filter = 12
# bottleneck = False
# reduction = 0.0
# dropout_rate = 0.0
#
# img_dim = (self.channles, self.normal_size) if K.image_dim_ordering() == "th" else (
# self.normal_size, self.normal_size, self.channles)
#
# model = densenet.DenseNet(img_dim, classNumber=self.classNumber, depth=depth, nb_dense_block=nb_dense_block,
# growth_rate=growth_rate, nb_filter=nb_filter, dropout_rate=dropout_rate,
# bottleneck=bottleneck, reduction=reduction, weights=None)
return model
def model_compile(self,model):
adam = keras.optimizers.Adam(lr=self.lr)
model.compile(loss="categorical_crossentropy", optimizer=adam, metrics=["accuracy"]) # compile之后才会更新权重和模型
return model
def build_model(self):
model = self.model_confirm(self.model_name)
model = self.model_compile(model)
return model