forked from johschmidt42/PyTorch-2D-3D-UNet-Tutorial
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
115 lines (90 loc) · 4.09 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
import torch
class Trainer:
def __init__(self,
model: torch.nn.Module,
device: torch.device,
criterion: torch.nn.Module,
optimizer: torch.optim.Optimizer,
training_DataLoader: torch.utils.data.Dataset,
validation_DataLoader: torch.utils.data.Dataset = None,
lr_scheduler: torch.optim.lr_scheduler = None,
epochs: int = 100,
epoch: int = 0,
notebook: bool = False
):
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.training_DataLoader = training_DataLoader
self.validation_DataLoader = validation_DataLoader
self.device = device
self.epochs = epochs
self.epoch = epoch
self.notebook = notebook
self.training_loss = []
self.validation_loss = []
self.learning_rate = []
def run_trainer(self):
if self.notebook:
from tqdm.notebook import tqdm, trange
else:
from tqdm import tqdm, trange
progressbar = trange(self.epochs, desc='Progress')
for i in progressbar:
"""Epoch counter"""
self.epoch += 1 # epoch counter
"""Training block"""
self._train()
"""Validation block"""
if self.validation_DataLoader is not None:
self._validate()
"""Learning rate scheduler block"""
if self.lr_scheduler is not None:
if self.validation_DataLoader is not None and self.lr_scheduler.__class__.__name__ == 'ReduceLROnPlateau':
self.lr_scheduler.batch(self.validation_loss[i]) # learning rate scheduler step with validation loss
else:
self.lr_scheduler.batch() # learning rate scheduler step
return self.training_loss, self.validation_loss, self.learning_rate
def _train(self):
if self.notebook:
from tqdm.notebook import tqdm, trange
else:
from tqdm import tqdm, trange
self.model.train() # train mode
train_losses = [] # accumulate the losses here
batch_iter = tqdm(enumerate(self.training_DataLoader), 'Training', total=len(self.training_DataLoader),
leave=False)
for i, (x, y) in batch_iter:
input, target = x.to(self.device), y.to(self.device) # send to device (GPU or CPU)
self.optimizer.zero_grad() # zerograd the parameters
out = self.model(input) # one forward pass
loss = self.criterion(out, target) # calculate loss
loss_value = loss.item()
train_losses.append(loss_value)
loss.backward() # one backward pass
self.optimizer.step() # update the parameters
batch_iter.set_description(f'Training: (loss {loss_value:.4f})') # update progressbar
self.training_loss.append(np.mean(train_losses))
self.learning_rate.append(self.optimizer.param_groups[0]['lr'])
batch_iter.close()
def _validate(self):
if self.notebook:
from tqdm.notebook import tqdm, trange
else:
from tqdm import tqdm, trange
self.model.eval() # evaluation mode
valid_losses = [] # accumulate the losses here
batch_iter = tqdm(enumerate(self.validation_DataLoader), 'Validation', total=len(self.validation_DataLoader),
leave=False)
for i, (x, y) in batch_iter:
input, target = x.to(self.device), y.to(self.device) # send to device (GPU or CPU)
with torch.no_grad():
out = self.model(input)
loss = self.criterion(out, target)
loss_value = loss.item()
valid_losses.append(loss_value)
batch_iter.set_description(f'Validation: (loss {loss_value:.4f})')
self.validation_loss.append(np.mean(valid_losses))
batch_iter.close()