Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Add Reset Scence and Stop Rendering function and fix black reforamtion #367

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
53 changes: 27 additions & 26 deletions examples/drone/hover_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import genesis as gs
from genesis.utils.geom import quat_to_xyz, transform_by_quat, inv_quat, transform_quat_by_quat


def gs_rand_float(lower, upper, shape, device):
return (upper - lower) * torch.rand(size=shape, device=device) + lower


class HoverEnv:
def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_viewer=False, device="cuda"):
self.device = torch.device(device)
Expand Down Expand Up @@ -52,18 +54,19 @@ def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_vie

# add target
if self.env_cfg["visualize_target"]:
self.target = self.scene.add_entity(morph=gs.morphs.Mesh(
file="meshes/sphere.obj",
scale=0.05,
fixed=True,
collision=False,
),
surface=gs.surfaces.Rough(
diffuse_texture=gs.textures.ColorTexture(
color=(1.0, 0.5, 0.5),
),
),
)
self.target = self.scene.add_entity(
morph=gs.morphs.Mesh(
file="meshes/sphere.obj",
scale=0.05,
fixed=True,
collision=False,
),
surface=gs.surfaces.Rough(
diffuse_texture=gs.textures.ColorTexture(
color=(1.0, 0.5, 0.5),
),
),
)

# add camera
if self.env_cfg["visualize_camera"]:
Expand Down Expand Up @@ -118,9 +121,7 @@ def _resample_commands(self, envs_idx):

def _at_target(self):
at_target = (
(torch.norm(self.rel_pos, dim=1) < self.env_cfg["at_target_threshold"])
.nonzero(as_tuple=False)
.flatten()
(torch.norm(self.rel_pos, dim=1) < self.env_cfg["at_target_threshold"]).nonzero(as_tuple=False).flatten()
)
return at_target

Expand All @@ -132,7 +133,7 @@ def step(self, actions):
# self.drone.control_dofs_position(target_dof_pos)

# 14468 is hover rpm
self.drone.set_propellels_rpm((1 + exec_actions*0.8) * 14468.429183500699)
self.drone.set_propellels_rpm((1 + exec_actions * 0.8) * 14468.429183500699)
self.scene.step()

# update buffers
Expand All @@ -155,12 +156,12 @@ def step(self, actions):

