-
Notifications
You must be signed in to change notification settings - Fork 253
/
Copy pathCPCA2d.py
89 lines (71 loc) · 3.44 KB
/
CPCA2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from torch import nn
import torch
import torch.nn.functional
import torch.nn.functional as F
# 论文:Channel prior convolutional attention for medical image segmentation
# 论文地址:https://arxiv.org/pdf/2306.05196
class ChannelAttention(nn.Module):
def __init__(self, input_channels, internal_neurons):
super(ChannelAttention, self).__init__()
self.fc1 = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1, bias=True)
self.fc2 = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1, bias=True)
self.input_channels = input_channels
def forward(self, inputs):
x1 = F.adaptive_avg_pool2d(inputs, output_size=(1, 1))
# print('x:', x.shape)
x1 = self.fc1(x1)
x1 = F.relu(x1, inplace=True)
x1 = self.fc2(x1)
x1 = torch.sigmoid(x1)
x2 = F.adaptive_max_pool2d(inputs, output_size=(1, 1))
# print('x:', x.shape)
x2 = self.fc1(x2)
x2 = F.relu(x2, inplace=True)
x2 = self.fc2(x2)
x2 = torch.sigmoid(x2)
x = x1 + x2
x = x.view(-1, self.input_channels, 1, 1)
return x
class CPCABlock(nn.Module):
def __init__(self, in_channels, out_channels,
channelAttention_reduce=4):
super().__init__()
self.C = in_channels
self.O = out_channels
assert in_channels == out_channels
self.ca = ChannelAttention(input_channels=in_channels, internal_neurons=in_channels // channelAttention_reduce)
self.dconv5_5 = nn.Conv2d(in_channels, in_channels, kernel_size=5, padding=2, groups=in_channels)
self.dconv1_7 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 7), padding=(0, 3), groups=in_channels)
self.dconv7_1 = nn.Conv2d(in_channels, in_channels, kernel_size=(7, 1), padding=(3, 0), groups=in_channels)
self.dconv1_11 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 11), padding=(0, 5), groups=in_channels)
self.dconv11_1 = nn.Conv2d(in_channels, in_channels, kernel_size=(11, 1), padding=(5, 0), groups=in_channels)
self.dconv1_21 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 21), padding=(0, 10), groups=in_channels)
self.dconv21_1 = nn.Conv2d(in_channels, in_channels, kernel_size=(21, 1), padding=(10, 0), groups=in_channels)
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 1), padding=0)
self.act = nn.GELU()
def forward(self, inputs):
# Global Perceptron
inputs = self.conv(inputs)
inputs = self.act(inputs)
channel_att_vec = self.ca(inputs)
inputs = channel_att_vec * inputs
x_init = self.dconv5_5(inputs)
x_1 = self.dconv1_7(x_init)
x_1 = self.dconv7_1(x_1)
x_2 = self.dconv1_11(x_init)
x_2 = self.dconv11_1(x_2)
x_3 = self.dconv1_21(x_init)
x_3 = self.dconv21_1(x_3)
x = x_1 + x_2 + x_3 + x_init
spatial_att = self.conv(x)
out = spatial_att * inputs
out = self.conv(out)
return out
if __name__ == '__main__':
input = torch.randn(4, 16, 64, 64)
print(input.size())
block = CPCABlock(in_channels=16, out_channels=16, channelAttention_reduce=4)
# 通过CPCABlock模块处理输入
output = block(input)
# 打印输出张量的形状
print(output.size())