-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain_scratch.py
101 lines (85 loc) · 3.56 KB
/
train_scratch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from __future__ import print_function
import argparse # Python 命令行解析工具
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tensorboardX import SummaryWriter
from nets import resnet34, CNN, CNNCifar10, resnet18, resnet50, MLP, AlexNet, vgg8_bn
from utils import test, get_dataset
import warnings
warnings.filterwarnings('ignore')
def train(model, train_loader, optimizer, epoch):
model.train()
for idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
data, target = data.cuda(), target.cuda()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
def adjust_learning_rate(lr, optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = lr * (0.1 ** (epoch // 40))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# def get_model(dataset, net):
# if "mnist" in dataset:
# if net == "mlp":
# model = MLP().cuda()
# elif net == "lenet":
# model = CNN().cuda()
# elif net == "alexnet":
# model = AlexNet().cuda()
# elif dataset == "svhn":
# if net == "alexnet":
# model = CNNCifar10().cuda()
# elif net == "vgg":
# model = CNNCifar10().cuda()
# elif net == "resnet18":
# model = resnet18(num_classes=10).cuda()
# elif dataset == "cifar10":
# # model = resnet18(num_classes=10).cuda()
# model = CNNCifar10().cuda()
# elif dataset == "cifar100":
# model = resnet50(num_classes=100).cuda()
# elif dataset == "imagenet":
# model = resnet18(num_classes=12).cuda()
# return model
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--dataset', type=str, default="cifar10",
help='dataset')
parser.add_argument('--net', type=str, default="cifar10",
help='dataset')
parser.add_argument('--epochs', type=int, default=100,
help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.1,
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9,
help='SGD momentum (default: 0.9)')
parser.add_argument('--model', type=str, default='resnet34',
help='SGD momentum (default: 0.9)')
args = parser.parse_args()
train_loader, test_loader = get_dataset(args.dataset)
# model = get_teacher_model(args.dataset, args.net)
model = CNNCifar10().cuda()
# model = vgg8_bn(num_classes=10).cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
bst_acc = -1
public = "pretrained_large/{}_{}".format(args.dataset, args.net)
tf_writer = SummaryWriter(log_dir=public)
for epoch in range(1, args.epochs + 1):
# adjust_learning_rate(args.lr, optimizer, epoch)
train(model, train_loader, optimizer, epoch)
acc, loss = test(model, test_loader)
if acc > bst_acc:
bst_acc = acc
torch.save(model.state_dict(), '{}/{}_{}.pkl'.format(public, args.dataset, args.net))
tf_writer.add_scalar('test_acc', acc, epoch)
bst_acc = max(bst_acc, acc)
print("Epoch:{},\t test_acc:{}, best_acc:{}".format(epoch, acc, bst_acc))
if __name__ == '__main__':
main()