# check termination and reset
self.crash_condition = (
(torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"]) |
(torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"]) |
(torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"]) |
(torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"]) |
(torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"]) |
(self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"])
(torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"])
| (torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"])
| (torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"])
| (torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"])
| (torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"])
| (self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"])
)
self.reset_buf = (self.episode_length_buf > self.max_episode_length) | self.crash_condition

Expand Down Expand Up @@ -246,15 +247,15 @@ def _reward_smooth(self):

def _reward_yaw(self):
yaw = self.base_euler[:, 2]
yaw = torch.where(yaw > 180, yaw - 360, yaw)/180*3.14159 # use rad for yaw_reward
yaw = torch.where(yaw > 180, yaw - 360, yaw) / 180 * 3.14159 # use rad for yaw_reward
yaw_rew = torch.exp(self.reward_cfg["yaw_lambda"] * torch.abs(yaw))
return yaw_rew

def _reward_angular(self):
angular_rew = torch.norm(self.base_ang_vel/3.14159, dim=1)
angular_rew = torch.norm(self.base_ang_vel / 3.14159, dim=1)
return angular_rew

def _reward_crash(self):
crash_rew = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float)
crash_rew[self.crash_condition] = 1
return crash_rew
return crash_rew
3 changes: 2 additions & 1 deletion examples/drone/hover_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def main():

obs, _ = env.reset()

max_sim_step = int(env_cfg["episode_length_s"]*env_cfg["max_visualize_FPS"])
max_sim_step = int(env_cfg["episode_length_s"] * env_cfg["max_visualize_FPS"])
with torch.no_grad():
if args.record:
env.cam.start_recording()
Expand All @@ -59,6 +59,7 @@ def main():
actions = policy(obs)
obs, _, rews, dones, infos = env.step(actions)


if __name__ == "__main__":
main()

Expand Down
2 changes: 1 addition & 1 deletion examples/drone/hover_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_cfgs():
"yaw": 0.01,
"angular": -2e-4,
"crash": -10.0,
}
},
}
command_cfg = {
"num_commands": 3,
Expand Down
Empty file added examples/rigid/load_model.py
Empty file.
Empty file added examples/rigid/load_urdf.py
Empty file.
32 changes: 28 additions & 4 deletions genesis/ext/pyrender/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ class Viewer(pyglet.window.Window):
- ``w``: Toggles wireframe mode
(scene default, flip wireframes, all wireframe, or all solid).
- ``z``: Resets the camera to the initial view.
- ``x``: Resets simulation to initial state.
- ``Space``: Pause simulation util pressing Space again

Note
----
Expand Down Expand Up @@ -182,7 +184,9 @@ class Viewer(pyglet.window.Window):

def __init__(
self,
scene,
context,
viewer,
viewport_size=None,
render_flags=None,
viewer_flags=None,
Expand All @@ -200,6 +204,8 @@ def __init__(
if viewport_size is None:
viewport_size = (640, 480)
self.gs_context = context
self.render_scene = scene
self._viewer = viewer
self._scene = context._scene
self._viewport_size = viewport_size
self._render_lock = RLock()
Expand All @@ -210,9 +216,8 @@ def __init__(
self._should_close = False
self._run_in_thread = run_in_thread
self._seg_node_map = context.seg_node_map

self._pause_draw = False
self._video_saver = None

self._default_render_flags = {
"flip_wireframe": False,
"all_wireframe": False,
Expand Down Expand Up @@ -290,8 +295,10 @@ def __init__(
" [v]: vertex normal",
" [w]: world frame",
" [l]: link frame",
" [x]: reset simulation",
" [d]: wireframe",
" [c]: camera & frustrum",
" [Space]: pause simulation",
" [F11]: full-screen mode",
],
]
Expand Down Expand Up @@ -624,7 +631,7 @@ def draw_offscreen(self):

def on_draw(self):
"""Redraw the scene into the viewing window."""
if self._renderer is None:
if self._renderer is None or self._pause_draw:
return

if self.run_in_thread or not self.auto_start:
Expand Down Expand Up @@ -852,7 +859,20 @@ def on_key_press(self, symbol, modifiers):
# S saves the current frame as an image
elif symbol == pyglet.window.key.S:
self._save_image()

elif symbol == pyglet.window.key.SPACE:
if not self.gs_context.pause_rendering_shown:
self._pause_draw = True
self.gs_context.pause_rendering_shown = True
self._viewer._pause_render_flag = True
# gs.logger.info("pause_rendering......")
else:
self._pause_draw = False
self.gs_context.pause_rendering_shown = False
self._viewer._pause_render_flag = False
# gs.logger.info("start_rendering......")
elif symbol == pyglet.window.key.X:
self.render_scene.reset()
# gs.logger.info("Start reset......")
# T toggles through geom types
# elif symbol == pyglet.window.key.T:
# if self.gs_context.rigid_shown == 'visual':
Expand Down Expand Up @@ -1219,5 +1239,9 @@ def _location_to_x_y(self, location):
elif location == TextAlign.TOP_CENTER:
return (self.viewport_size[0] / 2.0, self.viewport_size[1] - TEXT_PADDING)

@scene.setter
def scene(self, value):
self._scene = value


__all__ = ["Viewer"]
2 changes: 1 addition & 1 deletion genesis/vis/rasterizer_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def init_meshes(self):
self.world_frame_shown = False
self.link_frame_shown = False
self.camera_frustum_shown = False

self.pause_rendering_shown = False
self.world_frame_mesh = mu.create_frame(
origin_radius=0.012,
axis_radius=0.005,
Expand Down
11 changes: 7 additions & 4 deletions genesis/vis/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,17 @@ def __init__(self, options, context):
self._camera_init_lookat = options.camera_lookat
self._camera_up = options.camera_up
self._camera_fov = options.camera_fov

self._pause_render_flag = False
self.context = context

if self._max_FPS is not None:
self.rate = Rate(self._max_FPS)

def build(self, scene):
def build(self, scene, viewer):
self.scene = scene

# set viewer camera
self.setup_camera()

self.viewer = viewer
# viewer
if gs.platform == "Linux":
run_in_thread = True
Expand All @@ -67,13 +66,16 @@ def build(self, scene):
gs.raise_exception("Viewer has some issues on Windows. Can anyone help?")

self._pyrender_viewer = pyrender.Viewer(
scene=self.scene,
viewer=self.viewer,
context=self.context,
viewport_size=self._res,
run_in_thread=run_in_thread,
auto_start=auto_start,
view_center=self._camera_init_lookat,
shadow=self.context.shadow,
plane_reflection=self.context.plane_reflection,
update_flag=self._pause_render_flag,
viewer_flags={
"window_title": f"Genesis {gs.__version__}",
"refresh_rate": self._refresh_rate,
Expand Down Expand Up @@ -111,6 +113,7 @@ def setup_camera(self):
def update(self):
with self.lock:
buffer_updates = self.context.update()
# gs.logger.info(f"Updating viewer....")
for buffer_id, buffer_data in buffer_updates.items():
self._pyrender_viewer.pending_buffer_updates[buffer_id] = buffer_data

Expand Down
8 changes: 5 additions & 3 deletions genesis/vis/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def build(self):
self._context.build(self._scene)

if self._viewer is not None:
self._viewer.build(self._scene)
self._viewer.build(self._scene, self._viewer)
self.viewer_lock = self._viewer.lock
else:
self.viewer_lock = DummyViewerLock()
Expand Down Expand Up @@ -129,12 +129,14 @@ def build(self):
def update(self, force=True):
if force: # force update
self.reset()

if self._viewer is not None:
if self._viewer is not None and not self._viewer._pause_render_flag:
# gs.logger.info("Updating viewer in visualizer.....")
if self._viewer.is_alive():
self._viewer.update()
else:
gs.raise_exception("Viewer closed.")
# else:
# gs.logger.info("Skip updating viewer in visualizer.....")

def update_visual_states(self):
"""
Expand Down
Loading