-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain_frequentist.py
105 lines (82 loc) · 3.42 KB
/
main_frequentist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from __future__ import print_function
import os
import argparse
import torch
import numpy as np
import torch.nn as nn
from torch.optim import Adam
import data
import config_frequentist as cfg
from models.NonBayesianModels.AlexNet import AlexNet
from models.NonBayesianModels.LeNet import LeNet
from models.NonBayesianModels.ThreeConvThreeFC import ThreeConvThreeFC
# CUDA settings
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def getModel(net_type, inputs, outputs):
if (net_type == 'lenet'):
return LeNet(outputs, inputs)
elif (net_type == 'alexnet'):
return AlexNet(outputs, inputs)
elif (net_type == '3conv3fc'):
return ThreeConvThreeFC(outputs, inputs)
else:
raise ValueError('Network should be either [LeNet / AlexNet / 3Conv3FC')
def train_model(net, optimizer, criterion, train_loader):
train_loss = 0.0
net.train()
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = net(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()*data.size(0)
return train_loss
def validate_model(net, criterion, valid_loader):
valid_loss = 0.0
net.eval()
for data, target in valid_loader:
data, target = data.to(device), target.to(device)
output = net(data)
loss = criterion(output, target)
valid_loss += loss.item()*data.size(0)
return valid_loss
def run(dataset, net_type):
# Hyper Parameter settings
n_epochs = cfg.n_epochs
lr = cfg.lr
beta_type = cfg.beta_type
num_workers = cfg.num_workers
valid_size = cfg.valid_size
batch_size = cfg.batch_size
trainset, testset, inputs, outputs = data.getDataset(dataset)
train_loader, valid_loader, test_loader = data.getDataloader(
trainset, testset, valid_size, batch_size, num_workers)
net = getModel(net_type, inputs, outputs).to(device)
ckpt_dir = f'checkpoints/{dataset}/frequentist'
ckpt_name = f'checkpoints/{dataset}/frequentist/model_{net_type}.pt'
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir, exist_ok=True)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(net.parameters(), lr=lr)
valid_loss_min = np.Inf
for epoch in range(1, n_epochs+1):
train_loss = train_model(net, optimizer, criterion, train_loader)
valid_loss = validate_model(net, criterion, valid_loader)
train_loss = train_loss/len(train_loader.dataset)
valid_loss = valid_loss/len(valid_loader.dataset)
print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
epoch, train_loss, valid_loss))
# save model if validation loss has decreased
if valid_loss <= valid_loss_min:
print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(
valid_loss_min, valid_loss))
torch.save(net.state_dict(), ckpt_name)
valid_loss_min = valid_loss
if __name__ == '__main__':
parser = argparse.ArgumentParser(description = "PyTorch Frequentist Model Training")
parser.add_argument('--net_type', default='alexnet', type=str, help='model')
parser.add_argument('--dataset', default='CIFAR10', type=str, help='dataset = [MNIST/CIFAR10/CIFAR100]')
args = parser.parse_args()
run(args.dataset, args.net_type)