-
Notifications
You must be signed in to change notification settings - Fork 9
/
live_loss_plot.py
48 lines (37 loc) · 1.49 KB
/
live_loss_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import matplotlib.pyplot as plt
from keras.callbacks import Callback
from IPython.display import clear_output
#from matplotlib.ticker import FormatStrFormatter
# TODO
# object-oriented API
def translate_metric(x):
translations = {'acc': "Accuracy", 'loss': "Log-loss (cost function)"}
if x in translations:
return translations[x]
else:
return x
class PlotLosses(Callback):
def __init__(self, figsize=None):
super(PlotLosses, self).__init__()
self.figsize = figsize
def on_train_begin(self, logs={}):
self.base_metrics = [metric for metric in self.params['metrics'] if not metric.startswith('val_')]
self.logs = []
def on_epoch_end(self, epoch, logs={}):
self.logs.append(logs)
clear_output(wait=True)
plt.figure(figsize=self.figsize)
for metric_id, metric in enumerate(self.base_metrics):
plt.subplot(1, len(self.base_metrics), metric_id + 1)
plt.plot(range(1, len(self.logs) + 1),
[log[metric] for log in self.logs],
label="training")
if self.params['do_validation']:
plt.plot(range(1, len(self.logs) + 1),
[log['val_' + metric] for log in self.logs],
label="validation")
plt.title(translate_metric(metric))
plt.xlabel('epoch')
plt.legend(loc='center right')
plt.tight_layout()
plt.show();