-
Notifications
You must be signed in to change notification settings - Fork 4
/
utils.py
97 lines (85 loc) · 3.68 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
import torch
import torchvision.transforms.functional as F
class NPZParser:
def __init__(self, segment_length, image_size=64):
self.segment_length = segment_length
self.image_size = image_size
def preprocess(self, images):
images = images / 255
# images = F.center_crop(images, min(images.shape[2:]))
images = F.resize(images, [self.image_size, self.image_size])
return images
def get_segment(self, episode, stepsize=1):
# shrink stepsize if episode too short
if stepsize * self.segment_length > len(episode):
stepsize = max(1, len(episode) // self.segment_length)
start = np.random.randint(max(len(episode) - stepsize * self.segment_length + 1, 1))
images = [step for step in episode[start: start + stepsize * self.segment_length: stepsize]]
return images
def get_stepsize(self, dataset_name):
return max(round(BASE_STEPSIZE.get(dataset_name, 1) / BASE_STEPSIZE['fractal20220817_data']), 1)
def parse(self, npz_file, dataset_name):
images = np.load(npz_file)[DISPLAY_KEY.get(dataset_name, 'image')]
images = torch.Tensor(np.array(self.get_segment(images, self.get_stepsize(dataset_name)))
).permute(0, 3, 1, 2) # T, H, W, C -> T, C, H, W
images = self.preprocess(images)
return images
BASE_STEPSIZE = {
'fractal20220817_data': 3,
'kuka': 10,
'bridge': 5,
'taco_play': 15,
'jaco_play': 10,
'berkeley_cable_routing': 10,
'roboturk': 10,
'viola': 20,
'toto': 30,
'language_table': 10,
'columbia_cairlab_pusht_real': 10,
'stanford_kuka_multimodal_dataset_converted_externally_to_rlds': 20,
'stanford_hydra_dataset_converted_externally_to_rlds': 10,
'austin_buds_dataset_converted_externally_to_rlds': 20,
'nyu_franka_play_dataset_converted_externally_to_rlds': 3,
'maniskill_dataset_converted_externally_to_rlds': 20,
'furniture_bench_dataset_converted_externally_to_rlds': 10,
'ucsd_kitchen_dataset_converted_externally_to_rlds': 2,
'ucsd_pick_and_place_dataset_converted_externally_to_rlds': 3,
'austin_sailor_dataset_converted_externally_to_rlds': 20,
'bc_z': 10,
'utokyo_pr2_opening_fridge_converted_externally_to_rlds': 10,
'utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds': 10,
'utokyo_xarm_pick_and_place_converted_externally_to_rlds': 10,
'utokyo_xarm_bimanual_converted_externally_to_rlds': 10,
'robo_net': 1,
'kaist_nonprehensile_converted_externally_to_rlds': 10,
'stanford_mask_vit_converted_externally_to_rlds': 1, # no groundtruth value
'dlr_sara_pour_converted_externally_to_rlds': 10,
'dlr_sara_grid_clamp_converted_externally_to_rlds': 10,
'dlr_edan_shared_control_converted_externally_to_rlds': 5,
'asu_table_top_converted_externally_to_rlds': 12.5,
'iamlab_cmu_pickup_insert_converted_externally_to_rlds': 20,
'uiuc_d3field1': 1,
'uiuc_d3field2': 1,
'uiuc_d3field3': 1,
'uiuc_d3field4': 1,
'utaustin_mutex': 20,
'berkeley_fanuc_manipulation': 10,
'cmu_playing_with_food': 10,
'cmu_play_fusion': 5,
'cmu_stretch': 10,
}
DISPLAY_KEY = {
'taco_play': 'rgb_static',
'roboturk': 'front_rgb',
'viola': 'agentview_rgb',
'language_table': 'rgb',
'stanford_robocook_converted_externally_to_rlds1': 'image_1',
'stanford_robocook_converted_externally_to_rlds2': 'image_2',
'stanford_robocook_converted_externally_to_rlds3': 'image_3',
'stanford_robocook_converted_externally_to_rlds4': 'image_4',
'uiuc_d3field1': 'image_1',
'uiuc_d3field2': 'image_2',
'uiuc_d3field3': 'image_3',
'uiuc_d3field4': 'image_4',
}