-
Notifications
You must be signed in to change notification settings - Fork 1
/
main_DG.py
72 lines (61 loc) · 3.07 KB
/
main_DG.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from torch.utils.data import DataLoader
import torch
from data import Data
from trainer import DoppelGANger, CGAN, RCGAN, NAIVEGAN, TimeGAN, TimeGAN2
import os
from util import options_parser
def main():
parser = options_parser()
args = parser.parse_args()
device = args.device
dataset = args.dataset
gan_type = args.gan_type
dis_type = args.dis_type
wl = args.w_lambert == 'True'
ks = None if args.kernel_smoothing is None else args.kernel_smoothing
checkpoint_dir = 'runs/{}'.format(dataset)
if wl:
checkpoint_dir = "{}_wl".format(checkpoint_dir)
if ks is not None:
checkpoint_dir = "{}_ks_{}".format(checkpoint_dir, ks)
if args.dis_type is None:
checkpoint_dir = '{}/{}/1'.format(checkpoint_dir, gan_type)
else:
checkpoint_dir = '{}/Gen_{}_Dis_{}/1'.format(checkpoint_dir, gan_type, dis_type)
time_logging_file = '{}/time.log'.format(checkpoint_dir)
config_logging_file = '{}/config.log'.format(checkpoint_dir)
checkpoint_dir = '{}/checkpoint'.format(checkpoint_dir)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
sample_len = int(args.sample_len)
batch_size = args.batch_size
# load data
dataset = Data(sample_len=sample_len, name=dataset, w_lambert=wl, ks=ks)
real_train_dl = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
# define Hyperparameters
epoch = args.num_epochs
save_frequency = args.save_frequency
if gan_type == 'RCGAN' or gan_type == 'RGAN':
is_conditional = args.is_conditional
trainer = RCGAN(real_train_dl, device=device, checkpoint_dir=checkpoint_dir,
time_logging_file=time_logging_file, batch_size=batch_size,
config_logging_file=config_logging_file, isConditional=is_conditional)
elif gan_type == 'NaiveGAN':
trainer = NAIVEGAN(real_train_dl, device=device, checkpoint_dir=checkpoint_dir, batch_size=batch_size,
time_logging_file=time_logging_file,
config_logging_file=config_logging_file)
elif gan_type == 'CGAN':
trainer = CGAN(real_train_dl, device=device, batch_size=batch_size, checkpoint_dir=checkpoint_dir,
time_logging_file=time_logging_file, config_logging_file=config_logging_file)
elif gan_type == 'TimeGAN':
trainer = TimeGAN2(real_train_dl, device=device, checkpoint_dir=checkpoint_dir, batch_size=batch_size,
config_logging_file=config_logging_file,
time_logging_file=time_logging_file)
else:
trainer = DoppelGANger(real_train_dl=real_train_dl, device=device,
checkpoint_dir=checkpoint_dir, time_logging_file=time_logging_file,
config_logging_file=config_logging_file, sample_len=sample_len, batch_size=batch_size,
gen_type=gan_type, dis_type=dis_type)
trainer.train(epochs=epoch, writer_frequency=1, saver_frequency=save_frequency)
if __name__ == "__main__":
main()