-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_metrics.py
92 lines (68 loc) · 2.13 KB
/
plot_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import sys
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
fig_dir = Path('figures')
fig_dir.mkdir(exist_ok=True)
logpath = sys.argv[1]
with open(logpath, 'r') as logfile:
logs = pd.Series(logfile, dtype='string')
logs = logs.str.rstrip() # remove trailing newlines
#
# extract training accuracy
#
train_mask = logs.str.match('\[TRAIN\]')
train_stats = logs[train_mask].str.extract(
'[0-9.]+.*?[0-9.]+.*?([0-9.]+).*?([0-9.]+)')
train_stats = train_stats.astype(float)
train_stats = train_stats.rename({
0 : 'per-char',
1 : 'parser',
}, axis='columns')
train_stats = train_stats.reset_index(drop=True)
train_stats.insert(0, 'batch', (train_stats.index // 100) * 100)
train_results = pd.melt(
train_stats,
id_vars=['batch'],
var_name='metric',
value_name='accuracy')
#
# extract test accuracy
#
test_mask = logs.str.match('\[TEST\]')
test_stats = logs[test_mask].str.extract(
'[0-9.]+.*?([0-9.]+).*?([0-9.]+).*?([0-9.]+).*?([0-9.]+)')
test_stats = test_stats.astype(float)
test_stats = test_stats.rename({
0 : 'per-char (generated)',
1 : 'parser (generated)',
2 : 'per-char (real)',
3 : 'parser (real)',
}, axis='columns')
test_stats = test_stats.reset_index(drop=True)
test_stats = test_stats.reset_index(names='epoch')
test_results = pd.melt(
test_stats,
id_vars=['epoch'],
var_name='metric',
value_name='accuracy')
#
# plot results
#
sns.set_theme(context='paper', style='whitegrid', font='Arimo')
facecolor = '#f8f5f0'
fig, axs = plt.subplots(1, 2, figsize=(8,3), facecolor=facecolor)
# training accuracy
sns.lineplot(data=train_results, x='batch', y='accuracy', hue='metric', errorbar=('pi', 100), ax=axs[0])
axs[0].set_xlabel('Batch number')
axs[0].set_ylabel('Accuracy (%)')
axs[0].set_xticks([0,30000,60000,90000,120000])
axs[0].set_title('training')
# test accuracy
sns.lineplot(data=test_results, x='epoch', y='accuracy', hue='metric', ax=axs[1])
axs[1].set_xlabel('Epoch')
axs[1].set_ylabel('Accuracy (%)')
axs[1].set_title('test')
fig.tight_layout()
fig.savefig(fig_dir / 'metric_plots.svg', bbox_inches='tight')