-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtensorboard.py
59 lines (47 loc) · 1.97 KB
/
tensorboard.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
class Tensorboard:
""" Creates Tensorboard visualization """
def __init__(self, logdir, images):
self.tb = SummaryWriter(logdir)
self.images = images
def tbimages(self, model, name, epoch):
"""Add image data to summary
Arguments:
model: Network model
name(string or list): Name of the image
epoch(integer): Current epoch
"""
images = model(self.images)
if isinstance(images, tuple):
for idx,img in enumerate(images):
image = torch.sigmoid(img).detach().cpu()
img_grid = torchvision.utils.make_grid(image)
self.tb.add_image(f'img_{idx}_{name}_{epoch}.png', img_grid)
torchvision.utils.save_image(image,f'img_{idx}_{name}_{epoch}.png')
else:
images = torch.sigmoid(images)
images = images.detach().cpu()
img_grid = torchvision.utils.make_grid(images)
self.tb.add_image(f'img_{name}_{epoch}.png', img_grid)
torchvision.utils.save_image(images,f'img_{name}_{epoch}.png')
def tbmodel(self, model):
"""Add model to summary
Arguments:
model: Network model
model_input:
"""
self.tb.add_graph(model, self.images)
def tbmetrics(self,name, data, x_label):
"""
Add scalar data to summary
Arguments:
name(string or list): Name of the metric
value(float or list): Metric value
"""
if isinstance(name, str) and isinstance(data, float):
self.tb.add_scalar(name, data, x_label)
elif isinstance(name, str) and isinstance(data, dict):
for metric_name, metric_tensor in data.items():
self.tb.add_scalar(metric_name, metric_tensor, x_label)