-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathargs.py
125 lines (94 loc) · 2.85 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import argparse
def add_trailing_slash(path):
if path[-1] != "/":
return path + "/"
else:
return path
def finalize_args(parser):
args = parser.parse_args()
args.save_dir = add_trailing_slash(args.save_dir)
args.logs_dir = add_trailing_slash(args.logs_dir)
return args
def common_args(parser):
parser.add_argument(
"--timesteps",
type=int,
default=int(1e6),
help="number of frames to train (default: 1e6)",
)
parser.add_argument(
"--save-dir",
default="./models/",
help="directory to save agent checkpoints (default: ./models/)",
)
parser.add_argument(
"--logs-dir",
default="./logs/",
help="directory to save tensorboard logs (default: ./logs/)",
)
parser.add_argument("--load-model", help="path of the model to load")
parser.add_argument(
"--algo",
default="ppo2",
choices=["a2c", "ppo2"],
help="algorithm to use: a2c | ppo2",
)
parser.add_argument(
"--policy",
default="cnn",
choices=["cnn", "cnnlstm"],
help="algorithm to use: cnn | cnnlstm",
)
parser.add_argument(
"--hyper-opt",
action="store_true",
default=False,
help="set it to use the optimal hyperparameters",
)
parser.add_argument(
"--short-life",
action="store_true",
default=False,
help="whether or not to use ShortLife wrapper"
)
def get_test_args():
parser = argparse.ArgumentParser(
description="Sonic's reinforcement learning testing suite"
)
common_args(parser)
parser.add_argument("test_id", help="test id (used for the logs' name)")
parser.add_argument(
"--num-processes",
type=int,
default=1,
help="how many training CPU processes to use (default: 1)",
)
parser.add_argument("rank", help="rank number", type=int)
return finalize_args(parser)
def get_train_args():
parser = argparse.ArgumentParser(description="Sonic's reinforcement learning")
common_args(parser)
parser.add_argument(
"--game",
default="SonicTheHedgehog-Genesis",
help="game to train on (default: SonicTheHedgehog-Genesis)",
)
parser.add_argument(
"--level",
default="GreenHillZone.Act1",
help="lebel to train on (default: GreenHillZone.Act1)",
)
parser.add_argument(
"--num-processes",
type=int,
default=4,
help="how many training CPU processes to use (default: 4)",
)
parser.add_argument(
"--joint",
action="store_true",
default=False,
help="train on the full train set",
)
parser.add_argument("train_id", help="training id (used for the logs' name)")
return finalize_args(parser)