-
Notifications
You must be signed in to change notification settings - Fork 1
/
vgg.py
104 lines (97 loc) · 3.29 KB
/
vgg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.nn as nn
class vgg_layer(nn.Module):
def __init__(self, nin, nout):
super(vgg_layer, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(nin, nout, 3, 1, 1),
nn.BatchNorm2d(nout),
nn.LeakyReLU(0.2, inplace=True)
)
def forward(self, input):
return self.main(input)
class content_encoder(nn.Module):
def __init__(self, content_dim, nc=3):
super(content_encoder, self).__init__()
# 64 x 64
self.c1 = nn.Sequential(
vgg_layer(nc, 64),
vgg_layer(64, 64),
)
# 32 x 32
self.c2 = nn.Sequential(
vgg_layer(64, 128),
vgg_layer(128, 128),
)
# 16 x 16
self.c3 = nn.Sequential(
vgg_layer(128, 256),
vgg_layer(256, 256),
vgg_layer(256, 256),
)
# 8 x 8
self.c4 = nn.Sequential(
vgg_layer(256, 512),
vgg_layer(512, 512),
vgg_layer(512, 512),
)
# 4 x 4
self.c5 = nn.Sequential(
nn.Conv2d(512, content_dim, 4, 1, 0),
# nn.BatchNorm2d(content_dim),
# nn.Tanh()
)
self.mp = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
def forward(self, input):
h1 = self.c1(input) # 64 -> 32
h2 = self.c2(self.mp(h1)) # 32 -> 16
h3 = self.c3(self.mp(h2)) # 16 -> 8
h4 = self.c4(self.mp(h3)) # 8 -> 4
h5 = self.c5(self.mp(h4)) # 4 -> 1
return h5
class decoder(nn.Module):
def __init__(self, content_dim, pose_dim, nc=1):
super(decoder, self).__init__()
# 1 x 1 -> 4 x 4
self.upc1 = nn.Sequential(
nn.ConvTranspose2d(content_dim+pose_dim, 512, 4, 1, 0),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True)
)
# 8 x 8
self.upc2 = nn.Sequential(
vgg_layer(512*2, 512),
vgg_layer(512, 512),
vgg_layer(512, 256)
)
# 16 x 16
self.upc3 = nn.Sequential(
vgg_layer(256*2, 256),
vgg_layer(256, 256),
vgg_layer(256, 128)
)
# 32 x 32
self.upc4 = nn.Sequential(
vgg_layer(128*2, 128),
vgg_layer(128, 64)
)
# 64 x 64
self.upc5 = nn.Sequential(
vgg_layer(64*2, 64),
nn.ConvTranspose2d(64, nc, 3, 1, 1),
nn.Sigmoid()
)
self.up = nn.UpsamplingNearest2d(scale_factor=2)
def forward(self, input):
content, pose = input
content, skip = content
d1 = self.upc1(torch.cat([content, pose], 1)) # 1 -> 4
up1 = self.up(d1) # 4 -> 8
d2 = self.upc2(torch.cat([up1, skip[3]], 1)) # 8 x 8
up2 = self.up(d2) # 8 -> 16
d3 = self.upc3(torch.cat([up2, skip[2]], 1)) # 16 x 16
up3 = self.up(d3) # 8 -> 32
d4 = self.upc4(torch.cat([up3, skip[1]], 1)) # 32 x 32
up4 = self.up(d4) # 32 -> 64
output = self.upc5(torch.cat([up4, skip[0]], 1)) # 64 x 64
return output