-
Notifications
You must be signed in to change notification settings - Fork 0
/
callbacks.py
77 lines (71 loc) · 2.76 KB
/
callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import tensorflow as tf
from keras.callbacks import Callback
from keras.callbacks import TensorBoard
class NBatchLogger(Callback):
"""
A logger which log average performaces per `display` batches
"""
def __init__(self, display):
self.curr_step = 0
self.display = display
self.metric_cache = {}
def on_batch_end(self, batch, logs={}):
self.step += 1
metrics = self.params['metrics']
for k in metrics:
if k in logs:
self.metric_cache.setdefault(k, 0)
self.metric_cache[k] += logs[k]
if self.step % self.display == 0:
metrics_log = ''
for (k, v) in self.metric_cache.items():
val = v / self.display
if abs(val) > 1e-3:
metrics_log += ' - %s: %.4f' % (k, val)
else:
metrics_log += ' - %s: %.4f' % (k, val)
print('step: {}/{} ...{}'.format(self.step,
self.params['steps'],
metrics_log)
)
self.metric_cache.clear()
class BatchTensorBoard(TensorBoard):
"""
add summaries to TensorBoard on batch end.
"""
def __init__(self, log_dir='./logs',
histogram_freq=0,
batch_size=32,
write_graph=True,
write_grads=False,
write_images=False,
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None):
super(BatchTensorBoard, self).__init__(log_dir,
histogram_freq,
batch_size,
write_graph,
write_grads,
write_images,
embeddings_freq,
embeddings_layer_names,
embeddings_metadata)
self.seen = 0
#TODO: find a better to deal with this
def on_epoch_end(self, epoch, logs=None):
"""
to avoid the summary being overrite
"""
pass
def on_batch_end(self, batch, logs={}):
for name, value in logs.items():
if name == 'batch' or name == 'size':
continue
summary = tf.Summary()
summary_value = summary.value.add()
summary_value.simple_value = value.item()
summary_value.tag = name
self.writer.add_summary(summary, self.seen)
self.writer.flush()
self.seen += self.batch_size