Skip to content

Commit 651d40a

Browse files
committed
Major refactor, small fixes
1 parent f60e9b8 commit 651d40a

File tree

2 files changed

+38
-33
lines changed

2 files changed

+38
-33
lines changed

simclr.py

+3-33
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import logging
22
import os
3-
import shutil
43
import sys
54

65
import torch
76
import torch.nn.functional as F
8-
import yaml
97
from torch.utils.tensorboard import SummaryWriter
108
from tqdm import tqdm
119

10+
from utils import save_config_file, accuracy, save_checkpoint
11+
1212
torch.manual_seed(0)
1313

1414
apex_support = False
@@ -22,36 +22,6 @@
2222
apex_support = False
2323

2424

25-
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
26-
torch.save(state, filename)
27-
if is_best:
28-
shutil.copyfile(filename, 'model_best.pth.tar')
29-
30-
31-
def _save_config_file(model_checkpoints_folder, args):
32-
if not os.path.exists(model_checkpoints_folder):
33-
os.makedirs(model_checkpoints_folder)
34-
with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile:
35-
yaml.dump(args, outfile, default_flow_style=False)
36-
37-
38-
def accuracy(output, target, topk=(1,)):
39-
"""Computes the accuracy over the k top predictions for the specified values of k"""
40-
with torch.no_grad():
41-
maxk = max(topk)
42-
batch_size = target.size(0)
43-
44-
_, pred = output.topk(maxk, 1, True, True)
45-
pred = pred.t()
46-
correct = pred.eq(target.view(1, -1).expand_as(pred))
47-
48-
res = []
49-
for k in topk:
50-
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
51-
res.append(correct_k.mul_(100.0 / batch_size))
52-
return res
53-
54-
5525
class SimCLR(object):
5626

5727
def __init__(self, *args, **kwargs):
@@ -86,7 +56,7 @@ def train(self, train_loader):
8656
opt_level='O2',
8757
keep_batchnorm_fp32=True)
8858
# save config file
89-
_save_config_file(self.writer.log_dir, self.args)
59+
save_config_file(self.writer.log_dir, self.args)
9060

9161
n_iter = 0
9262
logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")

utils.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import os
2+
import shutil
3+
4+
import torch
5+
import yaml
6+
7+
8+
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
9+
torch.save(state, filename)
10+
if is_best:
11+
shutil.copyfile(filename, 'model_best.pth.tar')
12+
13+
14+
def save_config_file(model_checkpoints_folder, args):
15+
if not os.path.exists(model_checkpoints_folder):
16+
os.makedirs(model_checkpoints_folder)
17+
with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile:
18+
yaml.dump(args, outfile, default_flow_style=False)
19+
20+
21+
def accuracy(output, target, topk=(1,)):
22+
"""Computes the accuracy over the k top predictions for the specified values of k"""
23+
with torch.no_grad():
24+
maxk = max(topk)
25+
batch_size = target.size(0)
26+
27+
_, pred = output.topk(maxk, 1, True, True)
28+
pred = pred.t()
29+
correct = pred.eq(target.view(1, -1).expand_as(pred))
30+
31+
res = []
32+
for k in topk:
33+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
34+
res.append(correct_k.mul_(100.0 / batch_size))
35+
return res

0 commit comments

Comments
 (0)