forked from ikostrikov/jaxrl2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_online.py
127 lines (106 loc) · 4.18 KB
/
train_online.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#! /usr/bin/env python
import gymnasium as gym
from gymnasium.envs.registration import registry
import tqdm
import wandb
from absl import app, flags
from ml_collections import config_flags
from jaxrl3.agents import SACLearner
from jaxrl3.data import ReplayBuffer
from jaxrl3.evaluation import evaluate
from jaxrl3.wrappers import wrap_gym, set_universal_seed
FLAGS = flags.FLAGS
flags.DEFINE_string("env_name", "HalfCheetah-v4", "Environment name.")
flags.DEFINE_string("save_dir", "./tmp/", "Tensorboard logging dir.")
flags.DEFINE_integer("seed", 42, "Random seed.")
flags.DEFINE_integer("eval_episodes", 10, "Number of episodes used for evaluation.")
flags.DEFINE_integer("log_interval", 1000, "Logging interval.")
flags.DEFINE_integer("eval_interval", 5000, "Eval interval.")
flags.DEFINE_integer("batch_size", 256, "Mini batch size.")
flags.DEFINE_integer("max_steps", int(1e6), "Number of training steps.")
flags.DEFINE_integer(
"start_training", int(1e4), "Number of training steps to start training."
)
flags.DEFINE_boolean("tqdm", True, "Use tqdm progress bar.")
flags.DEFINE_boolean("wandb", True, "Log wandb.")
flags.DEFINE_boolean("save_video", False, "Save videos during evaluation.")
config_flags.DEFINE_config_file(
"config",
"configs/sac_default.py",
"File path to the training hyperparameter configuration.",
lock_config=False,
)
def main(_):
wandb.init(project="jaxrl3_online")
wandb.config.update(FLAGS)
def check_env_id(env_id):
dm_control_env_ids = [
id
for id in registry
if id.startswith("dm_control/") and id != "dm_control/compatibility-env-v0"
]
if not env_id.startswith("dm_control/"):
for id in dm_control_env_ids:
if env_id in id:
env_id = "dm_control/" + env_id
if env_id not in registry:
raise ValueError("Provide valid env id.")
return env_id
def make_and_wrap_env(env_id):
env = gym.make(check_env_id(env_id))
return wrap_gym(env, rescale_actions=True)
env = make_and_wrap_env(FLAGS.env_name)
env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=1)
set_universal_seed(env, FLAGS.seed)
eval_env = make_and_wrap_env(FLAGS.env_name)
set_universal_seed(eval_env, FLAGS.seed + 2)
kwargs = dict(FLAGS.config)
agent = SACLearner(FLAGS.seed, env.observation_space, env.action_space, **kwargs)
replay_buffer = ReplayBuffer(
env.observation_space, env.action_space, FLAGS.max_steps
)
replay_buffer.seed(FLAGS.seed)
observation, _ = env.reset()
done = False
for i in tqdm.tqdm(
range(1, FLAGS.max_steps + 1), smoothing=0.1, disable=not FLAGS.tqdm
):
if i < FLAGS.start_training:
action = env.action_space.sample()
else:
action = agent.sample_actions(observation)
next_observation, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
if not terminated:
mask = 1.0
else:
mask = 0.0
replay_buffer.insert(
dict(
observations=observation,
actions=action,
rewards=reward,
masks=mask,
dones=done,
next_observations=next_observation,
)
)
observation = next_observation
if done:
observation, _ = env.reset()
done = False
for k, v in info["episode"].items():
decode = {"r": "return", "l": "length", "t": "time"}
wandb.log({f"training/{decode[k]}": v}, step=i)
if i >= FLAGS.start_training:
batch = replay_buffer.sample(FLAGS.batch_size)
update_info = agent.update(batch)
if i % FLAGS.log_interval == 0:
for k, v in update_info.items():
wandb.log({f"training/{k}": v}, step=i)
if i % FLAGS.eval_interval == 0:
eval_info = evaluate(agent, eval_env, num_episodes=FLAGS.eval_episodes)
for k, v in eval_info.items():
wandb.log({f"evaluation/{k}": v}, step=i)
if __name__ == "__main__":
app.run(main)