This repository has been archived by the owner on Feb 20, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 19
/
torchutils.py
69 lines (53 loc) · 1.83 KB
/
torchutils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
from torch import save
import torch.nn as nn
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import numpy as np
import shutil
cudaAvailable = False
if torch.cuda.is_available():
cudaAvailable = True
Tensor = torch.cuda.FloatTensor if cudaAvailable else torch.Tensor
def toZeroThreshold(x, t=0.1):
zeros = Tensor(x.shape).fill_(0.0)
return torch.where(x > t, x, zeros)
def weights_init(layer):
classname = layer.__class__.__name__
if classname.find('Conv') != -1:
nn.init.xavier_normal_(layer.weight.data)
class Plotter:
def __init__(self, attributes = [('trainloss', 1)]):
self.attributes = attributes
for dictionary in attributes:
attr = dictionary[0]
freq = dictionary[1]
setattr(self, attr, [])
setattr(self, attr+'_freq', freq )
def log(self, attr, value):
getattr(self, attr).append(value)
def savelog(self, filename):
pass
def plot(self, ylabel, attributes = None, ymax = None, filename = 'plot.png'):
plt.style.use('ggplot')
if ymax is not None:
plt.ylim(ymax=ymax)
plt.xlabel("epoch")
plt.ylabel(ylabel)
# if kwargs is not None:
# for key, value in kwargs:
# getattr(plt, key)(value)
if attributes is None:
attributes = [attr[0] for attr in self.attributes]
for attr in attributes:
Xs = getattr(self, attr+'_freq') * np.arange(1, len(getattr(self, attr))+1)
Ys = getattr(self, attr)
# print(Xs)
# print(Ys)
plt.plot( Xs, Ys, label=attr)
plt.legend()
plt.savefig(filename)
plt.close()
def save_checkpoint(state, filename='checkpoint.pth.tar'):
save(state, filename)