-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutils.py
98 lines (80 loc) · 2.87 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Nov 20 17:24:29 2019
@author: Samiul Arshad <[email protected]>
"""
import logging
import re
import os
import torch
def one_hot(y, num_class):
'''
Converts class label to one hot representation
Parameters
----------
cls : Class labels in int
num_class : Total Number of class
Returns
-------
one_hot_cls : Class labels in one hot representation
'''
batch = y.shape[0]
y_onehot = torch.zeros(batch, num_class)
y_onehot.scatter_(1,y,1)
return y_onehot
def find_latest_step(path):
'''
Finds latest step given a path
'''
# Files in format 'step_d_epoch_dddd.pt'
step_regex = re.compile(r'^step_(?P<n_step>\d+)_epoch_\d+\.pt$')
steps_completed = []
for f in os.listdir(path):
m = step_regex.match(f)
if m:
steps_completed.append(int(m.group('n_step')))
return max(steps_completed) if steps_completed else 0
def find_latest_epoch(path, step):
'''
Finds latest epoch given a path and step
'''
# Files in format 'step_d_epoch_dddd.pt'
epoch_regex = re.compile(r'^step_'+str(step)+'_epoch_(?P<n_epoch>\d+)\.pt$')
epochs_completed = []
for f in os.listdir(path):
m = epoch_regex.match(f)
if m:
epochs_completed.append(int(m.group('n_epoch')))
return max(epochs_completed) if epochs_completed else -1
def find_latest_epoch_and_step(path):
step = find_latest_step(path)
epoch = find_latest_epoch(path, step)
return step, epoch
def setup_logging(log_dir):
os.makedirs(log_dir, exist_ok=True)
logpath = os.path.join(log_dir, 'log.txt')
filemode = 'a' if os.path.exists(logpath) else 'w'
# set up logging to file - see previous section for more details
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(message)s',
datefmt='%m-%d %H:%M:%S',
filename=logpath,
filemode=filemode)
# define a Handler which writes INFO messages or higher to the sys.stderr
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
# set a format which is simpler for console use
formatter = logging.Formatter('%(asctime)s: %(levelname)-8s %(message)s')
# tell the handler to use this format
console.setFormatter(formatter)
# add the handler to the root logger
logging.getLogger('').addHandler(console)
def writer_histogram(writer, model, epoch):
for tag, value in model.named_parameters():
tag = tag.replace('.', '/')
writer.add_histogram(model.__class__.__name__+'/'+tag, value.data.cpu().numpy(), epoch)
try:
writer.add_histogram(model.__class__.__name__+tag+'/grad', value.grad.cpu().numpy(), epoch)
except:
continue