forked from OndrejTexler/Few-Shot-Patch-Based-Training
-
Notifications
You must be signed in to change notification settings - Fork 0
/
logger.py
31 lines (24 loc) · 947 Bytes
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tensorflow as tf
import os
import shutil
class Logger(object):
def __init__(self, log_dir, suffix=None):
"""Create a summary writer logging to log_dir."""
# self.writer = tf.summary.create_file_writer(log_dir)
def scalar_summary(self, tag, value, step):
"""Log a scalar variable."""
# with self.writer.as_default():
# tf.summary.scalar(tag, value, step=step)
# self.writer.flush()
class ModelLogger(object):
def __init__(self, log_dir, save_func):
self.log_dir = log_dir
self.save_func = save_func
def save(self, model, epoch, isGenerator):
if isGenerator:
new_path = os.path.join(self.log_dir, "model_%05d.pth" % epoch)
else:
new_path = os.path.join(self.log_dir, "disc_%05d.pth" % epoch)
self.save_func(model, new_path)
def copy_file(self, source):
shutil.copy(source, self.log_dir)