forked from real-stanford/diffusion_policy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpusht_image_env.py
66 lines (57 loc) · 1.91 KB
/
pusht_image_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from gym import spaces
from diffusion_policy.env.pusht.pusht_env import PushTEnv
import numpy as np
import cv2
class PushTImageEnv(PushTEnv):
metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10}
def __init__(self,
legacy=False,
block_cog=None,
damping=None,
render_size=96):
super().__init__(
legacy=legacy,
block_cog=block_cog,
damping=damping,
render_size=render_size,
render_action=False)
ws = self.window_size
self.observation_space = spaces.Dict({
'image': spaces.Box(
low=0,
high=1,
shape=(3,render_size,render_size),
dtype=np.float32
),
'agent_pos': spaces.Box(
low=0,
high=ws,
shape=(2,),
dtype=np.float32
)
})
self.render_cache = None
def _get_obs(self):
img = super()._render_frame(mode='rgb_array')
agent_pos = np.array(self.agent.position)
img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0)
obs = {
'image': img_obs,
'agent_pos': agent_pos
}
# draw action
if self.latest_action is not None:
action = np.array(self.latest_action)
coord = (action / 512 * 96).astype(np.int32)
marker_size = int(8/96*self.render_size)
thickness = int(1/96*self.render_size)
cv2.drawMarker(img, coord,
color=(255,0,0), markerType=cv2.MARKER_CROSS,
markerSize=marker_size, thickness=thickness)
self.render_cache = img
return obs
def render(self, mode):
assert mode == 'rgb_array'
if self.render_cache is None:
self._get_obs()
return self.render_cache