forked from lucidrains/denoising-diffusion-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdiffutrainer.py
executable file
·90 lines (76 loc) · 3.28 KB
/
diffutrainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from denoising_diffusion_pytorch import GaussianDiffusion, Trainer
#from UnetGen import UnetGen
import torch
import argparse
parser = argparse.ArgumentParser()
# define params and their types with defaults if needed
parser.add_argument('--images', type=str, default="", help='path to images')
parser.add_argument('--lr', type=float, default=4e-5, help='learning rate')
parser.add_argument('--steps', type=int, default=1000, help='number of diffusion steps')
parser.add_argument('--accum', type=int, default=10, help='number of iterations per gradient update')
parser.add_argument('--trainsteps', type=int, default=100000, help='number of iterations')
parser.add_argument('--dir', type=str, default="train", help='folder for storing sampled images')
parser.add_argument('--name', type=str, default="oma", help='basename for storing sampled images')
parser.add_argument('--amp', action="store_true", help='use automatic mixed precision')
parser.add_argument('--imageSize', type=int, default=512, help='image size')
parser.add_argument('--batchSize', type=int, default=2, help='batch size')
parser.add_argument('--saveEvery', type=int, default=100, help='image and model save frequency')
parser.add_argument('--losstype', type=str, default="l2", help='loss type: l1 or l2')
parser.add_argument('--load', type=str, default="", help='path to pth file')
parser.add_argument('--nostrict', action="store_true", help='')
parser.add_argument('--mults', type=int, nargs='*', default=[1, 1, 2, 2, 4, 8], help='')
parser.add_argument('--nsamples', type=int, default=2, help='how many samples to generate')
parser.add_argument('--model', type=str, default="unet1", help='model architecture: unet0, unetok5, unet1,unetcn0')
opt = parser.parse_args()
mtype = opt.model
if mtype == "unet0":
from alt_models.Unet0 import Unet
elif mtype == "unet0k5":
from alt_models.Unet0k5 import Unet
elif mtype == "unet1":
from alt_models.Unet1 import Unet
elif mtype == "unet2":
from alt_models.Unet2 import Unet
elif mtype == "unetcn0":
from alt_models.UnetCN0 import Unet
else:
print("Unsupported model: "+mtype)
exit()
model = Unet(
dim = 64,
dim_mults = tuple(opt.mults)
).cuda()
print(model)
model = model.cuda()
diffusion = GaussianDiffusion(
model,
image_size = opt.imageSize,
timesteps = opt.steps, # number of steps
loss_type = opt.losstype # L1 or L2
).cuda()
trainer = Trainer(
diffusion,
opt.images,
image_size = opt.imageSize,
train_batch_size = opt.batchSize,
train_lr = opt.lr,
save_and_sample_every = opt.saveEvery,
train_num_steps = opt.trainsteps, # total training steps
gradient_accumulate_every = opt.accum, # gradient accumulation steps
ema_decay = 0.995, # exponential moving average decay
amp = opt.amp, # turn on mixed precision training with apex
results_folder = opt.dir,
nsamples = opt.nsamples,
opts = opt
)
if opt.load != "":
data = torch.load(opt.load)
#trainer.load(data)
trainer.step = data['step']
trainer.model.load_state_dict(data['model'])
trainer.ema_model.load_state_dict(data['ema'])
try:
print("loaded "+opt.load+", correct mults: "+",".join(str(x) for x in data['mults']))
except:
print("loaded "+opt.load+", no mults stored")
trainer.train()