@@ -54,6 +54,8 @@ class DatasetImporterBuilder(
54
54
'ssot_session_key' ,
55
55
]
56
56
57
+ images_from_observation_dict = {}
58
+
57
59
58
60
@abc .abstractmethod
59
61
def get_description (self ):
@@ -83,9 +85,15 @@ def _info(self) -> tfds.core.DatasetInfo:
83
85
84
86
tmp = dict (features )
85
87
88
+ # add all image features from observations to a new featuresdict
89
+ self .images_from_observation_dict = self .get_images_from_observation_dict ()
90
+ if self .images_from_observation_dict :
91
+ tmp ['display_image' ] = self .images_from_observation_dict
92
+
86
93
for key in self .KEYS_TO_STRIP :
87
94
if key in tmp :
88
95
del tmp [key ]
96
+
89
97
features = tfds .features .FeaturesDict (tmp )
90
98
91
99
return tfds .core .DatasetInfo (
@@ -120,15 +128,28 @@ def _generate_examples(
120
128
def converter_fn (example ):
121
129
# Decode the RLDS Episode and transform it to numpy.
122
130
example_out = dict (example )
131
+
123
132
example_out ['steps' ] = tf .data .Dataset .from_tensor_slices (
124
133
example_out ['steps' ]
125
134
).map (decode_fn )
135
+
126
136
steps = list (iter (example_out ['steps' ].take (- 1 )))
127
137
example_out ['steps' ] = steps
128
138
129
139
example_out = dataset_utils .as_numpy (example_out )
140
+ first_step = example_out ['steps' ][0 ]
141
+ image_feature_dict = {}
142
+
143
+ for feature_name in self .images_from_observation_dict :
144
+ image_feature_dict [feature_name ] = first_step ['observation' ][
145
+ feature_name
146
+ ]
147
+
148
+ if image_feature_dict :
149
+ example_out ['display_image' ] = image_feature_dict
130
150
131
151
example_id = example_out ['tfds_id' ].decode ('utf-8' )
152
+
132
153
del example_out ['tfds_id' ]
133
154
for key in self .KEYS_TO_STRIP :
134
155
if key in example_out :
@@ -148,3 +169,17 @@ def get_ds_builder(self):
148
169
ds_location = self .get_dataset_location ()
149
170
ds_builder = tfds .builder_from_directory (ds_location )
150
171
return ds_builder
172
+
173
+ def get_images_from_observation_dict (self ):
174
+ features = self .get_ds_builder ().info .features
175
+ tmp = dict (features )
176
+ images_from_observation = {}
177
+ if 'steps' in tmp and 'observation' in tmp ['steps' ]:
178
+ observation = tmp ['steps' ]['observation' ]
179
+ for feature_name , feature_data in observation .items ():
180
+ if isinstance (feature_data , tfds .features .Image ):
181
+ images_from_observation [feature_name ] = feature_data
182
+ images_from_observation_dict = tfds .features .FeaturesDict (
183
+ images_from_observation
184
+ )
185
+ return images_from_observation_dict
0 commit comments