-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
152 lines (121 loc) · 4.71 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from datetime import datetime
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import seaborn as sns
import yaml
import torch
from model.training import build_lr_scheduler
sns.set_style('whitegrid', {'axes.edgecolor': '.2'})
sns.set('poster', rc={"xtick.bottom": True, "ytick.left": True,
'axes.edgecolor': '.2',
"font.weight": 'bold',
"axes.titleweight": 'bold',
'axes.labelweight': 'bold'})
sns.color_palette('husl')
class Standardizer:
"""Z-score standardization"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, x, rev=False):
if rev:
return (x * self.std) + self.mean
return (x - self.mean) / self.std
def create_logger(name: str, log_dir: str = None) -> logging.Logger:
"""
Creates a logger with a stream handler and file handler.
:param name: The name of the logger.
:param log_dir: The directory in which to save the logs.
:return: The logger.
"""
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
logger.propagate = False
# Set logger
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
logger.addHandler(ch)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
fh = logging.FileHandler(os.path.join(log_dir, name + '.log'))
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger
def dict_to_str(dictionary: dict,
level: int = 0,
) -> str:
"""
A helper function to log dictionaries in a pretty way.
Args:
dictionary (dict): A general python dictionary.
level (int): A recursion level counter, sets the visual indentation.
Returns:
str: A text representation for the dictionary.
"""
message = ''
for key, value in dictionary.items():
if isinstance(value, dict):
message += ' ' * level * 2 + str(key) + ':\n' + dict_to_str(value, level + 1)
else:
message += ' ' * level * 2 + str(key) + ': ' + str(value) + '\n'
return message
def string_representer(dumper, data):
"""
Add a custom string representer to use block literals for multiline strings.
"""
if len(data.splitlines()) > 1:
return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data, style='|')
return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data)
def save_yaml_file(path: str,
content: list or dict,
) -> None:
"""
Save a YAML file (usually an input / restart file, but also conformers file)
Args:
path (str): The YAML file path to save.
content (list, dict): The content to save.
"""
if not isinstance(path, str):
raise InputError(f'path must be a string, got {path} which is a {type(path)}')
yaml.add_representer(str, string_representer)
content = yaml.dump(data=content)
if '/' in path and os.path.dirname(path) and not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
with open(path, 'w') as f:
f.write(content)
def plot_train_val_loss(log_file):
"""
Plots the training and validation loss by parsing the log file.
:param log_file: The path to the log file created during training.
"""
train_loss = []
val_loss = []
with open(log_file) as f:
lines = f.readlines()
for line in lines:
if ': Training Loss' in line:
train_loss.append(float(line.split(' ')[-1].rstrip()))
if ': Validation Loss' in line:
val_loss.append(float(line.split(' ')[-1].rstrip()))
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
ax.plot(np.arange(len(train_loss)), train_loss, label='Train Loss')
ax.plot(np.arange(len(val_loss)), val_loss, label='Val Loss')
ax.set_xlabel('Epochs')
ax.set_ylabel('Loss')
ax.legend()
fig.savefig(os.path.join(os.path.dirname(log_file), 'train_val_loss.pdf'), bbox_inches='tight')
def get_optimizer_and_scheduler(args, model, train_data_size):
if args.optimizer == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
else:
raise NotImplementedError("Optimizer not implementer.")
if args.scheduler == 'plateau':
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7,
patience=5, min_lr=args.lr/100)
elif args.scheduler == 'noam':
scheduler = build_lr_scheduler(optimizer=optimizer, args=args, train_data_size=train_data_size)
else:
scheduler = None
return optimizer, scheduler