forked from tslgithub/image_class
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
92 lines (80 loc) · 3.09 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
author:tslgithub
email:[email protected]
time:2018-12-12
msg: You can choose the following model to train your image, and just switch in config.py:
msg: You can choose the following model to train your image, and just switch in config.py:
VGG16,VGG19,InceptionV3,Xception,MobileNet,AlexNet,LeNet,ZF_Net,
ResNet18,ResNet34,ResNet50,ResNet101,ResNet152,mnist_net
TSL16
"""
from __future__ import print_function
from config import config
import sys,copy,shutil
import cv2
import os,time
from keras.preprocessing.image import img_to_array
import numpy as np
import tensorflow as tf
config1 = tf.ConfigProto()
config1.gpu_options.allow_growth = True
tf.Session(config=config1)
from Build_model import Build_model
class PREDICT(Build_model):
def __init__(self,config):
Build_model.__init__(self,config)
try:
className = sys.argv[2]
except:
print("use default className")
className = "dog"
self.className = className
self.test_data_path = os.path.join(config.test_data_path,self.className)
def classes_id(self):
with open('train_class_idx.txt','r') as f:
lines = f.readlines()
lines = [line.rstrip() for line in lines]
return lines
def mkdir(self,path):
if os.path.exists(path):
return path
os.mkdir(path)
return path
def Predict(self):
start = time.time()
model = Build_model(self.config).build_model()
if os.path.join(os.path.join(self.checkpoints,self.model_name),self.model_name+'.h5'):
print('weights is loaded')
else:
print('weights is not exist')
model.load_weights(os.path.join(os.path.join(self.checkpoints,self.model_name),self.model_name+'.h5'))
if(self.channles == 3):
data_list = list(
map(lambda x: cv2.resize(cv2.imread(os.path.join(self.test_data_path, x)),
(self.normal_size, self.normal_size)), os.listdir(self.test_data_path)))
elif(self.channles == 1):
data_list = list(
map(lambda x: cv2.resize(cv2.imread(os.path.join(self.test_data_path, x), 0),
(self.normal_size, self.normal_size)), os.listdir(self.test_data_path)))
i,j,tmp = 0,0,[]
for img in data_list:
img = np.array([img_to_array(img)],dtype='float')/255.0
pred = model.predict(img).tolist()[0]
label = self.classes_id()[pred.index(max(pred))]
confidence = max(pred)
print('predict label is: ',label)
print('predict confidect is: ',confidence)
if label != self.className:
print('____________________wrong label____________________', label)
i+=1
else:
j+=1
print('\naccuacy is:%.2f'%((j/ (len(data_list)))*100.0),"%")
print('Done')
end = time.time();
print("usg time:",end - start)
def main():
predict = PREDICT(config)
predict.Predict()
if __name__=='__main__':
main()