forked from real-stanford/diffusion_policy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreal_inference_util.py
52 lines (49 loc) · 1.83 KB
/
real_inference_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from typing import Dict, Callable, Tuple
import numpy as np
from diffusion_policy.common.cv2_util import get_image_transform
def get_real_obs_dict(
env_obs: Dict[str, np.ndarray],
shape_meta: dict,
) -> Dict[str, np.ndarray]:
obs_dict_np = dict()
obs_shape_meta = shape_meta['obs']
for key, attr in obs_shape_meta.items():
type = attr.get('type', 'low_dim')
shape = attr.get('shape')
if type == 'rgb':
this_imgs_in = env_obs[key]
t,hi,wi,ci = this_imgs_in.shape
co,ho,wo = shape
assert ci == co
out_imgs = this_imgs_in
if (ho != hi) or (wo != wi) or (this_imgs_in.dtype == np.uint8):
tf = get_image_transform(
input_res=(wi,hi),
output_res=(wo,ho),
bgr_to_rgb=False)
out_imgs = np.stack([tf(x) for x in this_imgs_in])
if this_imgs_in.dtype == np.uint8:
out_imgs = out_imgs.astype(np.float32) / 255
# THWC to TCHW
obs_dict_np[key] = np.moveaxis(out_imgs,-1,1)
elif type == 'low_dim':
this_data_in = env_obs[key]
if 'pose' in key and shape == (2,):
# take X,Y coordinates
this_data_in = this_data_in[...,[0,1]]
obs_dict_np[key] = this_data_in
return obs_dict_np
def get_real_obs_resolution(
shape_meta: dict
) -> Tuple[int, int]:
out_res = None
obs_shape_meta = shape_meta['obs']
for key, attr in obs_shape_meta.items():
type = attr.get('type', 'low_dim')
shape = attr.get('shape')
if type == 'rgb':
co,ho,wo = shape
if out_res is None:
out_res = (wo, ho)
assert out_res == (wo, ho)
return out_res