-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_rllib.py
85 lines (74 loc) · 3.22 KB
/
run_rllib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
import numpy as np
import retro
from utils import retro_wrappers
import ray
from ray.rllib.agents import ppo, impala, dqn
from utils.rllib_utils import register_retro, train, test
parser = argparse.ArgumentParser()
parser.add_argument("game", type=str)
parser.add_argument("state", type=str)
parser.add_argument("-c", "--checkpoint", type=str)
parser.add_argument("-t", "--train", action="store_true")
parser.add_argument("-a", "--agent", type=str, default="PPO")
parser.add_argument("-f", "--framework", type=str, default="tf")
parser.add_argument("-e", "--episodes", type=int, default=1)
parser.add_argument("-r", "--record", action="store_true")
parser.add_argument("-s", "--scenario", type=str, default="scenario")
args = parser.parse_args()
if __name__ == "__main__":
game = args.game
state = args.state
wrapper = retro_wrappers.get_wrapper(game)
checkpoint = args.checkpoint
training = args.train
agent = args.agent
framework = args.framework
episode_count = args.episodes
record = args.record
scenario = args.scenario
info = ray.init(ignore_reinit_error=True)
register_retro(game, state, scenario, wrapper)
if agent == "PPO":
trainer_config = ppo.DEFAULT_CONFIG.copy()
trainer_config['log_level'] = "WARN"
trainer_config['clip_rewards'] = True
trainer_config["num_gpus"] = 0
trainer_config['output'] = './checkpoints/'
trainer_config['num_workers'] = 0
trainer_config["num_cpus_per_worker"] = 4
trainer_config["num_envs_per_worker"] = 1
trainer_config['lambda'] = 0.95
trainer_config['kl_coeff'] = 0.5
trainer_config['clip_param'] = 0.1
trainer_config['vf_clip_param'] = 10.0
trainer_config['entropy_coeff'] = 0.01
trainer_config["train_batch_size"] = 500
trainer_config['rollout_fragment_length'] = 100
trainer_config['sgd_minibatch_size'] = 128
trainer_config['num_sgd_iter'] = 10
trainer_config['batch_mode'] = "truncate_episodes"
trainer_config['observation_filter'] = "NoFilter"
trainer_config['framework'] = 'tf' if framework == "tf" else 'torch'
agent = ppo.PPOTrainer(config=trainer_config, env=game)
elif agent == "IMPALA":
trainer_config = impala.DEFAULT_CONFIG.copy()
trainer_config['log_level'] = "WARN"
trainer_config['clip_rewards'] = True
trainer_config["num_gpus"] = 1
trainer_config['output'] = './checkpoints/'
trainer_config['rollout_fragment_length'] = 50
trainer_config['train_batch_size'] = 500
trainer_config["remote_worker_envs"] = True
trainer_config['num_workers'] = 8
trainer_config['num_envs_per_worker'] = 4
trainer_config['lr_schedule'] = [
[0, 0.0005],
[20000000, 0.000000000001],
]
trainer_config['framework'] = 'tf' if framework == "tf" else 'torch'
agent = impala.ImpalaTrainer(config=trainer_config, env=game)
if training:
trainer = train(agent, checkpoint=checkpoint)
else:
test(agent, game, state, scenario, wrapper, checkpoint=checkpoint, render=True, record=record, episode_count=episode_count)