-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_nn.py
158 lines (129 loc) · 4.45 KB
/
run_nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Author: Tim Ruhkopf
Email: [email protected]
Purpose: This script is intended to check iff the ResNet Model really learns
sth given sufficient capability.
"""
import torch
from torch.utils.data import TensorDataset, DataLoader
import pyro
import matplotlib
import os
import pickle
import datetime
from pathlib import Path
from src.resnet import ResNet
from src.utils import load_npz_kmnist, get_git_revision_short_hash
from src.blackboxpipe import BlackBoxPipe
matplotlib.use('Agg')
# Seeding & githash for reproducibility.
pyro.set_rng_seed(0)
torch.manual_seed(0)
git_hash = get_git_revision_short_hash()
# (0) Setup your computation device / plotting method. ------------------------
TEST = False
ROOT_DATA = 'Data/Raw/'
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
INIT_LAMB = -3 # 0.01
EPS = 0.
NOISE = 0.
# SEARCH_SPACE = (10e-5, 10e-1)
SEARCH_SPACE = (-5, -1)
if TEST:
BUDGET = 3
EPOCHS = 1
BATCH_SIZE = 1
resnet_config = dict(img_size=(28, 28),
architecture=((1, 2), (2, 2, 2)),
no_classes=10)
else:
# FULLRUN CONFIG
EPOCHS = 10
BATCH_SIZE = 8
BUDGET = 10
resnet_config = dict(img_size=(28, 28),
architecture=(
(1, 8), (8, 16, 16), (16, 16, 16), (16, 16, 16),
(16, 32, 32), (32, 32, 32)),
no_classes=10)
# Define the Name of the RUN.
s = '{:%Y%m%d_%H%M%S}'
timestamp = s.format(datetime.datetime.now())
RUNIDX = 'run_{}_{}'.format(git_hash, timestamp) # Run name
# (1) loading data & preprocessing according to
# https://github.com/rois-codh/kmnist/blob/master/benchmarks/kuzushiji_mnist_cnn.py
# Load the data ---------------------------------------------------------------
x_train, x_test, y_train, y_test = load_npz_kmnist(
folder=ROOT_DATA,
files=['kmnist-train-imgs.npz', 'kmnist-test-imgs.npz',
'kmnist-train-labels.npz', 'kmnist-test-labels.npz'])
if TEST:
n = 1001 # len(x_train)
x_train = x_train[:n]
y_train = y_train[:n]
x_test = x_test[:int(n / 10)]
y_test = y_test[:int(n / 10)]
# (2) Adjust the Data & create datapipeline. ----------------------------------
# Adjust X to 0 - 1 range.
x_train /= 255.
x_test /= 255.
# Convert y float to int.
y_train = y_train.type(torch.LongTensor)
y_test = y_test.type(torch.LongTensor)
# Add channel information/dim (greyscale image).
x_train = torch.unsqueeze(x_train, dim=1)
x_test = torch.unsqueeze(x_test, dim=1)
# Descriptive info of the dataset.
print("y's shape: {}\nx's shape: {}".format(y_train.shape, x_train.shape))
# Create Dataset & dataloader for train & test.
trainset = TensorDataset(x_train, y_train)
testset = TensorDataset(x_test, y_test)
trainloader = DataLoader(
trainset, batch_size=BATCH_SIZE,
shuffle=True, num_workers=0)
testloader = DataLoader(
testset, batch_size=1,
shuffle=True, num_workers=0)
# (3) Model setup. ------------------------------------------------------------
resnet = ResNet(**resnet_config)
resnet.to(DEVICE)
# (4) Create, track & run-config for a model with sgd under a specific
# learning rate.
# optionally create the directory
root = os.getcwd()
modeldir = root + '/models/{}'.format(RUNIDX)
Path(modeldir).mkdir(parents=True, exist_ok=True)
# create the model's training & testing protocol
pipe = BlackBoxPipe(
resnet, trainloader, testloader, epochs=EPOCHS,
path=modeldir, device=DEVICE)
# pipe.evaluate_model_with_SGD(0.003)
pipe.evaluate_model_with_SGD(0.001)
# remove the already written out model
del pipe.model
del pipe.trainloader
del pipe.testloader
filename = '{}/{}.pkl'.format(modeldir, RUNIDX)
with open(filename, 'wb') as handle:
pickle.dump(pipe, handle, protocol=pickle.HIGHEST_PROTOCOL)
if TEST:
# Pickle the tracker object
# to be capable of instantating the class once more and load the info
from src.blackboxpipe import BlackBoxPipe
from src.resnet import ResNet
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('TkAgg')
modelfolder = ''
file = ''
filename = '{}/{}'.format(modelfolder, file )
with open(filename, 'rb') as handle:
pipepickled = pickle.load(handle)
# plot the confusion matrix of the models
model_idx = 0
c_mat = pipepickled.confusion_matrices[model_idx]
confused = pd.DataFrame(c_mat.numpy())
sns.heatmap(confused, annot=True)
plt.show()