Skip to content

Commit

Permalink
fix black formation problem
Browse files Browse the repository at this point in the history
  • Loading branch information
JonnyDing committed Dec 28, 2024
1 parent e21dc34 commit abf4771
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 43 deletions.
53 changes: 27 additions & 26 deletions examples/drone/hover_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import genesis as gs
from genesis.utils.geom import quat_to_xyz, transform_by_quat, inv_quat, transform_quat_by_quat


def gs_rand_float(lower, upper, shape, device):
return (upper - lower) * torch.rand(size=shape, device=device) + lower


class HoverEnv:
def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_viewer=False, device="cuda"):
self.device = torch.device(device)
Expand Down Expand Up @@ -52,18 +54,19 @@ def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_vie

# add target
if self.env_cfg["visualize_target"]:
self.target = self.scene.add_entity(morph=gs.morphs.Mesh(
file="meshes/sphere.obj",
scale=0.05,
fixed=True,
collision=False,
),
surface=gs.surfaces.Rough(
diffuse_texture=gs.textures.ColorTexture(
color=(1.0, 0.5, 0.5),
),
),
)
self.target = self.scene.add_entity(
morph=gs.morphs.Mesh(
file="meshes/sphere.obj",
scale=0.05,
fixed=True,
collision=False,
),
surface=gs.surfaces.Rough(
diffuse_texture=gs.textures.ColorTexture(
color=(1.0, 0.5, 0.5),
),
),
)

# add camera
if self.env_cfg["visualize_camera"]:
Expand Down Expand Up @@ -118,9 +121,7 @@ def _resample_commands(self, envs_idx):

def _at_target(self):
at_target = (
(torch.norm(self.rel_pos, dim=1) < self.env_cfg["at_target_threshold"])
.nonzero(as_tuple=False)
.flatten()
(torch.norm(self.rel_pos, dim=1) < self.env_cfg["at_target_threshold"]).nonzero(as_tuple=False).flatten()
)
return at_target

Expand All @@ -132,7 +133,7 @@ def step(self, actions):
# self.drone.control_dofs_position(target_dof_pos)

# 14468 is hover rpm
self.drone.set_propellels_rpm((1 + exec_actions*0.8) * 14468.429183500699)
self.drone.set_propellels_rpm((1 + exec_actions * 0.8) * 14468.429183500699)
self.scene.step()

# update buffers
Expand All @@ -155,12 +156,12 @@ def step(self, actions):

