-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_iql.py
436 lines (380 loc) · 17.4 KB
/
main_iql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
import os, time
import gzip
from datetime import datetime
from tqdm import tqdm
from functools import partial
import numpy as np
import jax
import jax.numpy as jp
import flax
import pickle
# import wandb
from ml_collections import config_flags
# from agent import hiql
from agent import iql
from utils.gc_dataset import GCSDataset
from utils import d4rl_utils, d4rl_ant, ant_diagnostics, viz_utils
from utils.additional import record_video, CsvLogger
# from jaxrl_m.wandb import setup_wandb, default_wandb_config
from jaxrl_m.evaluation import supply_rng, evaluate_with_trajectories, EpisodeMonitor
from absl import app, flags
FLAGS = flags.FLAGS
flags.DEFINE_string('env_name', 'antmaze-large-diverse-v2', '')
flags.DEFINE_string('save_dir', f'experiment_output/', '')
flags.DEFINE_string('run_group', 'Debug', '')
flags.DEFINE_integer('seed', 0, '')
flags.DEFINE_integer('eval_episodes', 50, '')
flags.DEFINE_integer('num_video_episodes', 2, '')
flags.DEFINE_integer('log_interval', 1000, '')
flags.DEFINE_integer('eval_interval', 100000, '')
flags.DEFINE_integer('save_interval', 100000, '')
flags.DEFINE_integer('batch_size', 1024, '')
flags.DEFINE_integer('pretrain_steps', 0, '')
# flags.DEFINE_integer('layer_norm', 1, '')
flags.DEFINE_integer('layer_norm', 1, '')
flags.DEFINE_integer('value_hidden_dim', 512, '')
flags.DEFINE_integer('value_num_layers', 3, '')
flags.DEFINE_integer('use_rep', 0, '')
flags.DEFINE_integer('rep_dim', None, '')
flags.DEFINE_enum('rep_type', 'state', ['state', 'diff', 'concat'], '')
flags.DEFINE_integer('policy_train_rep', 0, '')
flags.DEFINE_integer('use_waypoints', 0, '')
flags.DEFINE_integer('way_steps', 1, '')
flags.DEFINE_float('pretrain_expectile', 0.7, '')
flags.DEFINE_float('p_randomgoal', 0.3, '')
flags.DEFINE_float('p_trajgoal', 0.5, '')
flags.DEFINE_float('p_currgoal', 0.2, '')
flags.DEFINE_float('high_p_randomgoal', 0., '')
flags.DEFINE_integer('geom_sample', 1, '')
flags.DEFINE_float('discount', 0.99, '')
flags.DEFINE_float('temperature', 1, '')
flags.DEFINE_float('high_temperature', 1, '')
flags.DEFINE_integer('visual', 0, '')
flags.DEFINE_string('encoder', 'impala', '')
flags.DEFINE_string('algo_name', None, '') # Not used, only for logging
# wandb_config = default_wandb_config()
# wandb_config.update({
# 'project': 'hiql',
# 'group': 'Debug',
# 'name': '{env_name}',
# })
# config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False)
config_flags.DEFINE_config_dict('config', iql.get_default_configs(), lock_config=False)
gcdataset_config = GCSDataset.get_default_config()
config_flags.DEFINE_config_dict('gcdataset', gcdataset_config, lock_config=False)
@jax.jit
def get_debug_statistics(agent, batch):
def get_info(s, g):
return agent.network(s, g, info=True, method='value')
s = batch['observations']
g = batch['goals']
info = get_info(s, g)
stats = {}
stats.update({
'v': info['v'].mean(),
})
return stats
@jax.jit
def get_gcvalue(agent, s, g):
v1, v2 = agent.network(s, g, method='value')
return (v1 + v2) / 2
def get_v(agent, goal, observations):
goal = jp.tile(goal, (observations.shape[0], 1))
return get_gcvalue(agent, observations, goal)
@jax.jit
def get_traj_v(agent, trajectory):
def get_v(s, g):
v1, v2 = agent.network(jax.tree_map(lambda x: x[None], s), jax.tree_map(lambda x: x[None], g), method='value')
return (v1 + v2) / 2
observations = trajectory['observations']
all_values = jax.vmap(jax.vmap(get_v, in_axes=(None, 0)), in_axes=(0, None))(observations, observations)
return {
'dist_to_beginning': all_values[:, 0],
'dist_to_end': all_values[:, -1],
'dist_to_middle': all_values[:, all_values.shape[1] // 2],
}
def main(_):
g_start_time = int(datetime.now().timestamp())
exp_name = ''
exp_name += f'sd{FLAGS.seed:03d}_'
if 'SLURM_JOB_ID' in os.environ:
exp_name += f's_{os.environ["SLURM_JOB_ID"]}.'
if 'SLURM_PROCID' in os.environ:
exp_name += f'{os.environ["SLURM_PROCID"]}.'
if 'SLURM_RESTART_COUNT' in os.environ:
exp_name += f'rs_{os.environ["SLURM_RESTART_COUNT"]}.'
exp_name += f'{g_start_time}'
# exp_name += f'_{FLAGS.wandb["name"]}'
FLAGS.gcdataset['p_randomgoal'] = FLAGS.p_randomgoal
FLAGS.gcdataset['p_trajgoal'] = FLAGS.p_trajgoal
FLAGS.gcdataset['p_currgoal'] = FLAGS.p_currgoal
FLAGS.gcdataset['geom_sample'] = FLAGS.geom_sample
FLAGS.gcdataset['high_p_randomgoal'] = FLAGS.high_p_randomgoal
FLAGS.gcdataset['way_steps'] = FLAGS.way_steps
FLAGS.gcdataset['discount'] = FLAGS.discount
FLAGS.config['pretrain_expectile'] = FLAGS.pretrain_expectile
FLAGS.config['discount'] = FLAGS.discount
FLAGS.config['temperature'] = FLAGS.temperature
FLAGS.config['high_temperature'] = FLAGS.high_temperature
FLAGS.config['use_waypoints'] = FLAGS.use_waypoints
FLAGS.config['way_steps'] = FLAGS.way_steps
FLAGS.config['value_hidden_dims'] = (FLAGS.value_hidden_dim,) * FLAGS.value_num_layers
FLAGS.config['use_rep'] = FLAGS.use_rep
FLAGS.config['rep_dim'] = FLAGS.rep_dim
FLAGS.config['policy_train_rep'] = FLAGS.policy_train_rep
# Create wandb logger
params_dict = {**FLAGS.gcdataset.to_dict(), **FLAGS.config.to_dict()}
# FLAGS.wandb['name'] = FLAGS.wandb['exp_descriptor'] = exp_name
# FLAGS.wandb['group'] = FLAGS.wandb['exp_prefix'] = FLAGS.run_group
# setup_wandb(params_dict, **FLAGS.wandb)
# FLAGS.save_dir = os.path.join(FLAGS.save_dir, wandb.run.project, wandb.config.exp_prefix,
# wandb.config.experiment_id)
os.makedirs(FLAGS.save_dir, exist_ok=True)
goal_info = None
discrete = False
if 'antmaze' in FLAGS.env_name:
env_name = FLAGS.env_name
if 'ultra' in FLAGS.env_name:
import d4rl_ext
import gym
env = gym.make(env_name)
env = EpisodeMonitor(env)
else:
env = d4rl_utils.make_env(env_name)
dataset = d4rl_utils.get_dataset(env, FLAGS.env_name)
dataset = dataset.copy({'rewards': dataset['rewards'] - 1.0})
env.render(mode='rgb_array', width=200, height=200)
if 'large' in FLAGS.env_name:
env.viewer.cam.lookat[0] = 18
env.viewer.cam.lookat[1] = 12
env.viewer.cam.distance = 50
env.viewer.cam.elevation = -90
viz_env, viz_dataset = d4rl_ant.get_env_and_dataset(env_name)
viz = ant_diagnostics.Visualizer(env_name, viz_env, viz_dataset, discount=FLAGS.discount)
init_state = np.copy(viz_dataset['observations'][0])
init_state[:2] = (12.5, 8)
elif 'ultra' in FLAGS.env_name:
env.viewer.cam.lookat[0] = 26
env.viewer.cam.lookat[1] = 18
env.viewer.cam.distance = 70
env.viewer.cam.elevation = -90
else:
env.viewer.cam.lookat[0] = 18
env.viewer.cam.lookat[1] = 12
env.viewer.cam.distance = 50
env.viewer.cam.elevation = -90
# elif 'kitchen' in FLAGS.env_name:
# env = d4rl_utils.make_env(FLAGS.env_name)
# dataset = d4rl_utils.get_dataset(env, FLAGS.env_name, filter_terminals=True)
# dataset = dataset.copy({'observations': dataset['observations'][:, :30],
# 'next_observations': dataset['next_observations'][:, :30]})
# elif 'calvin' in FLAGS.env_name:
# from src.envs.calvin import CalvinEnv
# from hydra import compose, initialize
# from src.envs.gym_env import GymWrapper
# from src.envs.gym_env import wrap_env
# initialize(config_path='src/envs/conf')
# cfg = compose(config_name='calvin')
# env = CalvinEnv(**cfg)
# env.max_episode_steps = cfg.max_episode_steps = 360
# env = GymWrapper(
# env=env,
# from_pixels=cfg.pixel_ob,
# from_state=cfg.state_ob,
# height=cfg.screen_size[0],
# width=cfg.screen_size[1],
# channels_first=False,
# frame_skip=cfg.action_repeat,
# return_state=False,
# )
# env = wrap_env(env, cfg)
#
# data = pickle.load(gzip.open('data/calvin.gz', "rb"))
# ds = []
# for i, d in enumerate(data):
# if len(d['obs']) < len(d['dones']):
# continue # Skip incomplete trajectories.
# # Only use the first 21 states of non-floating objects.
# d['obs'] = d['obs'][:, :21]
# new_d = dict(
# observations=d['obs'][:-1],
# next_observations=d['obs'][1:],
# actions=d['actions'][:-1],
# )
# num_steps = new_d['observations'].shape[0]
# new_d['rewards'] = np.zeros(num_steps)
# new_d['terminals'] = np.zeros(num_steps, dtype=bool)
# new_d['terminals'][-1] = True
# ds.append(new_d)
# dataset = dict()
# for key in ds[0].keys():
# dataset[key] = np.concatenate([d[key] for d in ds], axis=0)
# dataset = d4rl_utils.get_dataset(None, FLAGS.env_name, dataset=dataset)
# elif 'procgen' in FLAGS.env_name:
# from src.envs.procgen_env import ProcgenWrappedEnv, get_procgen_dataset
# import matplotlib
#
# matplotlib.use('Agg')
#
# n_processes = 1
# env_name = 'maze'
# env = ProcgenWrappedEnv(n_processes, env_name, 1, 1)
#
# if FLAGS.env_name == 'procgen-500':
# dataset = get_procgen_dataset('data/procgen/level500.npz', state_based=('state' in FLAGS.env_name))
# min_level, max_level = 0, 499
# elif FLAGS.env_name == 'procgen-1000':
# dataset = get_procgen_dataset('data/procgen/level1000.npz', state_based=('state' in FLAGS.env_name))
# min_level, max_level = 0, 999
# else:
# raise NotImplementedError
#
# # Test on large levels having >=20 border states
# large_levels = [12, 34, 35, 55, 96, 109, 129, 140, 143, 163, 176, 204, 234, 338, 344, 369, 370, 374, 410, 430,
# 468, 470, 476, 491] + [5034, 5046, 5052, 5080, 5082, 5142, 5244, 5245, 5268, 5272, 5283, 5335,
# 5342, 5366, 5375, 5413, 5430, 5474, 5491]
# goal_infos = [{'eval_level': [level for level in large_levels if min_level <= level <= max_level],
# 'eval_level_name': 'train'},
# {'eval_level': [level for level in large_levels if level > max_level], 'eval_level_name': 'test'}]
#
# dones_float = 1.0 - dataset['masks']
# dones_float[-1] = 1.0
# dataset = dataset.copy({
# 'dones_float': dones_float
# })
#
# discrete = True
# example_action = np.max(dataset['actions'], keepdims=True)
else:
raise NotImplementedError
env.reset()
pretrain_dataset = GCSDataset(dataset, **FLAGS.gcdataset.to_dict())
total_steps = FLAGS.pretrain_steps
example_batch = dataset.sample(1)
agent = iql.create_learner(FLAGS.seed,
example_batch['observations'],
# example_batch['actions'] if not discrete else example_action,
example_batch['actions'],
max_steps=1000,
**FLAGS.config)
# For debugging metrics
if 'antmaze' in FLAGS.env_name:
example_trajectory = pretrain_dataset.sample(50, indx=np.arange(1000, 1050))
# elif 'kitchen' in FLAGS.env_name:
# example_trajectory = pretrain_dataset.sample(50, indx=np.arange(0, 50))
# elif 'calvin' in FLAGS.env_name:
# example_trajectory = pretrain_dataset.sample(50, indx=np.arange(0, 50))
# elif 'procgen-500' in FLAGS.env_name:
# example_trajectory = pretrain_dataset.sample(50, indx=np.arange(5000, 5050))
# elif 'procgen-1000' in FLAGS.env_name:
# example_trajectory = pretrain_dataset.sample(50, indx=np.arange(5000, 5050))
else:
raise NotImplementedError
train_logger = CsvLogger(os.path.join(FLAGS.save_dir, 'train.csv'))
eval_logger = CsvLogger(os.path.join(FLAGS.save_dir, 'eval.csv'))
first_time = time.time()
last_time = time.time()
for i in tqdm(range(1, total_steps + 1), smoothing=0.1, dynamic_ncols=True):
pretrain_batch = pretrain_dataset.sample(FLAGS.batch_size)
agent, update_info = supply_rng(agent.update)(pretrain_batch)
if i % FLAGS.log_interval == 0:
debug_statistics = get_debug_statistics(agent, pretrain_batch)
train_metrics = {f'training/{k}': v for k, v in update_info.items()}
train_metrics.update({f'pretraining/debug/{k}': v for k, v in debug_statistics.items()})
train_metrics['time/epoch_time'] = (time.time() - last_time) / FLAGS.log_interval
train_metrics['time/total_time'] = (time.time() - first_time)
last_time = time.time()
# wandb.log(train_metrics, step=i)
train_logger.log(train_metrics, step=i)
if i == 1 or i % FLAGS.eval_interval == 0:
policy_fn = partial(supply_rng(agent.sample_actions), discrete=discrete)
high_policy_fn = partial(supply_rng(agent.sample_high_actions))
policy_rep_fn = agent.get_policy_rep
base_observation = jax.tree_map(lambda arr: arr[0], pretrain_dataset.dataset['observations'])
# if 'procgen' in FLAGS.env_name:
# eval_metrics = {}
# for goal_info in goal_infos:
# eval_info, trajs, renders = evaluate_with_trajectories(
# policy_fn, high_policy_fn, policy_rep_fn, env, env_name=FLAGS.env_name,
# num_episodes=FLAGS.eval_episodes,
# base_observation=base_observation, num_video_episodes=0,
# use_waypoints=FLAGS.use_waypoints,
# eval_temperature=0, epsilon=0.05,
# goal_info=goal_info, config=FLAGS.config,
# )
# eval_metrics.update(
# {f'evaluation/level{goal_info["eval_level_name"]}_{k}': v for k, v in eval_info.items()})
# else:
eval_info, trajs, renders = evaluate_with_trajectories(
policy_fn, high_policy_fn, policy_rep_fn, env, env_name=FLAGS.env_name,
num_episodes=FLAGS.eval_episodes,
base_observation=base_observation, num_video_episodes=FLAGS.num_video_episodes,
use_waypoints=FLAGS.use_waypoints,
eval_temperature=0,
goal_info=goal_info, config=FLAGS.config,
)
eval_metrics = {f'evaluation/{k}': v for k, v in eval_info.items()}
if FLAGS.num_video_episodes > 0:
video = record_video('Video', i, renders=renders)
eval_metrics['video'] = video
traj_metrics = get_traj_v(agent, example_trajectory)
value_viz = viz_utils.make_visual_no_image(
traj_metrics,
[partial(viz_utils.visualize_metric, metric_name=k) for k in traj_metrics.keys()]
)
# eval_metrics['value_traj_viz'] = wandb.Image(value_viz)
if 'antmaze' in FLAGS.env_name and 'large' in FLAGS.env_name and FLAGS.env_name.startswith('antmaze'):
traj_image = d4rl_ant.trajectory_image(viz_env, viz_dataset, trajs)
# eval_metrics['trajectories'] = wandb.Image(traj_image)
new_metrics_dist = viz.get_distance_metrics(trajs)
eval_metrics.update({
f'debugging/{k}': v for k, v in new_metrics_dist.items()})
image_v = d4rl_ant.gcvalue_image(
viz_env,
viz_dataset,
partial(get_v, agent),
)
# eval_metrics['v'] = wandb.Image(image_v)
# wandb.log(eval_metrics, step=i)
eval_logger.log(eval_metrics, step=i)
if i % FLAGS.save_interval == 0:
save_dict = dict(
agent=flax.serialization.to_state_dict(agent),
config=FLAGS.config.to_dict()
)
fname = os.path.join(FLAGS.save_dir, f'params_{i}.pkl')
print(f'Saving to {fname}')
with open(fname, "wb") as f:
pickle.dump(save_dict, f)
train_logger.close()
eval_logger.close()
if __name__ == '__main__':
import sys
from utils.config import Config
config = Config(env_name='antmaze-medium-play-v2',
run_group='EXP',
seed=0,
pretrain_steps=500002,
eval_interval=50000,
save_interval=125000,
p_currgoal=0.2,
p_trajgoal=0.5,
p_randomgoal=0.3,
high_p_randomgoal=0.3,
discount=0.99,
temperature=1.0,
high_temperature=1,
pretrain_expectile=0.7,
geom_sample=1,
layer_norm=1,
value_hidden_dim=512,
value_num_layers=3,
batch_size=1024,
use_rep=0,
policy_train_rep=0,
algo_name='hiql',
use_waypoints=1,
way_steps=25,)
argv = config.to_argv()
sys.argv.extend(argv)
app.run(main)