-
Notifications
You must be signed in to change notification settings - Fork 8
/
checkpoint.py
113 lines (99 loc) · 3.41 KB
/
checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import torch
from tensorboardX import SummaryWriter
use_cuda = torch.cuda.is_available()
default_checkpoint = {
"epoch": 0,
"train_losses": [],
"train_symbol_accuracy": [],
"train_sentence_accuracy": [],
"train_wer": [],
"train_score": [],
"validation_losses": [],
"validation_symbol_accuracy": [],
"validation_sentence_accuracy": [],
"validation_wer": [],
"validation_score": [],
"lr": [],
"grad_norm": [],
"model": {},
"configs":{},
"token_to_id":{},
"id_to_token":{},
}
def save_checkpoint(checkpoint, dir="./checkpoints", prefix=""):
""" Saving check point
Args:
checkpoint(dict) : Checkpoint to save
dir(str) : Path to save the checkpoint
prefix(str) : Path of location of dir
"""
# Padded to 4 digits because of lexical sorting of numbers.
# e.g. 0009.pth
filename = "{num:0>4}.pth".format(num=checkpoint["epoch"])
if not os.path.exists(os.path.join(prefix, dir)):
os.makedirs(os.path.join(prefix, dir))
torch.save(checkpoint, os.path.join(prefix, dir, filename))
def load_checkpoint(path, cuda=use_cuda):
""" Load check point
Args:
path(str) : Path checkpoint located
cuda : Whether use cuda or not [Default: use_cuda]
Returns
Loaded checkpoints
"""
if cuda:
return torch.load(path)
else:
# Load GPU model on CPU
return torch.load(path, map_location=lambda storage, loc: storage)
def init_tensorboard(name="", base_dir="./tensorboard"):
"""Init tensorboard
Args:
name(str) : name of tensorboard
base_dir(str): path of tesnorboard
"""
return SummaryWriter(os.path.join(name, base_dir))
def write_tensorboard(
writer,
epoch,
grad_norm,
train_loss,
train_symbol_accuracy,
train_sentence_accuracy,
train_wer,
train_score,
validation_loss,
validation_symbol_accuracy,
validation_sentence_accuracy,
validation_wer,
validation_score,
model,
):
writer.add_scalar("train_loss", train_loss, epoch)
writer.add_scalar("train_symbol_accuracy", train_symbol_accuracy, epoch)
writer.add_scalar("train_sentence_accuracy",train_sentence_accuracy,epoch)
writer.add_scalar("train_wer", train_wer, epoch)
writer.add_scalar("train_score", train_score, epoch)
writer.add_scalar("validation_loss", validation_loss, epoch)
writer.add_scalar("validation_symbol_accuracy", validation_symbol_accuracy, epoch)
writer.add_scalar("validation_sentence_accuracy",validation_sentence_accuracy,epoch)
writer.add_scalar("validation_wer",validation_wer,epoch)
writer.add_scalar("validation_score", validation_score, epoch)
writer.add_scalar("grad_norm", grad_norm, epoch)
for name, param in model.encoder.named_parameters():
writer.add_histogram(
"encoder/{}".format(name), param.detach().cpu().numpy(), epoch
)
if param.grad is not None:
writer.add_histogram(
"encoder/{}/grad".format(name), param.grad.detach().cpu().numpy(), epoch
)
for name, param in model.decoder.named_parameters():
writer.add_histogram(
"decoder/{}".format(name), param.detach().cpu().numpy(), epoch
)
if param.grad is not None:
writer.add_histogram(
"decoder/{}/grad".format(name), param.grad.detach().cpu().numpy(), epoch
)