-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
111 lines (95 loc) · 3.38 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import sys
import numpy as np
from math import pi, cos
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50
def get_backbone(backbone, castrate=True):
if castrate:
backbone.output_dim = backbone.fc.in_features
backbone.fc = torch.nn.Identity()
return backbone
def D(p, z,): # negative cosine similarity
return - F.cosine_similarity(p, z.detach(), dim=-1).mean()
class projection_MLP(nn.Module):
def __init__(self, in_dim, hidden_dim=2048, out_dim=2048):
super().__init__()
''' page 3 baseline setting
Projection MLP. The projection MLP (in f) has BN ap-
plied to each fully-connected (fc) layer, including its out-
put fc. Its output fc has no ReLU. The hidden fc is 2048-d.
This MLP has 3 layers.
'''
self.layer1 = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True)
)
self.layer2 = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True)
)
self.layer3 = nn.Sequential(
nn.Linear(hidden_dim, out_dim),
nn.BatchNorm1d(hidden_dim)
)
self.num_layers = 3
def set_layers(self, num_layers):
self.num_layers = num_layers
def forward(self, x):
if self.num_layers == 3:
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
elif self.num_layers == 2:
x = self.layer1(x)
x = self.layer3(x)
else:
raise Exception
return x
class prediction_MLP(nn.Module):
def __init__(self, in_dim=2048, hidden_dim=512, out_dim=2048): # bottleneck structure
super().__init__()
''' page 3 baseline setting
Prediction MLP. The prediction MLP (h) has BN applied
to its hidden fc layers. Its output fc does not have BN
(ablation in Sec. 4.4) or ReLU. This MLP has 2 layers.
The dimension of h’s input and output (z and p) is d = 2048,
and h’s hidden layer’s dimension is 512, making h a
bottleneck structure (ablation in supplement).
'''
self.layer1 = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True)
)
self.layer2 = nn.Linear(hidden_dim, out_dim)
"""
Adding BN to the output of the prediction MLP h does not work
well (Table 3d). We find that this is not about collapsing.
The training is unstable and the loss oscillates.
"""
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x
class SimSiam(nn.Module):
def __init__(self, backbone=resnet50()):
super().__init__()
self.backbone = get_backbone(backbone)
self.projector = projection_MLP(backbone.output_dim)
self.encoder = nn.Sequential( # f encoder
self.backbone,
self.projector
)
self.predictor = prediction_MLP()
def forward(self, x1, x2):
f, h = self.encoder, self.predictor
z1, z2 = f(x1), f(x2)
p1, p2 = h(z1), h(z2)
L = D(p1, z2) / 2 + D(p2, z1) / 2
return L