-
Notifications
You must be signed in to change notification settings - Fork 492
/
main.py
executable file
·54 lines (43 loc) · 1.44 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Entry point."""
import os
import torch
import data
import config
import utils
import trainer
logger = utils.get_logger()
def main(args): # pylint:disable=redefined-outer-name
"""main: Entry point."""
utils.prepare_dirs(args)
torch.manual_seed(args.random_seed)
if args.num_gpu > 0:
torch.cuda.manual_seed(args.random_seed)
if args.network_type == 'rnn':
dataset = data.text.Corpus(args.data_path)
elif args.dataset == 'cifar':
dataset = data.image.Image(args.data_path)
else:
raise NotImplementedError(f"{args.dataset} is not supported")
trnr = trainer.Trainer(args, dataset)
if args.mode == 'train':
utils.save_args(args)
trnr.train()
elif args.mode == 'derive':
assert args.load_path != "", ("`--load_path` should be given in "
"`derive` mode")
trnr.derive()
elif args.mode == 'test':
if not args.load_path:
raise Exception("[!] You should specify `load_path` to load a "
"pretrained model")
trnr.test()
elif args.mode == 'single':
if not args.dag_path:
raise Exception("[!] You should specify `dag_path` to load a dag")
utils.save_args(args)
trnr.train(single=True)
else:
raise Exception(f"[!] Mode not found: {args.mode}")
if __name__ == "__main__":
args, unparsed = config.get_args()
main(args)