-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
76 lines (70 loc) · 2.52 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
Rough outline:
1) create dataset
2) create model, setup
"""
import yaml
import torch
import os
from options.train_options import CycleTrainOptions
from models.cyclegan import CycleGAN
from data.dataset import CycleDataset
from utils.model_utils import print_losses, get_latest_num
if __name__ == '__main__':
# parse options
parser = CycleTrainOptions()
opt = parser.parse()
opt.phase = 'train'
parser.export_options(opt)
# config params
with open(opt.config, 'r') as file:
config = yaml.safe_load(file)
warmup_epochs = config['train']['warmup_epochs']
decay_epochs = config['train']['decay_epochs']
save_epoch_freq = config['train']['save_epoch_freq']
total_epochs = warmup_epochs + decay_epochs
if not (opt.continue_train):
start_epoch = 1
else:
# do it based on load epoch
if (opt.load_epoch == 'latest'):
start_epoch = get_latest_num(
os.path.join(opt.checkpoints_dir, opt.model_name)
) + 1
else:
start_epoch = int(opt.load_epoch) + 1
config['train']['start_epoch'] = start_epoch + 1
# create model + dataset
model = CycleGAN(opt, config)
model.general_setup()
dataset = CycleDataset(
opt.to_train,
dataroot=opt.dataroot,
**config['dataset']
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=config['dataset']['batch_size'],
shuffle=not (config['dataset']['in_order']),
num_workers=config['dataset']['num_workers']
)
max_size = len(dataloader)
X_size, Y_size = dataset.both_len()
print(f'Number of X images: {X_size}, Number of Y images: {Y_size}')
print(f"Starting training loop from epoch {start_epoch}...")
print("Losses printed as [epoch / total epochs] [batch / total batches]")
# main loop
print("Total epochs: ", total_epochs)
for epoch in range(start_epoch, total_epochs + 1):
print(f"LR: {model.schedulers[0].get_last_lr()}")
for i, data in enumerate(dataloader):
model.setup_input(data)
model.optimize()
model.update_schedulers()
losses = model.get_losses() # ordered dict
print_losses(losses, epoch, total_epochs, i + 1, max_size)
if (epoch % save_epoch_freq == 0) or (epoch == total_epochs):
# save version with latest and also with epoch num
print(f"Saving models at end of epoch {epoch}")
model.save_networks()
model.save_networks(str(epoch))