-
Notifications
You must be signed in to change notification settings - Fork 21
/
test_continuous.py
92 lines (76 loc) · 3.22 KB
/
test_continuous.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import re
import time
from tqdm import tqdm
import argparse
import numpy as np
import tensorflow as tf
import imageio
from model import PWCDCNet
from flow_utils import vis_flow_pyramid
def factor_crop(image, factor = 64):
assert image.ndim == 3
h, w, _ = image.shape
image = image[:factor*(h//factor), :factor*(w//factor)]
return image
class Tester(object):
def __init__(self, args):
self.args = args
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
self.sess = tf.Session(config = config)
self._build_graph()
def _build_graph(self):
self.images = tf.placeholder(tf.float32, shape = (1, 2, None, None, 3))
self.model = PWCDCNet()
self.flow_final, self.flows \
= self.model(self.images[:,0], self.images[:,1])
self.saver = tf.train.Saver()
if self.args.resume is not None:
print(f'Loading learned model from checkpoint {self.args.resume}')
self.saver.restore(self.sess, self.args.resume)
else:
print('!!! Test with un-learned model !!!')
self.sess.run(tf.global_variables_initializer())
def test(self):
if not os.path.exists('./test_figure'):
os.mkdir('./test_figure')
image_path_pairs = zip(self.args.input_images[:-1], self.args.input_images[1:])
for img1_path, img2_path in tqdm(image_path_pairs, desc = 'Processing'):
images = list(map(imageio.imread, (img1_path, img2_path)))
images = list(map(factor_crop, images))
images = np.array(images)/255.
images_expand = np.expand_dims(images, axis = 0)
flows = self.sess.run(self.flows, feed_dict = {self.images: images_expand})
flow_set = []
for l, flow in enumerate(flows):
upscale = 20/2**(self.model.num_levels-l)
flow_set.append(flow[0]*upscale)
dname, fname = re.split('[/.]', img1_path)[-3:-1]
if not os.path.exists(f'./test_figure/{dname}'):
os.mkdir(f'./test_figure/{dname}')
vis_flow_pyramid(flow_set, images = images,
filename = f'./test_figure/{dname}/{fname}.png')
print('Figure saved')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_images', type = str, nargs = '+', required = True,
help = 'Target images (required)')
parser.add_argument('-r', '--resume', type = str, default = None,
help = 'Learned parameter checkpoint file [None]')
args = parser.parse_args()
# Expand wild-card
if '*' in args.input_images:
from glob import glob
args.input_images = glob(args.input_images)
if len(args.input_images) < 2:
raise ValueError('# of input images must be >= 2')
print(args.resume)
for i, image in enumerate(args.input_images):
print(image)
if i == 5:
print(f'... and more ({len(args.input_images)} images)')
break
os.environ['CUDA_VISIBLE_DEVICES'] = input('Input utilize gpu-id (-1:cpu) : ')
tester = Tester(args)
tester.test()