forked from lucidrains/robotic-transformer-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
data.py
522 lines (430 loc) · 18.6 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
# Taken from https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/edit#gid=0
import abc
import dataclasses
from typing import Any, Dict, Iterable, Optional, Union
import numpy as np
import reverb
import tensorflow as tf
import tensorflow_datasets as tfds
import tree
from rlds import rlds_types, transformations
tf.config.experimental.set_visible_devices([], "GPU")
def dataset2path(name):
if name == "robo_net":
version = "1.0.0"
elif name == "language_table":
version = "0.0.1"
else:
version = "0.1.0"
return f"gs://gresearch/robotics/{name}/{version}"
def as_gif(images, path="temp.gif"):
# Render the images as the gif:
images[0].save(path, save_all=True, append_images=images[1:], duration=1000, loop=0)
gif_bytes = open(path, "rb").read()
return gif_bytes
def _features_to_tensor_spec(feature: tfds.features.FeatureConnector) -> tf.TensorSpec:
"""Converts a tfds Feature into a TensorSpec."""
def _get_feature_spec(nested_feature: tfds.features.FeatureConnector):
if isinstance(nested_feature, tf.DType):
return tf.TensorSpec(shape=(), dtype=nested_feature)
else:
return nested_feature.get_tensor_spec()
# FeaturesDict can sometimes be a plain dictionary, so we use tf.nest to
# make sure we deal with the nested structure.
return tf.nest.map_structure(_get_feature_spec, feature)
def _encoded_feature(
feature: Optional[tfds.features.FeatureConnector],
image_encoding: Optional[str],
tensor_encoding: Optional[tfds.features.Encoding],
):
"""Adds encoding to Images and/or Tensors."""
def _apply_encoding(
feature: tfds.features.FeatureConnector,
image_encoding: Optional[str],
tensor_encoding: Optional[tfds.features.Encoding],
):
if image_encoding and isinstance(feature, tfds.features.Image):
return tfds.features.Image(
shape=feature.shape,
dtype=feature.dtype,
use_colormap=feature.use_colormap,
encoding_format=image_encoding,
)
if (
tensor_encoding
and isinstance(feature, tfds.features.Tensor)
and feature.dtype != tf.string
):
return tfds.features.Tensor(
shape=feature.shape, dtype=feature.dtype, encoding=tensor_encoding
)
return feature
if not feature:
return None
return tf.nest.map_structure(
lambda x: _apply_encoding(x, image_encoding, tensor_encoding), feature
)
@dataclasses.dataclass
class RLDSSpec(metaclass=abc.ABCMeta):
"""Specification of an RLDS Dataset.
It is used to hold a spec that can be converted into a TFDS DatasetInfo or
a `tf.data.Dataset` spec.
"""
observation_info: Optional[tfds.features.FeatureConnector] = None
action_info: Optional[tfds.features.FeatureConnector] = None
reward_info: Optional[tfds.features.FeatureConnector] = None
discount_info: Optional[tfds.features.FeatureConnector] = None
step_metadata_info: Optional[tfds.features.FeaturesDict] = None
episode_metadata_info: Optional[tfds.features.FeaturesDict] = None
def step_tensor_spec(self) -> Dict[str, tf.TensorSpec]:
"""Obtains the TensorSpec of an RLDS step."""
step = {}
if self.observation_info:
step[rlds_types.OBSERVATION] = _features_to_tensor_spec(
self.observation_info
)
if self.action_info:
step[rlds_types.ACTION] = _features_to_tensor_spec(self.action_info)
if self.discount_info:
step[rlds_types.DISCOUNT] = _features_to_tensor_spec(self.discount_info)
if self.reward_info:
step[rlds_types.REWARD] = _features_to_tensor_spec(self.reward_info)
if self.step_metadata_info:
for k, v in self.step_metadata_info.items():
step[k] = _features_to_tensor_spec(v)
step[rlds_types.IS_FIRST] = tf.TensorSpec(shape=(), dtype=bool)
step[rlds_types.IS_LAST] = tf.TensorSpec(shape=(), dtype=bool)
step[rlds_types.IS_TERMINAL] = tf.TensorSpec(shape=(), dtype=bool)
return step
def episode_tensor_spec(self) -> Dict[str, tf.TensorSpec]:
"""Obtains the TensorSpec of an RLDS step."""
episode = {}
episode[rlds_types.STEPS] = tf.data.DatasetSpec(
element_spec=self.step_tensor_spec()
)
if self.episode_metadata_info:
for k, v in self.episode_metadata_info.items():
episode[k] = _features_to_tensor_spec(v)
return episode
def to_dataset_config(
self,
name: str,
image_encoding: Optional[str] = None,
tensor_encoding: Optional[tfds.features.Encoding] = None,
citation: Optional[str] = None,
homepage: Optional[str] = None,
description: Optional[str] = None,
overall_description: Optional[str] = None,
) -> tfds.rlds.rlds_base.DatasetConfig:
"""Obtains the DatasetConfig for TFDS from the Spec."""
return tfds.rlds.rlds_base.DatasetConfig(
name=name,
description=description,
overall_description=overall_description,
homepage=homepage,
citation=citation,
observation_info=_encoded_feature(
self.observation_info, image_encoding, tensor_encoding
),
action_info=_encoded_feature(
self.action_info, image_encoding, tensor_encoding
),
reward_info=_encoded_feature(
self.reward_info, image_encoding, tensor_encoding
),
discount_info=_encoded_feature(
self.discount_info, image_encoding, tensor_encoding
),
step_metadata_info=_encoded_feature(
self.step_metadata_info, image_encoding, tensor_encoding
),
episode_metadata_info=_encoded_feature(
self.episode_metadata_info, image_encoding, tensor_encoding
),
)
def to_features_dict(self):
"""Returns a TFDS FeaturesDict representing the dataset config."""
step_config = {
rlds_types.IS_FIRST: tf.bool,
rlds_types.IS_LAST: tf.bool,
rlds_types.IS_TERMINAL: tf.bool,
}
if self.observation_info:
step_config[rlds_types.OBSERVATION] = self.observation_info
if self.action_info:
step_config[rlds_types.ACTION] = self.action_info
if self.discount_info:
step_config[rlds_types.DISCOUNT] = self.discount_info
if self.reward_info:
step_config[rlds_types.REWARD] = self.reward_info
if self.step_metadata_info:
for k, v in self.step_metadata_info.items():
step_config[k] = v
if self.episode_metadata_info:
return tfds.features.FeaturesDict(
{
rlds_types.STEPS: tfds.features.Dataset(step_config),
**self.episode_metadata_info,
}
)
else:
return tfds.features.FeaturesDict(
{
rlds_types.STEPS: tfds.features.Dataset(step_config),
}
)
RLDS_SPEC = RLDSSpec
TENSOR_SPEC = Union[tf.TensorSpec, dict[str, tf.TensorSpec]]
@dataclasses.dataclass
class TrajectoryTransform(metaclass=abc.ABCMeta):
"""Specification the TrajectoryTransform applied to a dataset of episodes.
A TrajectoryTransform is a set of rules transforming a dataset
of RLDS episodes to a dataset of trajectories.
This involves three distinct stages:
- An optional `episode_to_steps_map_fn(episode)` is called at the episode
level, and can be used to select or modify steps.
- Augmentation: an `episode_key` could be propagated to `steps` for
debugging.
- Selection: Particular steps can be selected.
- Stripping: Features can be removed from steps. Prefer using `step_map_fn`.
- An optional `step_map_fn` is called at the flattened steps dataset for each
step, and can be used to featurize a step, e.g. add/remove features, or
augument images
- A `pattern` leverages DM patterns to set a rule of slicing an episode to a
dataset of overlapping trajectories.
Importantly, each TrajectoryTransform must define a `expected_tensor_spec`
which specifies a nested TensorSpec of the resulting dataset. This is what
this TrajectoryTransform will produce, and can be used as an interface with
a neural network.
"""
episode_dataset_spec: RLDS_SPEC
episode_to_steps_fn_dataset_spec: RLDS_SPEC
steps_dataset_spec: Any
pattern: reverb.structured_writer.Pattern
episode_to_steps_map_fn: Any
expected_tensor_spec: TENSOR_SPEC
step_map_fn: Optional[Any] = None
def get_for_cached_trajectory_transform(self):
"""Creates a copy of this traj transform to use with caching.
The returned TrajectoryTransfrom copy will be initialized with the default
version of the `episode_to_steps_map_fn`, because the effect of that
function has already been materialized in the cached copy of the dataset.
Returns:
trajectory_transform: A copy of the TrajectoryTransform with overridden
`episode_to_steps_map_fn`.
"""
traj_copy = dataclasses.replace(self)
traj_copy.episode_dataset_spec = traj_copy.episode_to_steps_fn_dataset_spec
traj_copy.episode_to_steps_map_fn = lambda e: e[rlds_types.STEPS]
return traj_copy
def transform_episodic_rlds_dataset(self, episodes_dataset: tf.data.Dataset):
"""Applies this TrajectoryTransform to the dataset of episodes."""
# Convert the dataset of episodes to the dataset of steps.
steps_dataset = episodes_dataset.map(
self.episode_to_steps_map_fn, num_parallel_calls=tf.data.AUTOTUNE
).flat_map(lambda x: x)
return self._create_pattern_dataset(steps_dataset)
def transform_steps_rlds_dataset(
self, steps_dataset: tf.data.Dataset
) -> tf.data.Dataset:
"""Applies this TrajectoryTransform to the dataset of episode steps."""
return self._create_pattern_dataset(steps_dataset)
def create_test_dataset(
self,
) -> tf.data.Dataset:
"""Creates a test dataset of trajectories.
It is guaranteed that the structure of this dataset will be the same as
when flowing real data. Hence this is a useful construct for tests or
initialization of JAX models.
Returns:
dataset: A test dataset made of zeros structurally identical to the
target dataset of trajectories.
"""
zeros = transformations.zeros_from_spec(self.expected_tensor_spec)
return tf.data.Dataset.from_tensors(zeros)
def _create_pattern_dataset(
self, steps_dataset: tf.data.Dataset
) -> tf.data.Dataset:
"""Create PatternDataset from the `steps_dataset`."""
config = create_structured_writer_config("temp", self.pattern)
# Further transform each step if the `step_map_fn` is provided.
if self.step_map_fn:
steps_dataset = steps_dataset.map(self.step_map_fn)
pattern_dataset = reverb.PatternDataset(
input_dataset=steps_dataset,
configs=[config],
respect_episode_boundaries=True,
is_end_of_episode=lambda x: x[rlds_types.IS_LAST],
)
return pattern_dataset
class TrajectoryTransformBuilder(object):
"""Facilitates creation of the `TrajectoryTransform`."""
def __init__(
self,
dataset_spec: RLDS_SPEC,
episode_to_steps_map_fn=lambda e: e[rlds_types.STEPS],
step_map_fn=None,
pattern_fn=None,
expected_tensor_spec=None,
):
self._rds_dataset_spec = dataset_spec
self._steps_spec = None
self._episode_to_steps_map_fn = episode_to_steps_map_fn
self._step_map_fn = step_map_fn
self._pattern_fn = pattern_fn
self._expected_tensor_spec = expected_tensor_spec
def build(self, validate_expected_tensor_spec: bool = True) -> TrajectoryTransform:
"""Creates `TrajectoryTransform` from a `TrajectoryTransformBuilder`."""
if validate_expected_tensor_spec and self._expected_tensor_spec is None:
raise ValueError("`expected_tensor_spec` must be set.")
episode_ds = zero_episode_dataset_from_spec(self._rds_dataset_spec)
steps_ds = episode_ds.flat_map(self._episode_to_steps_map_fn)
episode_to_steps_fn_dataset_spec = self._rds_dataset_spec
if self._step_map_fn is not None:
steps_ds = steps_ds.map(self._step_map_fn)
zeros_spec = transformations.zeros_from_spec(
steps_ds.element_spec
) # pytype: disable=wrong-arg-types
ref_step = reverb.structured_writer.create_reference_step(zeros_spec)
pattern = self._pattern_fn(ref_step)
steps_ds_spec = steps_ds.element_spec
target_tensor_structure = create_reverb_table_signature(
"temp_table", steps_ds_spec, pattern
)
if (
validate_expected_tensor_spec
and self._expected_tensor_spec != target_tensor_structure
):
raise RuntimeError(
"The tensor spec of the TrajectoryTransform doesn't "
"match the expected spec.\n"
"Expected:\n%s\nActual:\n%s\n"
% (
str(self._expected_tensor_spec).replace(
"TensorSpec", "tf.TensorSpec"
),
str(target_tensor_structure).replace("TensorSpec", "tf.TensorSpec"),
)
)
return TrajectoryTransform(
episode_dataset_spec=self._rds_dataset_spec,
episode_to_steps_fn_dataset_spec=episode_to_steps_fn_dataset_spec,
steps_dataset_spec=steps_ds_spec,
pattern=pattern,
episode_to_steps_map_fn=self._episode_to_steps_map_fn,
step_map_fn=self._step_map_fn,
expected_tensor_spec=target_tensor_structure,
)
def zero_episode_dataset_from_spec(rlds_spec: RLDS_SPEC):
"""Creates a zero valued dataset of episodes for the given RLDS Spec."""
def add_steps(episode, step_spec):
episode[rlds_types.STEPS] = transformations.zero_dataset_like(
tf.data.DatasetSpec(step_spec)
)
if "fake" in episode:
del episode["fake"]
return episode
episode_without_steps_spec = {
k: v
for k, v in rlds_spec.episode_tensor_spec().items()
if k != rlds_types.STEPS
}
if episode_without_steps_spec:
episodes_dataset = transformations.zero_dataset_like(
tf.data.DatasetSpec(episode_without_steps_spec)
)
else:
episodes_dataset = tf.data.Dataset.from_tensors({"fake": ""})
episodes_dataset_with_steps = episodes_dataset.map(
lambda episode: add_steps(episode, rlds_spec.step_tensor_spec())
)
return episodes_dataset_with_steps
def create_reverb_table_signature(
table_name: str, steps_dataset_spec, pattern: reverb.structured_writer.Pattern
) -> reverb.reverb_types.SpecNest:
config = create_structured_writer_config(table_name, pattern)
reverb_table_spec = reverb.structured_writer.infer_signature(
[config], steps_dataset_spec
)
return reverb_table_spec
def create_structured_writer_config(
table_name: str, pattern: reverb.structured_writer.Pattern
) -> Any:
config = reverb.structured_writer.create_config(
pattern=pattern, table=table_name, conditions=[]
)
return config
def n_step_pattern_builder(n: int) -> Any:
"""Creates trajectory of length `n` from all fields of a `ref_step`."""
def transform_fn(ref_step):
traj = {}
for key in ref_step:
if isinstance(ref_step[key], dict):
transformed_entry = tree.map_structure(
lambda ref_node: ref_node[-n:], ref_step[key]
)
traj[key] = transformed_entry
else:
traj[key] = ref_step[key][-n:]
return traj
return transform_fn
def get_observation_and_action_from_step(step):
return {
"observation": {
"image": step["observation"]["image"],
"embedding": step["observation"]["natural_language_embedding"],
"instruction": step["observation"]["natural_language_instruction"],
},
# Decode one hot discrete actions
"action": {
k: tf.argmax(v, axis=-1) if v.dtype == tf.int32 else v
for k, v in step["action"].items()
},
}
def create_dataset(
datasets=["fractal20220817_data"],
split="train",
trajectory_length=6,
batch_size=32,
num_epochs=1,
) -> Iterable[Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]]:
trajectory_datasets = []
for dataset in datasets:
b = tfds.builder_from_directory(builder_dir=dataset2path(dataset))
ds = b.as_dataset(split=split)
# The RLDSSpec for the RT1 dataset.
rt1_spec = RLDSSpec(
observation_info=b.info.features["steps"]["observation"],
action_info=b.info.features["steps"]["action"],
)
trajectory_transform = TrajectoryTransformBuilder(
rt1_spec, pattern_fn=n_step_pattern_builder(trajectory_length)
).build(validate_expected_tensor_spec=False)
trajectory_dataset = trajectory_transform.transform_episodic_rlds_dataset(ds)
trajectory_datasets.append(trajectory_dataset)
trajectory_dataset = tf.data.Dataset.sample_from_datasets(trajectory_datasets)
trajectory_dataset = trajectory_dataset.map(
get_observation_and_action_from_step, num_parallel_calls=tf.data.AUTOTUNE
)
# Shuffle, batch, prefetch, repeat
trajectory_dataset = trajectory_dataset.shuffle(batch_size * 16)
trajectory_dataset = trajectory_dataset.batch(
batch_size,
drop_remainder=True,
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False,
)
trajectory_dataset = trajectory_dataset.repeat(num_epochs)
trajectory_dataset = trajectory_dataset.prefetch(tf.data.AUTOTUNE)
return iter(trajectory_dataset.as_numpy_iterator())
if __name__ == "__main__":
ds = create_dataset(datasets=["fractal20220817_data"], split="train[:10]")
it = next(ds)
def print_shape(x):
if isinstance(x, dict):
shapes = tree.map_structure(lambda x: x.shape, x)
else:
shapes = x.shape
return shapes
shapes = tree.map_structure(print_shape, it)
print(shapes)