-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain.py
112 lines (98 loc) · 3.02 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import argparse
import sys
from test import test
from train import train
from configs.defaults import get_cfg
import research_platform.utils.checkpoint as cu
from research_platform.utils.misc import launch_job
""""
General launcher script for Emprical Risk Minimization (i.e., supervised learning) cases.
"""
def parse_args(default=False):
parser = argparse.ArgumentParser(
description='Parse arguments')
parser.add_argument(
"--gpu_ids",
help="The shard id of current node, Starts from 0 to num_shards - 1",
type=str,
)
parser.add_argument(
"--shard_id",
help="The shard id of current node, Starts from 0 to num_shards - 1",
default=0,
type=int,
)
parser.add_argument(
"--num_shards",
help="Number of shards using by the job",
default=1,
type=int,
)
parser.add_argument(
"--init_method",
help="Initialization method, includes TCP or shared file-system",
default="tcp://localhost:9999",
type=str,
)
parser.add_argument(
"--cfg",
dest="cfg_file",
help="Path to the config file",
default=None,
type=str,
)
parser.add_argument(
"opts",
help="See configs/defaults.py for all options",
default=None,
nargs=argparse.REMAINDER,
)
if len(sys.argv) == 1:
parser.print_help()
return parser.parse_args()
def load_config(args):
"""
Given the arguemnts, load and initialize the configs.
Args:
args (argument): arguments includes `shard_id`, `num_shards`,
`init_method`, `cfg_file`, and `opts`.
"""
# Setup cfg.
cfg = get_cfg()
# Load config from cfg.
if args.cfg_file is not None:
cfg.merge_from_file(args.cfg_file)
# Load config from command line, overwrite config from opts.
if args.opts is not None:
cfg.merge_from_list(args.opts)
# Inherit parameters from args.
if hasattr(args, "num_shards") and hasattr(args, "shard_id"):
cfg.NUM_SHARDS = args.num_shards
cfg.SHARD_ID = args.shard_id
if hasattr(args, "rng_seed"):
cfg.RNG_SEED = args.rng_seed
if hasattr(args, "output_dir"):
cfg.OUTPUT_DIR = args.output_dir
output_dir = cfg.OUTPUT_DIR
for name in cfg.DATA.APPEND_TO_OUTPUT_DIRNAME:
output_dir += f'_{name}-{cfg.DATA[name]}'
cfg.OUTPUT_DIR = output_dir
# Create the checkpoint dir.
cu.make_checkpoint_dir(cfg.OUTPUT_DIR)
return cfg
def main():
""" argument define """
args = parse_args()
""" set torch device"""
if args.gpu_ids is not None:
os.environ["CUDA_VISIBLE_DEVICES" ]= args.gpu_ids
if args.num_shards > 1:
args.output_dir = str(args.job_dir)
cfg = load_config(args)
if cfg.ERM_TRAIN.ENABLE:
launch_job(cfg=cfg, init_method=args.init_method, func=train)
if cfg.ERM_TEST.ENABLE:
launch_job(cfg=cfg, init_method=args.init_method, func=test)
if __name__ == "__main__":
main()