-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_pup.py
137 lines (105 loc) · 5.65 KB
/
train_pup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import sys
from utils.NiftiDataset import *
import utils.NiftiDataset as NiftiDataset
from torch.utils.data import DataLoader
from options.train_options import TrainOptions
# from logger import *
import time
from models import create_model
from utils.visualizer import Visualizer
# from test import inference
if __name__ == '__main__':
# ----- Loading the init options -----
opt = TrainOptions().parse()
# ----- Transformation and Augmentation process for the data -----
min_pixel = int(opt.min_pixel * ((opt.patch_size[0] * opt.patch_size[1] * opt.patch_size[2]) / 100))
trainTransforms = [
NiftiDataset.Resample(opt.new_resolution, opt.resample),
NiftiDataset.Augmentation(),
NiftiDataset.Padding((opt.patch_size[0], opt.patch_size[1], opt.patch_size[2])),
NiftiDataset.RandomCrop((opt.patch_size[0], opt.patch_size[1], opt.patch_size[2]), opt.drop_ratio, min_pixel)
]
train_set = NiftiDataSet(opt.data_path, which_direction='AtoB', transforms=trainTransforms, shuffle_labels=False, train=True)
print('lenght train list:', len(train_set))
print((train_set[0][1].shape))
train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=True) # Here are then fed to the network with a defined batch size
if opt.data_path_2 is not None:
train_set_unpaired = NiftiDataSet(opt.data_path_2, which_direction='AtoB', transforms=trainTransforms, shuffle_labels=True, train=True)
print('lenght train list (unpaired):', len(train_set_unpaired))
print((train_set[0][1].shape))
train_loader_unpaired = DataLoader(train_set_unpaired, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=True) # Here are then fed to the network with a defined batch size
# -----------------------------------------------------
model = create_model(opt) # creation of the model
model.setup(opt)
if opt.epoch_count > 1:
model.load_networks(opt.epoch_count)
visualizer = Visualizer(opt)
total_steps = 0
super_train = opt.super_train
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
iter_data_time = time.time()
epoch_iter = 0
if epoch <= opt.super_epoch and super_train == 1:
for i, data in enumerate(train_loader):
iter_start_time = time.time()
if total_steps % opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
visualizer.reset()
total_steps += opt.batch_size
epoch_iter += opt.batch_size
model.set_input(data)
if opt.model == 'pup_gan':
model.optimize_parameters(opt)
else:
model.optimize_parameters()
if total_steps % opt.print_freq == 0:
losses = model.get_current_losses()
t = (time.time() - iter_start_time) / opt.batch_size
visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data)
if total_steps % opt.save_latest_freq == 0:
print('saving the latest model (epoch %d, total_steps %d)' %
(epoch, total_steps))
model.save_networks('latest')
iter_data_time = time.time()
if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' %
(epoch, total_steps))
model.save_networks('latest')
model.save_networks(epoch)
print('End of epoch %d / %d \t Time Taken: %d sec' %
(epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
model.update_learning_rate()
else:
# If we started with supervised training and then switch, opt.super_train must be updated. If we only do unsupervised, then opt.super_train = 0 anyway
if super_train == 1:
opt.super_train = 0
for i, data in enumerate(train_loader_unpaired):
iter_start_time = time.time()
if total_steps % opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
visualizer.reset()
total_steps += opt.batch_size
epoch_iter += opt.batch_size
model.set_input(data)
if opt.model == 'pup_gan':
model.optimize_parameters(opt)
else:
model.optimize_parameters()
if total_steps % opt.print_freq == 0:
losses = model.get_current_losses()
t = (time.time() - iter_start_time) / opt.batch_size
visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data)
if total_steps % opt.save_latest_freq == 0:
print('saving the latest model (epoch %d, total_steps %d)' %
(epoch, total_steps))
model.save_networks('latest')
iter_data_time = time.time()
if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' %
(epoch, total_steps))
model.save_networks('latest')
model.save_networks(epoch)
print('End of epoch %d / %d \t Time Taken: %d sec' %
(epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
model.update_learning_rate()