-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrollout.py
executable file
·104 lines (92 loc) · 3.56 KB
/
rollout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import os
import pickle
import gym
import ray
from ray.rllib.agent import get_agent_class
from ray.rllib.dqn.common.wrappers import wrap_dqn
from ray.rllib.models import ModelCatalog
from ray.tune.registry import get_registry
EXAMPLE_USAGE = """
example usage:
./rollout.py /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN """
"""--env CartPole-v0 --steps 1000000 --out rollouts.pkl
"""
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description="Roll out a reinforcement learning agent "
"given a checkpoint.", epilog=EXAMPLE_USAGE)
parser.add_argument(
"checkpoint", type=str, help="Checkpoint from which to roll out.")
required_named = parser.add_argument_group("required named arguments")
required_named.add_argument(
"--run", type=str, required=True,
help="The algorithm or model to train. This may refer to the name "
"of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
"user-defined trainable function or class registered in the "
"tune registry.")
required_named.add_argument(
"--env", type=str, help="The gym environment to use.")
parser.add_argument(
"--no-render", default=False, action="store_const", const=True,
help="Surpress rendering of the environment.")
parser.add_argument(
"--steps", default=None, help="Number of steps to roll out.")
parser.add_argument(
"--out", default=None, help="Output filename.")
parser.add_argument(
"--config", default="{}", type=json.loads,
help="Algorithm-specific configuration (e.g. env, hyperparams). "
"Surpresses loading of configuration from checkpoint.")
if __name__ == "__main__":
args = parser.parse_args()
if not args.config:
# Load configuration from file
config_dir = os.path.dirname(args.checkpoint)
config_path = os.path.join(config_dir, "params.json")
with open(config_path) as f:
args.config = json.load(f)
if not args.env:
if not args.config.get("env"):
parser.error("the following arguments are required: --env")
args.env = args.config.get("env")
ray.init()
cls = get_agent_class(args.run)
agent = cls(env=args.env, config=args.config)
agent.restore(args.checkpoint)
num_steps = int(args.steps)
if args.run == "DQN":
env = gym.make(args.env)
env = wrap_dqn(get_registry(), env, args.config.get("model", {}))
else:
env = ModelCatalog.get_preprocessor_as_wrapper(get_registry(),
gym.make(args.env))
if args.out is not None:
rollouts = []
steps = 0
while steps < (num_steps or steps + 1):
if args.out is not None:
rollout = []
state = env.reset()
done = False
reward_total = 0.0
while not done and steps < (num_steps or steps + 1):
action = agent.compute_action(state)
next_state, reward, done, _ = env.step(action)
reward_total += reward
if not args.no_render:
env.render()
if args.out is not None:
rollout.append([state, action, next_state, reward, done])
steps += 1
state = next_state
if args.out is not None:
rollouts.append(rollout)
print("Episode reward", reward_total)
if args.out is not None:
pickle.dump(rollouts, open(args.out, "wb"))