-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain.py
34 lines (26 loc) · 1.05 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import argparse
import collections
import importlib
import numpy as np
import torch
from parse_config import ConfigParser
def main(config):
SEED = config.config.get('seed', 1)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
np.random.seed(SEED)
Trainer = importlib.import_module('src.trainer.'
+ config['trainer'].get('type', 'trainer'))
trainer = Trainer.Trainer(config)
trainer.train()
if __name__ == '__main__':
args = argparse.ArgumentParser()
args.add_argument('-c', '--config', default=None, type=str,
help='config file path (default: None)')
args.add_argument('-d', '--device', default=None, type=str,
help='indices of GPUs to enable (default: all)')
args.add_argument('-r', '--resume', default=None, type=str,
help='path to latest checkpoint (default: None)')
config = ConfigParser.from_args(args)
main(config)