-
Notifications
You must be signed in to change notification settings - Fork 19
/
test.py
108 lines (91 loc) · 3.99 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from __future__ import division
import argparse
import numpy as np
import time, math, glob
import scipy.misc
import os
import imageio
import pdb
import tensorlayer as tl
import tensorflow as tf
from model import *
from utils import *
parser = argparse.ArgumentParser(description="")
parser.add_argument("--model_path", type=str, default="checkpoint/FEQE/model.ckpt", help="model path")
parser.add_argument('--save_path', type=str, default='results')
parser.add_argument("--dataset", default="Set5", type=str, help="dataset name, Default: Set5")
parser.add_argument('--downsample_type', type=str, default='desubpixel')
parser.add_argument('--upsample_type', type=str, default='subpixel')
parser.add_argument('--conv_type', type=str, default='default')
parser.add_argument('--body_type', type=str, default='resnet')
parser.add_argument('--n_feats', type=int, default=16,
help='number of convolution feats')
parser.add_argument('--n_blocks', type=int, default=20,
help='number of residual block if body_type=resnet')
parser.add_argument('--n_groups', type=int, default=0,
help='number of residual group if body_type=res_in_res')
parser.add_argument('--n_convs', type=int, default=0,
help='number of conv layers if body_type=conv')
parser.add_argument('--n_squeezes', type=int, default=0,
help='number of squeeze blocks if body_type=squeeze')
parser.add_argument('--scale', type=int, default=4)
args = parser.parse_args()
print('############################################################')
print('# Image Super Resolution - PIRM2018 - TEAM_ALEX #')
print('# Implemented by Thang Vu, [email protected] #')
print('############################################################')
print('')
print('_____________YOUR SETTINGS_____________')
for arg in vars(args):
print("%20s: %s" %(str(arg), str(getattr(args, arg))))
print('')
def main():
#==================Data==================================
print('Loading data...')
test_hr_path = os.path.join('./data/test_benchmark', args.dataset)
hr_paths = glob.glob(os.path.join(test_hr_path, '*.png'))
hr_paths.sort()
#=================Model===================================
print('Loading model...')
t_lr = tf.placeholder('float32', [1, None, None, 3], name='input_image')
t_hr = tf.placeholder('float32', [1, None, None, 3], name='label_image')
opt = {
'n_feats': args.n_feats,
'n_blocks': args.n_blocks,
'n_groups': args.n_groups,
'n_convs': args.n_convs,
'n_squeezes': args.n_squeezes,
'downsample_type': args.downsample_type,
'upsample_type': args.upsample_type,
'conv_type': args.conv_type,
'body_type': args.body_type,
'scale': args.scale
}
t_sr = FEQE(t_lr, opt)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
tl.layers.initialize_global_variables(sess)
saver = tf.train.Saver()
saver.restore(sess, args.model_path)
#=================result=================================
save_path = os.path.join(args.save_path, args.dataset)
if not os.path.exists(save_path):
os.makedirs(save_path)
psnr_avr = 0
for i, _ in enumerate(hr_paths):
print('processing image %d' %i)
hr_org = imageio.imread(hr_paths[i])
lr = downsample_fn(hr_org)
[hr, lr] = normalize([hr_org, lr])
lr = lr[np.newaxis, :, :, :]
hr = hr[np.newaxis, :, :, :]
[sr] = sess.run([t_sr], {t_lr: lr, t_hr: hr})
sr = np.squeeze(sr)
[sr] = restore([sr])
sr = sr[args.scale:-args.scale, args.scale:-args.scale]
hr_org = hr_org[args.scale:-args.scale, args.scale:-args.scale]
psnr_avr += compute_PSNR(sr, hr_org)
scipy.misc.imsave(os.path.join(save_path, os.path.basename(hr_paths[i])), sr)
print('Average PSNR: %.4f' %(psnr_avr/len(hr_paths)))
print('Finish')
if __name__ == '__main__':
main()