-
Notifications
You must be signed in to change notification settings - Fork 2
/
rd_test.py
128 lines (100 loc) · 4.69 KB
/
rd_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
This script is designed for testing the model trained by rd_train.py
Created on Sun Feb 16 2018
@ Author: Bo Peng
@ University of Wisconsin - Madison
@ Project: Road Extraction
"""
import dataProc as dp
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from scipy.misc import imsave
"""define data dimensions"""
# raw image size
im_width = 1500
im_height = 1500
# image patch size
im_patch_width = 64
im_patch_height= 64
nChannel = 3
# label patch size
lb_patch_width = 16
lb_patch_height= 16
# number of raw images used for training
num_rawIm_test = 0
# batch size
batch_size_test = int((im_width - im_patch_width) / lb_patch_width + 1)
"""Data Preprocessing: Normalization"""
# load raw images and labels for testing
im_test, lb_test = dp.load_data('./data_sub/test', num_rawIm_test)
print('Number of test images loaded: ', num_rawIm_test)
# change labels data type to int32 and set 255 to be 1 s
lb_test = np.array(lb_test)
lb_test = lb_test.astype(np.int32)
lb_test = [lb / 255 for lb in lb_test]
print('Label data type changed.')
# compute the coordinates of patch center points for one single test image
# note that here 'patch_cpt_test' involves only one image, which is different from that for training and validation
patch_cpt_test = dp.patch_center_point(im_height,
im_width,
im_patch_height,
im_patch_width,
lb_patch_height,
lb_patch_width,
1)
# number of patches in one single test image
num_patch_test = len(patch_cpt_test)
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./model_save/rd_model.ckpt-27800.meta')
saver.restore(sess, tf.train.latest_checkpoint('./model_save/'))
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("x:0")
y_true = graph.get_tensor_by_name("y_true:0")
y_pred = graph.get_tensor_by_name("y_pred:0")
acc = graph.get_tensor_by_name("acc:0")
# average accuracy for all test images
acc_test_avg = []
# loop over test set for prediction one by one
for k in range(len(im_test)):
print('Predition for test image #{0}'.format(k))
# number of iterations for this single image
num_iterations = int(num_patch_test / batch_size_test + 0.5)
# accuracy for each batch after each iteration
acc_test_batch = []
# prediction for each batch
prediction_batch = []
for it in range(num_iterations):
# extract image and label patches in current batch for test data
# note that the first 2 arguments should be lists
im_patch_batch_test, lb_patch_batch_test = dp.data_batch([im_test[k]],
[lb_test[k]],
patch_cpt_test,
im_patch_height,
im_patch_width,
lb_patch_height,
lb_patch_width,
batch_size_test,
it)
# patch normalization
im_patch_batch_test = [dp.image_normalize(im) for im in im_patch_batch_test]
# feed data
Feed_Dict_Test = {x: im_patch_batch_test, y_true: lb_patch_batch_test}
# accuracy for this batch
acc_test_batch.append(sess.run(acc, feed_dict=Feed_Dict_Test))
# store the prediction for current batch
prediction_batch.append(sess.run(y_pred, feed_dict=Feed_Dict_Test))
if it % 10 == 0:
msg = "Test #{0}...Iteration #{1}...Batch acc: {2:>6.1%}"
print(msg.format(k, it, acc_test_batch[it]))
# compute the average accuracy for current test image
acc_test_avg.append(np.average(np.array(acc_test_batch)))
print('Test #{0}...Avg Acc: {1:>6.1%}'.format(k, acc_test_avg[k]))
# mosaiking all batches into one entire image
im_test_prediction = dp.pred_mosaic(prediction_batch, patch_cpt_test)
# write output_label to image file
imsave('./test_prediction/test_pred_{0}.tif'.format(k), im_test_prediction)
# save the prediction accuracy for all test images
np.savetxt('./test_prediction/testset_accuracy.txt', acc_test_avg)
#plt.imshow(im_test_prediction)
#plt.show()