|
1 | 1 | import logging
|
2 | 2 | import os
|
3 |
| -import shutil |
4 | 3 | import sys
|
5 | 4 |
|
6 | 5 | import torch
|
7 | 6 | import torch.nn.functional as F
|
8 |
| -import yaml |
9 | 7 | from torch.utils.tensorboard import SummaryWriter
|
10 | 8 | from tqdm import tqdm
|
11 | 9 |
|
| 10 | +from utils import save_config_file, accuracy, save_checkpoint |
| 11 | + |
12 | 12 | torch.manual_seed(0)
|
13 | 13 |
|
14 | 14 | apex_support = False
|
|
22 | 22 | apex_support = False
|
23 | 23 |
|
24 | 24 |
|
25 |
| -def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): |
26 |
| - torch.save(state, filename) |
27 |
| - if is_best: |
28 |
| - shutil.copyfile(filename, 'model_best.pth.tar') |
29 |
| - |
30 |
| - |
31 |
| -def _save_config_file(model_checkpoints_folder, args): |
32 |
| - if not os.path.exists(model_checkpoints_folder): |
33 |
| - os.makedirs(model_checkpoints_folder) |
34 |
| - with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile: |
35 |
| - yaml.dump(args, outfile, default_flow_style=False) |
36 |
| - |
37 |
| - |
38 |
| -def accuracy(output, target, topk=(1,)): |
39 |
| - """Computes the accuracy over the k top predictions for the specified values of k""" |
40 |
| - with torch.no_grad(): |
41 |
| - maxk = max(topk) |
42 |
| - batch_size = target.size(0) |
43 |
| - |
44 |
| - _, pred = output.topk(maxk, 1, True, True) |
45 |
| - pred = pred.t() |
46 |
| - correct = pred.eq(target.view(1, -1).expand_as(pred)) |
47 |
| - |
48 |
| - res = [] |
49 |
| - for k in topk: |
50 |
| - correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) |
51 |
| - res.append(correct_k.mul_(100.0 / batch_size)) |
52 |
| - return res |
53 |
| - |
54 |
| - |
55 | 25 | class SimCLR(object):
|
56 | 26 |
|
57 | 27 | def __init__(self, *args, **kwargs):
|
@@ -86,7 +56,7 @@ def train(self, train_loader):
|
86 | 56 | opt_level='O2',
|
87 | 57 | keep_batchnorm_fp32=True)
|
88 | 58 | # save config file
|
89 |
| - _save_config_file(self.writer.log_dir, self.args) |
| 59 | + save_config_file(self.writer.log_dir, self.args) |
90 | 60 |
|
91 | 61 | n_iter = 0
|
92 | 62 | logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")
|
|
0 commit comments