forked from tatz1101/Edge-AI-Platform-Tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_training_log.py
199 lines (177 loc) · 7.17 KB
/
plot_training_log.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
#modified by Daniele Bagni on 19 Sept 2018
#!/usr/bin/env python
import inspect
import os
import random
import sys
import matplotlib
matplotlib.use('Agg')
import matplotlib.cm as cmx
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import matplotlib.legend as lgd
import matplotlib.markers as mks
from config import cifar10_config as config
caffe_path = config.CAFFE_ROOT # i.e. "/caffe/Caffe-SSD-Ristretto"
def get_log_parsing_script():
#dirname = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
dirname = caffe_path + "/tools/extra"
return dirname + '/parse_log.sh'
def get_log_file_suffix():
return '.log'
def get_chart_type_description_separator():
return ' vs. '
def is_x_axis_field(field):
x_axis_fields = ['Iters', 'Seconds']
return field in x_axis_fields
def create_field_index():
train_key = 'Train'
test_key = 'Test'
field_index = {train_key:{'Iters':0, 'Seconds':1, train_key + ' loss':2,
train_key + ' learning rate':3},
test_key:{'Iters':0, 'Seconds':1, test_key + ' accuracy':2,
test_key + ' loss':3}}
fields = set()
for data_file_type in field_index.keys():
fields = fields.union(set(field_index[data_file_type].keys()))
fields = list(fields)
fields.sort()
return field_index, fields
def get_supported_chart_types():
field_index, fields = create_field_index()
num_fields = len(fields)
supported_chart_types = []
for i in xrange(num_fields):
if not is_x_axis_field(fields[i]):
for j in xrange(num_fields):
if i != j and is_x_axis_field(fields[j]):
supported_chart_types.append('%s%s%s' % (
fields[i], get_chart_type_description_separator(),
fields[j]))
return supported_chart_types
def get_chart_type_description(chart_type):
supported_chart_types = get_supported_chart_types()
chart_type_description = supported_chart_types[chart_type]
return chart_type_description
def get_data_file_type(chart_type):
description = get_chart_type_description(chart_type)
data_file_type = description.split()[0]
return data_file_type
def get_data_file(chart_type, path_to_log):
return (os.path.basename(path_to_log) + '.' +
get_data_file_type(chart_type).lower())
def get_field_descriptions(chart_type):
description = get_chart_type_description(chart_type).split(
get_chart_type_description_separator())
y_axis_field = description[0]
x_axis_field = description[1]
return x_axis_field, y_axis_field
def get_field_indices(x_axis_field, y_axis_field):
data_file_type = get_data_file_type(chart_type)
fields = create_field_index()[0][data_file_type]
return fields[x_axis_field], fields[y_axis_field]
def load_data(data_file, field_idx0, field_idx1):
data = [[], []]
with open(data_file, 'r') as f:
for line in f:
line = line.strip()
if line[0] != '#':
fields = line.split()
data[0].append(float(fields[field_idx0].strip()))
data[1].append(float(fields[field_idx1].strip()))
return data
def random_marker():
markers = mks.MarkerStyle.markers
num = len(markers.values())
idx = random.randint(0, num - 1)
return markers.values()[idx]
def get_data_label(path_to_log):
label = path_to_log[path_to_log.rfind('/')+1 : path_to_log.rfind(
get_log_file_suffix())]
return label
def get_legend_loc(chart_type):
x_axis, y_axis = get_field_descriptions(chart_type)
loc = 'lower right'
if y_axis.find('accuracy') != -1:
pass
if y_axis.find('loss') != -1 or y_axis.find('learning rate') != -1:
loc = 'upper right'
return loc
def plot_chart(chart_type, path_to_png, path_to_log_list):
for path_to_log in path_to_log_list:
os.system('%s %s' % (get_log_parsing_script(), path_to_log))
data_file = get_data_file(chart_type, path_to_log)
x_axis_field, y_axis_field = get_field_descriptions(chart_type)
x, y = get_field_indices(x_axis_field, y_axis_field)
data = load_data(data_file, x, y)
## TODO: more systematic color cycle for lines
color = [random.random(), random.random(), random.random()]
label = get_data_label(path_to_log)
linewidth = 0.75
## If there too many datapoints, do not use marker.
## use_marker = False
use_marker = True
if not use_marker:
plt.plot(data[0], data[1], label = label, color = color,
linewidth = linewidth)
else:
ok = False
## Some markers throw ValueError: Unrecognized marker style
while not ok:
try:
marker = random_marker()
plt.plot(data[0], data[1], label = label, color = color,
marker = marker, linewidth = linewidth)
ok = True
except:
pass
legend_loc = get_legend_loc(chart_type)
plt.legend(loc = legend_loc, ncol = 1) # ajust ncol to fit the space
plt.title(get_chart_type_description(chart_type))
plt.xlabel(x_axis_field)
plt.ylabel(y_axis_field)
plt.savefig(path_to_png)
#plt.show()
def print_help():
print """This script mainly serves as the basis of your customizations.
Customization is a must.
You can copy, paste, edit them in whatever way you want.
Be warned that the fields in the training log may change in the future.
You had better check the data files and change the mapping from field name to
field index in create_field_index before designing your own plots.
Usage:
./plot_training_log.py chart_type[0-%s] /where/to/save.png /path/to/first.log ...
Notes:
1. Supporting multiple logs.
2. Log file name must end with the lower-cased "%s".
Supported chart types:""" % (len(get_supported_chart_types()) - 1,
get_log_file_suffix())
supported_chart_types = get_supported_chart_types()
num = len(supported_chart_types)
for i in xrange(num):
print ' %d: %s' % (i, supported_chart_types[i])
sys.exit()
def is_valid_chart_type(chart_type):
return chart_type >= 0 and chart_type < len(get_supported_chart_types())
if __name__ == '__main__':
if len(sys.argv) < 4:
print_help()
else:
chart_type = int(sys.argv[1])
if not is_valid_chart_type(chart_type):
print '%s is not a valid chart type.' % chart_type
print_help()
path_to_png = sys.argv[2]
if not path_to_png.endswith('.png'):
print 'Path must ends with png' % path_to_png
sys.exit()
path_to_logs = sys.argv[3:]
for path_to_log in path_to_logs:
if not os.path.exists(path_to_log):
print 'Path does not exist: %s' % path_to_log
sys.exit()
if not path_to_log.endswith(get_log_file_suffix()):
print 'Log file must end in %s.' % get_log_file_suffix()
print_help()
## plot_chart accpets multiple path_to_logs
plot_chart(chart_type, path_to_png, path_to_logs)