-
Notifications
You must be signed in to change notification settings - Fork 251
/
Copy path(arXiv 2021) PSAN.py
105 lines (82 loc) · 4.03 KB
/
(arXiv 2021) PSAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import numpy as np
import torch
from torch import nn
from torch.nn import init
# 论文地址:https://arxiv.org/pdf/2107.00782
# 论文:Polarized Self-Attention: Towards High-quality Pixel-wise Regression
class ParallelPolarizedSelfAttention(nn.Module):
def __init__(self, channel=512):
super().__init__()
self.ch_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
self.ch_wq=nn.Conv2d(channel,1,kernel_size=(1,1))
self.softmax_channel=nn.Softmax(1)
self.softmax_spatial=nn.Softmax(-1)
self.ch_wz=nn.Conv2d(channel//2,channel,kernel_size=(1,1))
self.ln=nn.LayerNorm(channel)
self.sigmoid=nn.Sigmoid()
self.sp_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
self.sp_wq=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
self.agp=nn.AdaptiveAvgPool2d((1,1))
def forward(self, x):
b, c, h, w = x.size()
#Channel-only Self-Attention
channel_wv=self.ch_wv(x) #bs,c//2,h,w
channel_wq=self.ch_wq(x) #bs,1,h,w
channel_wv=channel_wv.reshape(b,c//2,-1) #bs,c//2,h*w
channel_wq=channel_wq.reshape(b,-1,1) #bs,h*w,1
channel_wq=self.softmax_channel(channel_wq)
channel_wz=torch.matmul(channel_wv,channel_wq).unsqueeze(-1) #bs,c//2,1,1
channel_weight=self.sigmoid(self.ln(self.ch_wz(channel_wz).reshape(b,c,1).permute(0,2,1))).permute(0,2,1).reshape(b,c,1,1) #bs,c,1,1
channel_out=channel_weight*x
#Spatial-only Self-Attention
spatial_wv=self.sp_wv(x) #bs,c//2,h,w
spatial_wq=self.sp_wq(x) #bs,c//2,h,w
spatial_wq=self.agp(spatial_wq) #bs,c//2,1,1
spatial_wv=spatial_wv.reshape(b,c//2,-1) #bs,c//2,h*w
spatial_wq=spatial_wq.permute(0,2,3,1).reshape(b,1,c//2) #bs,1,c//2
spatial_wq=self.softmax_spatial(spatial_wq)
spatial_wz=torch.matmul(spatial_wq,spatial_wv) #bs,1,h*w
spatial_weight=self.sigmoid(spatial_wz.reshape(b,1,h,w)) #bs,1,h,w
spatial_out=spatial_weight*x
out=spatial_out+channel_out
return out
class SequentialPolarizedSelfAttention(nn.Module):
def __init__(self, channel=512):
super().__init__()
self.ch_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
self.ch_wq=nn.Conv2d(channel,1,kernel_size=(1,1))
self.softmax_channel=nn.Softmax(1)
self.softmax_spatial=nn.Softmax(-1)
self.ch_wz=nn.Conv2d(channel//2,channel,kernel_size=(1,1))
self.ln=nn.LayerNorm(channel)
self.sigmoid=nn.Sigmoid()
self.sp_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
self.sp_wq=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
self.agp=nn.AdaptiveAvgPool2d((1,1))
def forward(self, x):
b, c, h, w = x.size()
#Channel-only Self-Attention
channel_wv=self.ch_wv(x) #bs,c//2,h,w
channel_wq=self.ch_wq(x) #bs,1,h,w
channel_wv=channel_wv.reshape(b,c//2,-1) #bs,c//2,h*w
channel_wq=channel_wq.reshape(b,-1,1) #bs,h*w,1
channel_wq=self.softmax_channel(channel_wq)
channel_wz=torch.matmul(channel_wv,channel_wq).unsqueeze(-1) #bs,c//2,1,1
channel_weight=self.sigmoid(self.ln(self.ch_wz(channel_wz).reshape(b,c,1).permute(0,2,1))).permute(0,2,1).reshape(b,c,1,1) #bs,c,1,1
channel_out=channel_weight*x
#Spatial-only Self-Attention
spatial_wv=self.sp_wv(channel_out) #bs,c//2,h,w
spatial_wq=self.sp_wq(channel_out) #bs,c//2,h,w
spatial_wq=self.agp(spatial_wq) #bs,c//2,1,1
spatial_wv=spatial_wv.reshape(b,c//2,-1) #bs,c//2,h*w
spatial_wq=spatial_wq.permute(0,2,3,1).reshape(b,1,c//2) #bs,1,c//2
spatial_wq=self.softmax_spatial(spatial_wq)
spatial_wz=torch.matmul(spatial_wq,spatial_wv) #bs,1,h*w
spatial_weight=self.sigmoid(spatial_wz.reshape(b,1,h,w)) #bs,1,h,w
spatial_out=spatial_weight*channel_out
return spatial_out
if __name__ == '__main__':
input=torch.randn(1,512,7,7)
block = SequentialPolarizedSelfAttention(channel=512)
output=block(input)
print(output.shape)