-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmain.py
executable file
·55 lines (43 loc) · 1.38 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""Entry point."""
import os
import torch
import data
import config
import utils
import trainer
import numpy
import random
logger = utils.get_logger()
def main(args): # pylint:disable=redefined-outer-name
"""main: Entry point."""
utils.prepare_dirs(args)
torch.manual_seed(args.random_seed)
# Add this for the random seed
numpy.random.seed(args.random_seed)
random.seed(args.random_seed)
torch.backends.cudnn.deterministic = True
if args.num_gpu > 0:
torch.cuda.manual_seed(args.random_seed)
if args.network_type == 'rnn':
dataset = data.text.Corpus(args.data_path)
trnr = trainer.Trainer(args, dataset)
elif 'cnn' in args.network_type:
dataset = data.image.Image(args)
trnr = trainer.CNNTrainer(args, dataset)
else:
raise NotImplementedError(f"{args.dataset} is not supported")
if args.mode == 'train':
utils.save_args(args)
trnr.train()
elif args.mode == 'derive':
assert args.load_path != "", ("`--load_path` should be given in "
"`derive` mode")
trnr.derive()
else:
if not args.load_path:
raise Exception("[!] You should specify `load_path` to load a "
"pretrained model")
trnr.test()
if __name__ == "__main__":
args, unparsed = config.get_args()
main(args)