-
Notifications
You must be signed in to change notification settings - Fork 13
/
models.py
28 lines (20 loc) · 1.01 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch.nn as nn
class SimCLR(nn.Module):
def __init__(self, base_encoder, projection_dim=128):
super().__init__()
self.enc = base_encoder(pretrained=False) # load model from torchvision.models without pretrained weights.
self.feature_dim = self.enc.fc.in_features
# Customize for CIFAR10. Replace conv 7x7 with conv 3x3, and remove first max pooling.
# See Section B.9 of SimCLR paper.
self.enc.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
self.enc.maxpool = nn.Identity()
self.enc.fc = nn.Identity() # remove final fully connected layer.
# Add MLP projection.
self.projection_dim = projection_dim
self.projector = nn.Sequential(nn.Linear(self.feature_dim, 2048),
nn.ReLU(),
nn.Linear(2048, projection_dim))
def forward(self, x):
feature = self.enc(x)
projection = self.projector(feature)
return feature, projection