# check termination and reset
self.crash_condition = (
(torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"]) |
(torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"]) |
(torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"]) |
(torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"]) |
(torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"]) |
(self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"])
(torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"])
| (torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"])
| (torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"])
| (torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"])
| (torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"])
| (self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"])
)
self.reset_buf = (self.episode_length_buf > self.max_episode_length) | self.crash_condition

Expand Down Expand Up @@ -246,15 +247,15 @@ def _reward_smooth(self):

def _reward_yaw(self):
yaw = self.base_euler[:, 2]
yaw = torch.where(yaw > 180, yaw - 360, yaw)/180*3.14159 # use rad for yaw_reward
yaw = torch.where(yaw > 180, yaw - 360, yaw) / 180 * 3.14159 # use rad for yaw_reward
yaw_rew = torch.exp(self.reward_cfg["yaw_lambda"] * torch.abs(yaw))
return yaw_rew

def _reward_angular(self):
angular_rew = torch.norm(self.base_ang_vel/3.14159, dim=1)
angular_rew = torch.norm(self.base_ang_vel / 3.14159, dim=1)
return angular_rew

def _reward_crash(self):
crash_rew = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float)
crash_rew[self.crash_condition] = 1
return crash_rew
return crash_rew
3 changes: 2 additions & 1 deletion examples/drone/hover_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def main():

obs, _ = env.reset()

max_sim_step = int(env_cfg["episode_length_s"]*env_cfg["max_visualize_FPS"])
max_sim_step = int(env_cfg["episode_length_s"] * env_cfg["max_visualize_FPS"])
with torch.no_grad():
if args.record:
env.cam.start_recording()
Expand All @@ -59,6 +59,7 @@ def main():
actions = policy(obs)
obs, _, rews, dones, infos = env.step(actions)


if __name__ == "__main__":
main()

Expand Down
2 changes: 1 addition & 1 deletion examples/drone/hover_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_cfgs():
"yaw": 0.01,
"angular": -2e-4,
"crash": -10.0,
}
},
}
command_cfg = {
"num_commands": 3,
Expand Down
17 changes: 10 additions & 7 deletions genesis/ext/pyrender/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class Viewer(pyglet.window.Window):
(scene default, flip wireframes, all wireframe, or all solid).
- ``z``: Resets the camera to the initial view.
- ``x``: Resets simulation to initial state.
- ``Space``: Pause rendering util pressing Space again
- ``Space``: Pause simulation util pressing Space again
Note
----
Expand Down Expand Up @@ -186,6 +186,7 @@ def __init__(
self,
scene,
context,
viewer,
viewport_size=None,
render_flags=None,
viewer_flags=None,
Expand All @@ -204,6 +205,7 @@ def __init__(
viewport_size = (640, 480)
self.gs_context = context
self.render_scene = scene
self._viewer = viewer
self._scene = context._scene
self._viewport_size = viewport_size
self._render_lock = RLock()
Expand All @@ -214,9 +216,8 @@ def __init__(
self._should_close = False
self._run_in_thread = run_in_thread
self._seg_node_map = context.seg_node_map

self._pause_draw = False
self._video_saver = None

self._default_render_flags = {
"flip_wireframe": False,
"all_wireframe": False,
Expand Down Expand Up @@ -297,7 +298,7 @@ def __init__(
" [x]: reset simulation",
" [d]: wireframe",
" [c]: camera & frustrum",
" [Space]: pause rendering",
" [Space]: pause simulation",
" [F11]: full-screen mode",
],
]
Expand Down Expand Up @@ -630,7 +631,7 @@ def draw_offscreen(self):

def on_draw(self):
"""Redraw the scene into the viewing window."""
if self._renderer is None:
if self._renderer is None or self._pause_draw:
return

if self.run_in_thread or not self.auto_start:
Expand Down Expand Up @@ -860,12 +861,14 @@ def on_key_press(self, symbol, modifiers):
self._save_image()
elif symbol == pyglet.window.key.SPACE:
if not self.gs_context.pause_rendering_shown:
self.render_lock.acquire(blocking=False)
self._pause_draw = True
self.gs_context.pause_rendering_shown = True
self._viewer._pause_render_flag = True
# gs.logger.info("pause_rendering......")
else:
self.render_lock.release()
self._pause_draw = False
self.gs_context.pause_rendering_shown = False
self._viewer._pause_render_flag = False
# gs.logger.info("start_rendering......")
elif symbol == pyglet.window.key.X:
self.render_scene.reset()
Expand Down
2 changes: 1 addition & 1 deletion genesis/vis/rasterizer_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def init_meshes(self):
self.world_frame_shown = False
self.link_frame_shown = False
self.camera_frustum_shown = False

self.pause_rendering_shown = False
self.world_frame_mesh = mu.create_frame(
origin_radius=0.012,
axis_radius=0.005,
Expand Down
10 changes: 6 additions & 4 deletions genesis/vis/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,17 @@ def __init__(self, options, context):
self._camera_init_lookat = options.camera_lookat
self._camera_up = options.camera_up
self._camera_fov = options.camera_fov

self._pause_render_flag = False
self.context = context

if self._max_FPS is not None:
self.rate = Rate(self._max_FPS)

def build(self, scene):
def build(self, scene, viewer):
self.scene = scene

# set viewer camera
self.setup_camera()

self.viewer = viewer
# viewer
if gs.platform == "Linux":
run_in_thread = True
Expand All @@ -68,13 +67,15 @@ def build(self, scene):

self._pyrender_viewer = pyrender.Viewer(
scene=self.scene,
viewer=self.viewer,
context=self.context,
viewport_size=self._res,
run_in_thread=run_in_thread,
auto_start=auto_start,
view_center=self._camera_init_lookat,
shadow=self.context.shadow,
plane_reflection=self.context.plane_reflection,
update_flag=self._pause_render_flag,
viewer_flags={
"window_title": f"Genesis {gs.__version__}",
"refresh_rate": self._refresh_rate,
Expand Down Expand Up @@ -112,6 +113,7 @@ def setup_camera(self):
def update(self):
with self.lock:
buffer_updates = self.context.update()
# gs.logger.info(f"Updating viewer....")
for buffer_id, buffer_data in buffer_updates.items():
self._pyrender_viewer.pending_buffer_updates[buffer_id] = buffer_data

Expand Down
8 changes: 5 additions & 3 deletions genesis/vis/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def build(self):
self._context.build(self._scene)

if self._viewer is not None:
self._viewer.build(self._scene)
self._viewer.build(self._scene, self._viewer)
self.viewer_lock = self._viewer.lock
else:
self.viewer_lock = DummyViewerLock()
Expand Down Expand Up @@ -129,12 +129,14 @@ def build(self):
def update(self, force=True):
if force: # force update
self.reset()

if self._viewer is not None:
if self._viewer is not None and not self._viewer._pause_render_flag:
# gs.logger.info("Updating viewer in visualizer.....")
if self._viewer.is_alive():
self._viewer.update()
else:
gs.raise_exception("Viewer closed.")
# else:
# gs.logger.info("Skip updating viewer in visualizer.....")

def update_visual_states(self):
"""
Expand Down

0 comments on commit abf4771

Please sign in to comment.