-
Notifications
You must be signed in to change notification settings - Fork 180
/
Main.py
39 lines (35 loc) · 1.07 KB
/
Main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from Diffusion.Train import train, eval
def main(model_config = None):
modelConfig = {
"state": "train", # or eval
"epoch": 200,
"batch_size": 80,
"T": 1000,
"channel": 128,
"channel_mult": [1, 2, 3, 4],
"attn": [2],
"num_res_blocks": 2,
"dropout": 0.15,
"lr": 1e-4,
"multiplier": 2.,
"beta_1": 1e-4,
"beta_T": 0.02,
"img_size": 32,
"grad_clip": 1.,
"device": "cuda:0", ### MAKE SURE YOU HAVE A GPU !!!
"training_load_weight": None,
"save_weight_dir": "./Checkpoints/",
"test_load_weight": "ckpt_199_.pt",
"sampled_dir": "./SampledImgs/",
"sampledNoisyImgName": "NoisyNoGuidenceImgs.png",
"sampledImgName": "SampledNoGuidenceImgs.png",
"nrow": 8
}
if model_config is not None:
modelConfig = model_config
if modelConfig["state"] == "train":
train(modelConfig)
else:
eval(modelConfig)
if __name__ == '__main__':
main()