-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmodule.py
119 lines (96 loc) · 3.79 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.nn as nn
import torch.nn.functional as F
def logsumexp_2d(tensor):
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
return outputs
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
"3x3 convolution with padding"
return nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=strd, padding=padding, bias=bias
)
class ConvBlock(nn.Module):
def __init__(self, in_planes, out_planes):
super(ConvBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
if in_planes != out_planes:
self.downsample = nn.Sequential(
nn.BatchNorm2d(in_planes),
nn.ReLU(True),
nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False),
)
else:
self.downsample = None
def forward(self, x):
residual = x
out1 = self.bn1(x)
out1 = F.relu(out1, True)
out1 = self.conv1(out1)
out2 = self.bn2(out1)
out2 = F.relu(out2, True)
out2 = self.conv2(out2)
out3 = self.bn3(out2)
out3 = F.relu(out3, True)
out3 = self.conv3(out3)
out3 = torch.cat((out1, out2, out3), 1)
if self.downsample is not None:
residual = self.downsample(residual)
out3 += residual
return out3
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(
self, gate_channels, face_classes, reduction_ratio=16, pool_types=["avg"]
):
""" """
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels // reduction_ratio, face_classes),
)
self.pool_types = pool_types
self.face_classes = face_classes
def forward(self, x):
b, c, h, w = x.shape
assert c % self.face_classes == 0
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type == "avg":
avg_pool = F.avg_pool2d(
x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))
)
channel_att_raw = self.mlp(avg_pool)
elif pool_type == "max":
max_pool = F.max_pool2d(
x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))
)
channel_att_raw = self.mlp(max_pool)
elif pool_type == "lp":
lp_pool = F.lp_pool2d(
x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))
)
channel_att_raw = self.mlp(lp_pool)
elif pool_type == "lse":
# LSE pool only
lse_pool = logsumexp_2d(x)
channel_att_raw = self.mlp(lse_pool)
if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).unsqueeze(1)
x = x.view(b, -1, self.face_classes, h, w)
out = x * scale
return out.view(b, -1, h, w).contiguous()