forked from CompVis/adaptive-style-transfer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
module.py
246 lines (203 loc) · 10.1 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
# Copyright (C) 2018 Artsiom Sanakoyeu and Dmytro Kotovenko
#
# This file is part of Adaptive Style Transfer
#
# Adaptive Style Transfer is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Adaptive Style Transfer is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import division
from ops import *
def encoder(image, options, reuse=True, name="encoder"):
"""
Args:
image: input tensor, must have
options: options defining number of kernels in conv layers
reuse: to create new encoder or use existing
name: name of the encoder
Returns: Encoded image.
"""
with tf.variable_scope(name):
if reuse:
tf.get_variable_scope().reuse_variables()
else:
assert tf.get_variable_scope().reuse is False
image = instance_norm(input=image,
is_training=options.is_training,
name='g_e0_bn')
c0 = tf.pad(image, [[0, 0], [15, 15], [15, 15], [0, 0]], "REFLECT")
c1 = tf.nn.relu(instance_norm(input=conv2d(c0, options.gf_dim, 3, 1, padding='VALID', name='g_e1_c'),
is_training=options.is_training,
name='g_e1_bn'))
c2 = tf.nn.relu(instance_norm(input=conv2d(c1, options.gf_dim, 3, 2, padding='VALID', name='g_e2_c'),
is_training=options.is_training,
name='g_e2_bn'))
c3 = tf.nn.relu(instance_norm(conv2d(c2, options.gf_dim * 2, 3, 2, padding='VALID', name='g_e3_c'),
is_training=options.is_training,
name='g_e3_bn'))
c4 = tf.nn.relu(instance_norm(conv2d(c3, options.gf_dim * 4, 3, 2, padding='VALID', name='g_e4_c'),
is_training=options.is_training,
name='g_e4_bn'))
c5 = tf.nn.relu(instance_norm(conv2d(c4, options.gf_dim * 8, 3, 2, padding='VALID', name='g_e5_c'),
is_training=options.is_training,
name='g_e5_bn'))
return c5
def decoder(features, options, reuse=True, name="decoder"):
"""
Args:
features: input tensor, must have
options: options defining number of kernels in conv layers
reuse: to create new decoder or use existing
name: name of the encoder
Returns: Decoded image.
"""
with tf.variable_scope(name):
if reuse:
tf.get_variable_scope().reuse_variables()
else:
assert tf.get_variable_scope().reuse is False
def residule_block(x, dim, ks=3, s=1, name='res'):
p = int((ks - 1) / 2)
y = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c1'), name+'_bn1')
y = tf.pad(tf.nn.relu(y), [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c2'), name+'_bn2')
return y + x
# Now stack 9 residual blocks
num_kernels = features.get_shape().as_list()[-1]
r1 = residule_block(features, num_kernels, name='g_r1')
r2 = residule_block(r1, num_kernels, name='g_r2')
r3 = residule_block(r2, num_kernels, name='g_r3')
r4 = residule_block(r3, num_kernels, name='g_r4')
r5 = residule_block(r4, num_kernels, name='g_r5')
r6 = residule_block(r5, num_kernels, name='g_r6')
r7 = residule_block(r6, num_kernels, name='g_r7')
r8 = residule_block(r7, num_kernels, name='g_r8')
r9 = residule_block(r8, num_kernels, name='g_r9')
# Decode image.
d1 = deconv2d(r9, options.gf_dim * 8, 3, 2, name='g_d1_dc')
d1 = tf.nn.relu(instance_norm(input=d1,
name='g_d1_bn',
is_training=options.is_training))
d2 = deconv2d(d1, options.gf_dim * 4, 3, 2, name='g_d2_dc')
d2 = tf.nn.relu(instance_norm(input=d2,
name='g_d2_bn',
is_training=options.is_training))
d3 = deconv2d(d2, options.gf_dim * 2, 3, 2, name='g_d3_dc')
d3 = tf.nn.relu(instance_norm(input=d3,
name='g_d3_bn',
is_training=options.is_training))
d4 = deconv2d(d3, options.gf_dim, 3, 2, name='g_d4_dc')
d4 = tf.nn.relu(instance_norm(input=d4,
name='g_d4_bn',
is_training=options.is_training))
d4 = tf.pad(d4, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
pred = tf.nn.sigmoid(conv2d(d4, 3, 7, 1, padding='VALID', name='g_pred_c'))*2. - 1.
return pred
def discriminator(image, options, reuse=True, name="discriminator"):
"""
Discriminator agent, that provides us with information about image plausibility at
different scales.
Args:
image: input tensor
options: options defining number of kernels in conv layers
reuse: to create new discriminator or use existing
name: name of the discriminator
Returns:
Image estimates at different scales.
"""
with tf.variable_scope(name):
if reuse:
tf.get_variable_scope().reuse_variables()
else:
assert tf.get_variable_scope().reuse is False
h0 = lrelu(instance_norm(conv2d(image, options.df_dim * 2, ks=5, name='d_h0_conv'),
name='d_bn0'))
h0_pred = conv2d(h0, 1, ks=5, s=1, name='d_h0_pred', activation_fn=None)
h1 = lrelu(instance_norm(conv2d(h0, options.df_dim * 2, ks=5, name='d_h1_conv'),
name='d_bn1'))
h1_pred = conv2d(h1, 1, ks=10, s=1, name='d_h1_pred', activation_fn=None)
h2 = lrelu(instance_norm(conv2d(h1, options.df_dim * 4, ks=5, name='d_h2_conv'),
name='d_bn2'))
h3 = lrelu(instance_norm(conv2d(h2, options.df_dim * 8, ks=5, name='d_h3_conv'),
name='d_bn3'))
h3_pred = conv2d(h3, 1, ks=10, s=1, name='d_h3_pred', activation_fn=None)
h4 = lrelu(instance_norm(conv2d(h3, options.df_dim * 8, ks=5, name='d_h4_conv'),
name='d_bn4'))
h5 = lrelu(instance_norm(conv2d(h4, options.df_dim * 16, ks=5, name='d_h5_conv'),
name='d_bn5'))
h5_pred = conv2d(h5, 1, ks=6, s=1, name='d_h5_pred', activation_fn=None)
h6 = lrelu(instance_norm(conv2d(h5, options.df_dim * 16, ks=5, name='d_h6_conv'),
name='d_bn6'))
h6_pred = conv2d(h6, 1, ks=3, s=1, name='d_h6_pred', activation_fn=None)
return {"scale_0": h0_pred,
"scale_1": h1_pred,
"scale_3": h3_pred,
"scale_5": h5_pred,
"scale_6": h6_pred}
# ====== Define different types of losses applied to discriminator's output. ====== #
def abs_criterion(in_, target):
return tf.reduce_mean(tf.abs(in_ - target))
def mae_criterion(in_, target):
return tf.reduce_mean(tf.abs(in_-target))
def mse_criterion(in_, target):
return tf.reduce_mean((in_-target)**2)
def sce_criterion(logits, labels):
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))
def reduce_spatial_dim(input_tensor):
"""
Since labels and discriminator outputs are of different shapes (and even ranks)
we should write a routine to deal with that.
Args:
input: tensor of shape [batch_size, spatial_resol_1, spatial_resol_2, depth]
Returns:
tensor of shape [batch_size, depth]
"""
input_tensor = tf.reduce_mean(input_tensor=input_tensor, axis=1)
input_tensor = tf.reduce_mean(input_tensor=input_tensor, axis=1)
return input_tensor
def add_spatial_dim(input_tensor, dims_list, resol_list):
"""
Appends dimensions mentioned in dims_list resol_list times. S
Args:
input: tensor of shape [batch_size, depth0]
dims_list: list of integers with position of new dimensions to append.
resol_list: list of integers with corresponding new dimensionalities for each dimension.
Returns:
tensor of new shape
"""
for dim, res in zip(dims_list, resol_list):
input_tensor = tf.expand_dims(input=input_tensor, axis=dim)
input_tensor = tf.concat(values=[input_tensor]*res, axis=dim)
return input_tensor
def repeat_scalar(input_tensor, shape):
"""
Repeat scalar values.
:param input_tensor: tensor of shape [batch_size, 1]
:param shape: new_shape of the element of the tensor
:return: tensor of the shape [batch_size, *shape] with elements repeated.
"""
with tf.control_dependencies([tf.assert_equal(tf.shape(input_tensor)[1], 1)]):
batch_size = tf.shape(input_tensor)[0]
input_tensor = tf.tile(input_tensor, tf.stack(values=[1, tf.reduce_prod(shape)], axis=0))
input_tensor = tf.reshape(input_tensor, tf.concat(values=[[batch_size], shape, [1]], axis=0))
return input_tensor
def transformer_block(input_tensor, kernel_size=10):
"""
This is a simplified version of transformer block described in our paper
https://arxiv.org/abs/1807.10201.
Args:
input_tensor: Image(or tensor of rank 4) we want to transform.
kernel_size: Size of kernel we apply to the input_tensor.
Returns:
Transformed tensor
"""
return slim.avg_pool2d(inputs=input_tensor, kernel_size=kernel_size, stride=1, padding='SAME')