-
Notifications
You must be signed in to change notification settings - Fork 81
/
logger.py
33 lines (25 loc) · 1.23 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""Wrapper class for logging into the TensorBoard and comet.ml"""
__author__ = 'Erdene-Ochir Tuguldur'
__all__ = ['Logger']
import os
from tensorboardX import SummaryWriter
from hparams import HParams as hp
class Logger(object):
def __init__(self, dataset_name, model_name):
self.model_name = model_name
self.project_name = "%s-%s" % (dataset_name, self.model_name)
self.logdir = os.path.join(hp.logdir, self.project_name)
self.writer = SummaryWriter(log_dir=self.logdir)
def log_step(self, phase, step, loss_dict, image_dict):
if phase == 'train':
if step % 50 == 0:
# self.writer.add_scalar('lr', get_lr(), step)
# self.writer.add_scalar('%s-step/loss' % phase, loss, step)
for key in sorted(loss_dict):
self.writer.add_scalar('%s-step/%s' % (phase, key), loss_dict[key], step)
if step % 1000 == 0:
for key in sorted(image_dict):
self.writer.add_image('%s/%s' % (self.model_name, key), image_dict[key], step)
def log_epoch(self, phase, step, loss_dict):
for key in sorted(loss_dict):
self.writer.add_scalar('%s/%s' % (phase, key), loss_dict[key], step)