-
Notifications
You must be signed in to change notification settings - Fork 1
/
siamese3D.py
72 lines (64 loc) · 2.43 KB
/
siamese3D.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
class Siamese3D(nn.Module):
def __init__(self, n_classes):
super(Siamese3D, self).__init__()
self.conv3d_5_2 = nn.ModuleList([nn.Conv3d(in_channels=1, out_channels=32, kernel_size=(5,5,5), stride=(2,2,2), padding='valid') for _ in range(2)])
self.conv3d_3_1 = nn.ModuleList([nn.Conv3d(in_channels=32, out_channels=32, kernel_size=(3,3,3), stride=(1,1,1), padding='valid') for _ in range(12)])
self.bn3d = nn.ModuleList([nn.BatchNorm3d(num_features=32) for _ in range(14)])
self.avg = nn.ModuleList([nn.AvgPool3d(kernel_size=(2,2,2)) for _ in range(6)])
self.dense_100 = nn.Linear(in_features=512, out_features=100)
self.dense = nn.Linear(in_features=100, out_features=n_classes)
self.bn = nn.BatchNorm1d(num_features=100)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x1 = self.conv3d_5_2[0](x)
x1 = self.bn3d[0](x1)
x1 = self.conv3d_3_1[0](x1)
x1 = self.bn3d[1](x1)
x1 = self.conv3d_3_1[1](x1)
x1 = self.bn3d[2](x1)
x1 = self.avg[0](x1)
x1 = self.conv3d_3_1[2](x1)
x1 = self.bn3d[3](x1)
x1 = self.conv3d_3_1[3](x1)
x1 = self.bn3d[4](x1)
x1 = self.avg[1](x1)
x1 = self.conv3d_3_1[4](x1)
x1 = self.bn3d[5](x1)
x1 = self.conv3d_3_1[5](x1)
x1 = self.bn3d[6](x1)
x1 = self.avg[2](x1)
x1 = torch.flatten(x1, start_dim=1)
x2 = self.conv3d_5_2[1](x)
x2 = self.bn3d[7](x2)
x2 = self.conv3d_3_1[6](x2)
x2 = self.bn3d[8](x2)
x2 = self.conv3d_3_1[7](x2)
x2 = self.bn3d[9](x2)
x2 = self.avg[3](x2)
x2 = self.conv3d_3_1[8](x2)
x2 = self.bn3d[10](x2)
x2 = self.conv3d_3_1[9](x2)
x2 = self.bn3d[11](x2)
x2 = self.avg[4](x2)
x2 = self.conv3d_3_1[10](x2)
x2 = self.bn3d[12](x2)
x2 = self.conv3d_3_1[11](x2)
x2 = self.bn3d[13](x2)
x2 = self.avg[5](x2)
x2 = torch.flatten(x2, start_dim=1)
x = torch.cat((x1,x2), dim=1)
x = self.dense_100(x)
x = self.bn(x)
x = self.dense(x)
output = self.softmax(x)
return output
def test():
model = Siamese3D(n_classes=5)
print(model)
input = torch.randn(3, 1, 91, 109, 91)
out = model(input)
print(f"For input {input.size()}, output is {out.size()}")
if __name__ == '__main__':
test()