Skip to content

Commit ca75b94

Browse files
author
The TensorFlow Datasets Authors
committed
Add "display_image" feature to robotics dataset importer builder.
PiperOrigin-RevId: 638110575
1 parent 496c2d4 commit ca75b94

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tensorflow_datasets/robotics/dataset_importer_builder.py

+35
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ class DatasetImporterBuilder(
5454
'ssot_session_key',
5555
]
5656

57+
images_from_observation_dict = {}
58+
5759

5860
@abc.abstractmethod
5961
def get_description(self):
@@ -83,9 +85,15 @@ def _info(self) -> tfds.core.DatasetInfo:
8385

8486
tmp = dict(features)
8587

88+
# add all image features from observations to a new featuresdict
89+
self.images_from_observation_dict = self.get_images_from_observation_dict()
90+
if self.images_from_observation_dict:
91+
tmp['display_image'] = self.images_from_observation_dict
92+
8693
for key in self.KEYS_TO_STRIP:
8794
if key in tmp:
8895
del tmp[key]
96+
8997
features = tfds.features.FeaturesDict(tmp)
9098

9199
return tfds.core.DatasetInfo(
@@ -120,15 +128,28 @@ def _generate_examples(
120128
def converter_fn(example):
121129
# Decode the RLDS Episode and transform it to numpy.
122130
example_out = dict(example)
131+
123132
example_out['steps'] = tf.data.Dataset.from_tensor_slices(
124133
example_out['steps']
125134
).map(decode_fn)
135+
126136
steps = list(iter(example_out['steps'].take(-1)))
127137
example_out['steps'] = steps
128138

129139
example_out = dataset_utils.as_numpy(example_out)
140+
first_step = example_out['steps'][0]
141+
image_feature_dict = {}
142+
143+
for feature_name in self.images_from_observation_dict:
144+
image_feature_dict[feature_name] = first_step['observation'][
145+
feature_name
146+
]
147+
148+
if image_feature_dict:
149+
example_out['display_image'] = image_feature_dict
130150

131151
example_id = example_out['tfds_id'].decode('utf-8')
152+
132153
del example_out['tfds_id']
133154
for key in self.KEYS_TO_STRIP:
134155
if key in example_out:
@@ -148,3 +169,17 @@ def get_ds_builder(self):
148169
ds_location = self.get_dataset_location()
149170
ds_builder = tfds.builder_from_directory(ds_location)
150171
return ds_builder
172+
173+
def get_images_from_observation_dict(self):
174+
features = self.get_ds_builder().info.features
175+
tmp = dict(features)
176+
images_from_observation = {}
177+
if 'steps' in tmp and 'observation' in tmp['steps']:
178+
observation = tmp['steps']['observation']
179+
for feature_name, feature_data in observation.items():
180+
if isinstance(feature_data, tfds.features.Image):
181+
images_from_observation[feature_name] = feature_data
182+
images_from_observation_dict = tfds.features.FeaturesDict(
183+
images_from_observation
184+
)
185+
return images_from_observation_dict

0 commit comments

Comments
 (0)