-
Notifications
You must be signed in to change notification settings - Fork 13
/
simclr.py
156 lines (122 loc) · 5.4 KB
/
simclr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import hydra
from omegaconf import DictConfig
import logging
import numpy as np
from PIL import Image
import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18, resnet34
from torchvision import transforms
from models import SimCLR
from tqdm import tqdm
logger = logging.getLogger(__name__)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name):
self.name = name
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class CIFAR10Pair(CIFAR10):
"""Generate mini-batche pairs on CIFAR10 training set."""
def __getitem__(self, idx):
img, target = self.data[idx], self.targets[idx]
img = Image.fromarray(img) # .convert('RGB')
imgs = [self.transform(img), self.transform(img)]
return torch.stack(imgs), target # stack a positive pair
def nt_xent(x, t=0.5):
x = F.normalize(x, dim=1)
x_scores = (x @ x.t()).clamp(min=1e-7) # normalized cosine similarity scores
x_scale = x_scores / t # scale with temperature
# (2N-1)-way softmax without the score of i-th entry itself.
# Set the diagonals to be large negative values, which become zeros after softmax.
x_scale = x_scale - torch.eye(x_scale.size(0)).to(x_scale.device) * 1e5
# targets 2N elements.
targets = torch.arange(x.size()[0])
targets[::2] += 1 # target of 2k element is 2k+1
targets[1::2] -= 1 # target of 2k+1 element is 2k
return F.cross_entropy(x_scale, targets.long().to(x_scale.device))
def get_lr(step, total_steps, lr_max, lr_min):
"""Compute learning rate according to cosine annealing schedule."""
return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))
# color distortion composed by color jittering and color dropping.
# See Section A of SimCLR: https://arxiv.org/abs/2002.05709
def get_color_distortion(s=0.5): # 0.5 for CIFAR10 by default
# s is the strength of color distortion
color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
rnd_gray = transforms.RandomGrayscale(p=0.2)
color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
return color_distort
@hydra.main(config_path='simclr_config.yml')
def train(args: DictConfig) -> None:
assert torch.cuda.is_available()
cudnn.benchmark = True
train_transform = transforms.Compose([transforms.RandomResizedCrop(32),
transforms.RandomHorizontalFlip(p=0.5),
get_color_distortion(s=0.5),
transforms.ToTensor()])
data_dir = hydra.utils.to_absolute_path(args.data_dir) # get absolute path of data dir
train_set = CIFAR10Pair(root=data_dir,
train=True,
transform=train_transform,
download=True)
train_loader = DataLoader(train_set,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
drop_last=True)
# Prepare model
assert args.backbone in ['resnet18', 'resnet34']
base_encoder = eval(args.backbone)
model = SimCLR(base_encoder, projection_dim=args.projection_dim).cuda()
logger.info('Base model: {}'.format(args.backbone))
logger.info('feature dim: {}, projection dim: {}'.format(model.feature_dim, args.projection_dim))
optimizer = torch.optim.SGD(
model.parameters(),
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov=True)
# cosine annealing lr
scheduler = LambdaLR(
optimizer,
lr_lambda=lambda step: get_lr( # pylint: disable=g-long-lambda
step,
args.epochs * len(train_loader),
args.learning_rate, # lr_lambda computes multiplicative factor
1e-3))
# SimCLR training
model.train()
for epoch in range(1, args.epochs + 1):
loss_meter = AverageMeter("SimCLR_loss")
train_bar = tqdm(train_loader)
for x, y in train_bar:
sizes = x.size()
x = x.view(sizes[0] * 2, sizes[2], sizes[3], sizes[4]).cuda(non_blocking=True)
optimizer.zero_grad()
feature, rep = model(x)
loss = nt_xent(rep, args.temperature)
loss.backward()
optimizer.step()
scheduler.step()
loss_meter.update(loss.item(), x.size(0))
train_bar.set_description("Train epoch {}, SimCLR loss: {:.4f}".format(epoch, loss_meter.avg))
# save checkpoint very log_interval epochs
if epoch >= args.log_interval and epoch % args.log_interval == 0:
logger.info("==> Save checkpoint. Train epoch {}, SimCLR loss: {:.4f}".format(epoch, loss_meter.avg))
torch.save(model.state_dict(), 'simclr_{}_epoch{}.pt'.format(args.backbone, epoch))
if __name__ == '__main__':
train()