-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrainers.py
127 lines (105 loc) · 4.98 KB
/
trainers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
'''
xECG Project Repository (https://github.com/jtrpinto/xECG)
File: trainers.py
- Contains the model training routine. Used by train_model_ptb.py and
train_model_uoftdb.py.
"Explaining ECG Biometrics: Is It All In The QRS?"
João Ribeiro Pinto and Jaime S. Cardoso
19th International Conference of the Biometrics Special Interest Group (BIOSIG 2020)
[email protected] | https://jtrpinto.github.io
'''
import torch
import numpy as np
import sys
import pickle
def train_model(model, loss_fn, optimiser, train_loader, n_epochs, device, patience=np.inf, valid_loader=None, filename=None):
# repeat training for the desired number of epochs
train_hist = []
train_idr = []
valid_hist = []
valid_idr = []
# For early stopping:
plateau = 0
best_valid_loss = None
print(model)
for epoch in range(n_epochs):
print('Epoch {}/{}'.format(epoch + 1, n_epochs))
# training loop
model.train() # set model to training mode (affects dropout and batch norm.)
for i, (X, y) in enumerate(train_loader):
# copy the mini-batch to GPU
X = X.float().to(device)
y = y.to(device)
ypred = model(X) # forward pass
loss = loss_fn(ypred, y) # compute the loss
optimiser.zero_grad() # set all gradients to zero (otherwise they are accumulated)
loss.backward() # backward pass (i.e. compute gradients)
optimiser.step() # update the parameters
# display the mini-batch loss
sys.stdout.write("\r" + '........mini-batch no. {} loss: {:.4f}'.format(i+1, loss.item()))
sys.stdout.flush()
if torch.isnan(loss):
print('NaN loss. Terminating train.')
return [], []
# compute the training and validation losses to monitor the training progress (optional)
print()
with torch.no_grad(): # now we are doing inference only, so we do not need gradients
model.eval() # set model to inference mode (affects dropout and batch norm.)
train_loss = 0.
t_corrects = 0
t_total = 0
for i, (X, y) in enumerate(train_loader):
# copy the mini-batch to GPU
X = X.float().to(device)
y = y.to(device)
ypred = model(X) # forward pass
train_loss += loss_fn(ypred, y) # accumulate the loss of the mini-batch
t_corrects += (torch.argmax(ypred, 1) == y).float().sum()
t_total += y.shape[0]
train_loss /= i + 1
train_hist.append(train_loss.item())
t_idr = t_corrects / t_total
train_idr.append(t_idr)
print('....train loss: {:.4f} :: IDR {:.4f}'.format(train_loss.item(), t_idr))
if valid_loader is None:
print()
continue
valid_loss = 0.
v_corrects = 0
v_total = 0
for i, (X, y) in enumerate(valid_loader):
# copy the mini-batch to GPU
X = X.float().to(device)
y = y.to(device)
ypred = model(X) # forward pass
valid_loss += loss_fn(ypred, y) # accumulate the loss of the mini-batch
v_corrects += (torch.argmax(ypred, 1) == y).float().sum()
v_total += y.shape[0]
valid_loss /= i + 1
valid_hist.append(valid_loss.item())
v_idr = v_corrects / v_total
valid_idr.append(v_idr)
print('....valid loss: {:.4f} :: IDR {:.4f}'.format(valid_loss.item(), v_idr))
if best_valid_loss is None:
best_valid_loss = valid_loss
torch.save(model.state_dict(), filename + '.pth')
with open(filename + '_trainhist.pk', 'wb') as hf:
pickle.dump({'loss': train_hist, 'idr': train_idr}, hf)
with open(filename + '_validhist.pk', 'wb') as hf:
pickle.dump({'loss': valid_hist, 'idr': valid_idr}, hf)
print('....Saving...')
elif valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), filename + '.pth')
with open(filename + '_trainhist.pk', 'wb') as hf:
pickle.dump({'loss': train_hist, 'idr': train_idr}, hf)
with open(filename + '_validhist.pk', 'wb') as hf:
pickle.dump({'loss': valid_hist, 'idr': valid_idr}, hf)
plateau = 0
print('....Saving...')
else:
plateau += 1
if plateau >= patience:
print('....Early stopping the train.')
return train_hist, valid_hist
return train_hist, valid_hist