-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
277 lines (235 loc) · 10 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import torch.nn as nn
import torch
from utils import normalize
from utils import mean_std
# Output the image from the content and style features
# passed through the attention layers
decoder = nn.Sequential(
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 256, (3, 3)),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 128, (3, 3)),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 128, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 64, (3, 3)),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 64, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 3, (3, 3)),
)
# The feature extractor network
vgg = nn.Sequential(
nn.Conv2d(3, 3, (1, 1)),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(3, 64, (3, 3)),
nn.ReLU(), # relu1-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 64, (3, 3)),
nn.ReLU(), # relu1-2
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 128, (3, 3)),
nn.ReLU(), # relu2-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 128, (3, 3)),
nn.ReLU(), # relu2-2
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 256, (3, 3)),
nn.ReLU(), # relu3-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(), # relu3-2
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(), # relu3-3
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(), # relu3-4
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 512, (3, 3)),
nn.ReLU(), # relu4-1, this is the last layer used
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu4-2
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu4-3
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu4-4
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu5-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu5-2
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu5-3
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU() # relu5-4
)
# This is a Self Atention Network where the content image features (Fc)
# and the style image features (Fs) are match with the attention mechanism
class SANet(nn.Module):
def __init__(self, in_channel: int):
super(SANet, self).__init__()
self.f = nn.Conv2d(in_channel, in_channel, (1, 1))
self.g = nn.Conv2d(in_channel, in_channel, (1, 1))
self.h = nn.Conv2d(in_channel, in_channel, (1, 1))
self.softmax = nn.Softmax(-1)
self.out_conv = nn.Conv2d(in_channel, in_channel, (1, 1))
def forward(self, Fc : torch.Tensor, Fs : torch.Tensor):
B, _, H, W = Fc.size()
# The permute transpose the tensor along the two last axis
# F_Fc_norm is similar to the 'key' if you are familiar with attention mechanism
F_Fc_norm = self.f(normalize(Fc)).view(B, -1, H * W).permute(0, 2, 1)
B, _, H, W = Fs.size()
# This is similar to the 'query'
G_Fs_norm = self.g(normalize(Fs)).view(B, -1, H * W)
# The attention mechanism
attention = self.softmax(torch.bmm(F_Fc_norm, G_Fs_norm))
# This is similar to the 'value'
H_Fs = self.h(Fs).view(B, -1, H * W)
# Finally this is the output calculation
out = torch.bmm(H_Fs, attention.permute(0, 2, 1))
B, C, H, W = Fc.size()
# Reshape to the feature shapes
out = out.view(B, C, H, W)
out = self.out_conv(out)
# Skip connection with the content features
out += Fc
return out
# Combine the SANets
class SelfAttentionModule(nn.Module):
def __init__(self, in_channel : int):
super(SelfAttentionModule, self).__init__()
# two SANets for the relu5 and relu4 layers
self.SAN1 = SANet(in_channel)
self.SAN2 = SANet(in_channel)
# Other layers for combining
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.merge_conv_pad = nn.ReflectionPad2d((1, 1, 1, 1))
self.merge_conv = nn.Conv2d(in_channel, in_channel, (3, 3))
def forward(self, Fc : torch.Tensor, Fs : torch.Tensor):
# First, the attentions on the last features (relu5 output)
Fcsc_5 = self.SAN1(Fc[-1], Fs[-1])
# Uspsampling to match the Fcsc_4 shape
Fcsc_5_up = self.upsample(Fcsc_5)
# Then, the attentions on the relu4 output
Fcsc_4 = self.SAN2(Fc[-2], Fs[-2])
# Finaly combination and convolution of both SANets
Fcsc_m = Fcsc_4 + Fcsc_5_up
Fcsc_m = self.merge_conv_pad(Fcsc_m)
Fcsc_m = self.merge_conv(Fcsc_m)
return Fcsc_m
# Comput the output images and the losses
class MultiLevelStyleAttention(nn.Module):
def __init__(self, encoder : nn.Sequential, decoder : nn.Sequential):
super(MultiLevelStyleAttention, self).__init__()
# Get the encoder layers
encoder_layers = list(encoder.children())
self.enc_1 = nn.Sequential(*encoder_layers[:4])
self.enc_2 = nn.Sequential(*encoder_layers[4:11])
self.enc_3 = nn.Sequential(*encoder_layers[11:18])
self.enc_4 = nn.Sequential(*encoder_layers[18:31])
self.enc_5 = nn.Sequential(*encoder_layers[31:44])
# Transforms
self.sa_module = SelfAttentionModule(512)
self.decoder = decoder
self.mse_loss = nn.MSELoss()
# Fix the encoder parameters
for n in ['enc_1', 'enc_2', 'enc_3', 'enc_4', 'enc_5']:
for param in getattr(self, n).parameters():
param.requires_grad = False
def get_encoder_features(self, x : torch.Tensor):
'''
x : batch of images
'''
results = [x]
for i in range(5):
# Get the enc_i layers from self
func = getattr(self, 'enc_{:d}'.format(i + 1))
# Propagate previous output through enc_i layers
results.append(func(results[-1]))
# extract [relu1_1, relu2_1, relu3_1, relu4_1, relu5_1] from x
return results[1:]
# It is just a mse loss with safe garde
def content_loss(self, _input : torch.Tensor, _target : torch.Tensor):
assert (_input.size() == _target.size())
assert (_target.requires_grad is False)
return self.mse_loss(_input, _target)
# Mse loss of the mean and standard deviation of the input and target
def style_loss(self, _input : torch.Tensor, _target : torch.Tensor):
assert (_input.size() == _target.size())
assert (_target.requires_grad is False)
_input_mean, _input_std = mean_std(_input)
_target_mean, _target_std = mean_std(_target)
return self.mse_loss(_input_mean, _target_mean) \
+ self.mse_loss(_input_std, _target_std)
# If training, compute the losses,
# else, only the output styled images
def forward(self, Ic : torch.Tensor, Is : torch.Tensor, train : bool = True):
# Extract features from the style and the content image
Fs = self.get_encoder_features(Is)
Fc = self.get_encoder_features(Ic)
# The styled images
Ics = self.decoder(self.sa_module(Fc, Fs))
# If we are not training, stop here
if not train:
return Ics
# Extract the features from the stylized images
Ics_feats = self.get_encoder_features(Ics)
# Content loss
# We only use relu4 and relu5 output because
# this is where there is the most content value
Lc = self.content_loss(normalize(Ics_feats[-1]), normalize(Fc[-1])) \
+ self.content_loss(normalize(Ics_feats[-2]), normalize(Fc[-2]))
# Style loss
# Here we calculate the style loss from all the relu layers because
# the style can be extract from all the layers
# (even there is more style value on the first layers)
Ls = sum([self.style_loss(Ics_feats[i], Fs[i]) for i in range(5)])
# For the result yield by the save 159000.pt i have
# finetune for 50000 steps with above line replace by :
# Ls = sum([self.style_loss(Ics_feats[i], Fs[i]) for i in range(4)])
# By remvoing the last layer of the style images, we remove the
# style image content leakage on the transfered image.
# I also added 3 to the style_weight on the config.json to
# put more value on the style because the sum will be smaller
# Then the styled images with both same style and both same content
Icc = self.decoder(self.sa_module(Fc, Fc))
Iss = self.decoder(self.sa_module(Fs, Fs))
# Extracting the features
Icc_feats = self.get_encoder_features(Icc)
Iss_feats = self.get_encoder_features(Iss)
# Those two loss value are the inovation of this paper.
# This is used to check that stylized an image with itself
# have the same features as the initial images
# identity1 loss
loss_lambda1 = self.content_loss(Icc, Ic) + self.content_loss(Iss, Is)
# identity2 loss
loss_lambda2 = sum([self.content_loss(Icc_feats[i], Fc[i]) + self.content_loss(Iss_feats[i], Fs[i]) for i in range(5)])
return Lc, Ls, loss_lambda1, loss_lambda2