forked from yfeng95/PRNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predictor.py
106 lines (89 loc) · 5.8 KB
/
predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import tensorflow as tf
import tensorflow.contrib.layers as tcl
from tensorflow.contrib.framework import arg_scope
import numpy as np
def resBlock(x, num_outputs, kernel_size = 4, stride=1, activation_fn=tf.nn.relu, normalizer_fn=tcl.batch_norm, scope=None):
assert num_outputs%2==0 #num_outputs must be divided by channel_factor(2 here)
with tf.variable_scope(scope, 'resBlock'):
shortcut = x
if stride != 1 or x.get_shape()[3] != num_outputs:
shortcut = tcl.conv2d(shortcut, num_outputs, kernel_size=1, stride=stride,
activation_fn=None, normalizer_fn=None, scope='shortcut')
x = tcl.conv2d(x, num_outputs/2, kernel_size=1, stride=1, padding='SAME')
x = tcl.conv2d(x, num_outputs/2, kernel_size=kernel_size, stride=stride, padding='SAME')
x = tcl.conv2d(x, num_outputs, kernel_size=1, stride=1, activation_fn=None, padding='SAME', normalizer_fn=None)
x += shortcut
x = normalizer_fn(x)
x = activation_fn(x)
return x
class resfcn256(object):
def __init__(self, resolution_inp = 256, resolution_op = 256, channel = 3, name = 'resfcn256'):
self.name = name
self.channel = channel
self.resolution_inp = resolution_inp
self.resolution_op = resolution_op
def __call__(self, x, is_training = True):
with tf.variable_scope(self.name) as scope:
with arg_scope([tcl.batch_norm], is_training=is_training, scale=True):
with arg_scope([tcl.conv2d, tcl.conv2d_transpose], activation_fn=tf.nn.relu,
normalizer_fn=tcl.batch_norm,
biases_initializer=None,
padding='SAME',
weights_regularizer=tcl.l2_regularizer(0.0002)):
size = 16
# x: s x s x 3
se = tcl.conv2d(x, num_outputs=size, kernel_size=4, stride=1) # 256 x 256 x 16
se = resBlock(se, num_outputs=size * 2, kernel_size=4, stride=2) # 128 x 128 x 32
se = resBlock(se, num_outputs=size * 2, kernel_size=4, stride=1) # 128 x 128 x 32
se = resBlock(se, num_outputs=size * 4, kernel_size=4, stride=2) # 64 x 64 x 64
se = resBlock(se, num_outputs=size * 4, kernel_size=4, stride=1) # 64 x 64 x 64
se = resBlock(se, num_outputs=size * 8, kernel_size=4, stride=2) # 32 x 32 x 128
se = resBlock(se, num_outputs=size * 8, kernel_size=4, stride=1) # 32 x 32 x 128
se = resBlock(se, num_outputs=size * 16, kernel_size=4, stride=2) # 16 x 16 x 256
se = resBlock(se, num_outputs=size * 16, kernel_size=4, stride=1) # 16 x 16 x 256
se = resBlock(se, num_outputs=size * 32, kernel_size=4, stride=2) # 8 x 8 x 512
se = resBlock(se, num_outputs=size * 32, kernel_size=4, stride=1) # 8 x 8 x 512
pd = tcl.conv2d_transpose(se, size * 32, 4, stride=1) # 8 x 8 x 512
pd = tcl.conv2d_transpose(pd, size * 16, 4, stride=2) # 16 x 16 x 256
pd = tcl.conv2d_transpose(pd, size * 16, 4, stride=1) # 16 x 16 x 256
pd = tcl.conv2d_transpose(pd, size * 16, 4, stride=1) # 16 x 16 x 256
pd = tcl.conv2d_transpose(pd, size * 8, 4, stride=2) # 32 x 32 x 128
pd = tcl.conv2d_transpose(pd, size * 8, 4, stride=1) # 32 x 32 x 128
pd = tcl.conv2d_transpose(pd, size * 8, 4, stride=1) # 32 x 32 x 128
pd = tcl.conv2d_transpose(pd, size * 4, 4, stride=2) # 64 x 64 x 64
pd = tcl.conv2d_transpose(pd, size * 4, 4, stride=1) # 64 x 64 x 64
pd = tcl.conv2d_transpose(pd, size * 4, 4, stride=1) # 64 x 64 x 64
pd = tcl.conv2d_transpose(pd, size * 2, 4, stride=2) # 128 x 128 x 32
pd = tcl.conv2d_transpose(pd, size * 2, 4, stride=1) # 128 x 128 x 32
pd = tcl.conv2d_transpose(pd, size, 4, stride=2) # 256 x 256 x 16
pd = tcl.conv2d_transpose(pd, size, 4, stride=1) # 256 x 256 x 16
pd = tcl.conv2d_transpose(pd, 3, 4, stride=1) # 256 x 256 x 3
pd = tcl.conv2d_transpose(pd, 3, 4, stride=1) # 256 x 256 x 3
pos = tcl.conv2d_transpose(pd, 3, 4, stride=1, activation_fn = tf.nn.sigmoid)#, padding='SAME', weights_initializer=tf.random_normal_initializer(0, 0.02))
return pos
@property
def vars(self):
return [var for var in tf.global_variables() if self.name in var.name]
class PosPrediction():
def __init__(self, resolution_inp = 256, resolution_op = 256):
# -- hyper settings
self.resolution_inp = resolution_inp
self.resolution_op = resolution_op
self.MaxPos = resolution_inp*1.1
# network type
self.network = resfcn256(self.resolution_inp, self.resolution_op)
# net forward
self.x = tf.placeholder(tf.float32, shape=[None, self.resolution_inp, self.resolution_inp, 3])
self.x_op = self.network(self.x, is_training = False)
self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
def restore(self, model_path):
tf.train.Saver(self.network.vars).restore(self.sess, model_path)
def predict(self, image):
pos = self.sess.run(self.x_op,
feed_dict = {self.x: image[np.newaxis, :,:,:]})
pos = np.squeeze(pos)
return pos*self.MaxPos
def predict_batch(self, images):
pos = self.sess.run(self.x_op,
feed_dict = {self.x: images})
return pos*self.MaxPos