-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
125 lines (112 loc) · 3.82 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
EECS 445 - Introduction to Machine Learning
Fall 2018 - Project 2
Utility functions
"""
import os
import numpy as np
import matplotlib.pyplot as plt
def config(attr):
"""
Retrieves the queried attribute value from the config file. Loads the
config file on first call.
"""
if not hasattr(config, 'config'):
with open('config.json') as f:
config.config = eval(f.read())
node = config.config
for part in attr.split('.'):
node = node[part]
return node
def denormalize_image(image):
""" Rescale the image's color space from (min, max) to (0, 1) """
ptp = np.max(image, axis=(0,1)) - np.min(image, axis=(0,1))
return (image - np.min(image, axis=(0,1))) / ptp
def hold_training_plot():
"""
Keep the program alive to display the training plot
"""
plt.ioff()
plt.show()
def log_cnn_training(epoch, stats):
"""
Logs the validation accuracy and loss to the terminal
"""
valid_acc, valid_loss, train_acc, train_loss = stats[-1]
print('Epoch {}'.format(epoch))
print('\tValidation Loss: {}'.format(valid_loss))
print('\tValidation Accuracy: {}'.format(valid_acc))
print('\tTrain Loss: {}'.format(train_loss))
print('\tTrain Accuracy: {}'.format(train_acc))
def make_cnn_training_plot():
"""
Runs the setup for an interactive matplotlib graph that logs the loss and
accuracy
"""
plt.ion()
fig, axes = plt.subplots(1,2, figsize=(10,5))
plt.suptitle('CNN Training')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
return axes
def update_cnn_training_plot(axes, epoch, stats):
"""
Updates the training plot with a new data point for loss and accuracy
"""
valid_acc = [s[0] for s in stats]
valid_loss = [s[1] for s in stats]
train_acc = [s[2] for s in stats]
train_loss = [s[3] for s in stats]
axes[0].plot(range(epoch - len(stats) + 1, epoch + 1), valid_acc,
linestyle='--', marker='o', color='b')
axes[0].plot(range(epoch - len(stats) + 1, epoch + 1), train_acc,
linestyle='--', marker='o', color='r')
axes[0].legend(['Validation', 'Train'])
axes[1].plot(range(epoch - len(stats) + 1, epoch + 1), valid_loss,
linestyle='--', marker='o', color='b')
axes[1].plot(range(epoch - len(stats) + 1, epoch + 1), train_loss,
linestyle='--', marker='o', color='r')
axes[1].legend(['Validation', 'Train'])
plt.pause(0.00001)
def save_cnn_training_plot():
"""
Saves the training plot to a file
"""
plt.savefig('cnn_training_plot.png', dpi=200)
def log_ae_training(epoch, stats):
"""
Logs the validation loss to the terminal
"""
valid_loss, train_loss = stats[-1]
print('Epoch {}'.format(epoch))
print('\tValidation Mean squared error loss: {}'.format(valid_loss))
print('\tTrain Mean squared error loss: {}'.format(train_loss))
def make_ae_training_plot():
"""
Runs the setup for an interactive matplotlib graph that logs the loss
"""
plt.ion()
fig, axes = plt.subplots(1,1, figsize=(5,5))
plt.suptitle('Autoencoder Training')
axes.set_xlabel('Epoch')
axes.set_ylabel('MSE')
return axes
def update_ae_training_plot(axes, epoch, stats):
"""
Updates the training plot with a new data point for loss
"""
valid_loss = [s[0] for s in stats]
train_loss = [s[1] for s in stats]
axes.plot(range(epoch - len(stats) + 1, epoch + 1), valid_loss,
linestyle='--', marker='o', color='b')
axes.plot(range(epoch - len(stats) + 1, epoch + 1), train_loss,
linestyle='--', marker='o', color='r')
axes.legend(['Validation', 'Train'])
plt.pause(0.00001)
def save_ae_training_plot():
"""
Saves the training plot to a file
"""
plt.savefig('ae_training_plot.png', dpi=200)