Skip to content

Commit

Permalink
add train hover demo
Browse files Browse the repository at this point in the history
  • Loading branch information
KafuuChikai committed Dec 26, 2024
1 parent 66e3f84 commit 99db3c1
Show file tree
Hide file tree
Showing 3 changed files with 422 additions and 0 deletions.
229 changes: 229 additions & 0 deletions examples/drone/hover_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import torch
import math
import genesis as gs
from genesis.utils.geom import quat_to_xyz, transform_by_quat, inv_quat, transform_quat_by_quat

def gs_rand_float(lower, upper, shape, device):
return (upper - lower) * torch.rand(size=shape, device=device) + lower

class HoverEnv:
def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_viewer=False, device="cuda"):
self.device = torch.device(device)

self.num_envs = num_envs
self.num_obs = obs_cfg["num_obs"]
self.num_privileged_obs = None
self.num_actions = env_cfg["num_actions"]
self.num_commands = command_cfg["num_commands"]

# self.simulate_action_latency = env_cfg["simulate_action_latency"]
self.dt = 0.01 # run in 100hz
self.max_episode_length = math.ceil(env_cfg["episode_length_s"] / self.dt)

self.env_cfg = env_cfg
self.obs_cfg = obs_cfg
self.reward_cfg = reward_cfg
self.command_cfg = command_cfg

self.obs_scales = obs_cfg["obs_scales"]
self.reward_scales = reward_cfg["reward_scales"]

# create scene
self.scene = gs.Scene(
sim_options=gs.options.SimOptions(dt=self.dt, substeps=2),
viewer_options=gs.options.ViewerOptions(
max_FPS=60,
camera_pos=(2.0, 0.0, 2.5),
camera_lookat=(0.0, 0.0, 1.0),
camera_fov=40,
),
vis_options=gs.options.VisOptions(n_rendered_envs=1),
rigid_options=gs.options.RigidOptions(
dt=self.dt,
constraint_solver=gs.constraint_solver.Newton,
enable_collision=True,
enable_joint_limit=True,
),
show_viewer=show_viewer,
)

# add plane
self.scene.add_entity(gs.morphs.Plane())

# add drone
self.base_init_pos = torch.tensor(self.env_cfg["base_init_pos"], device=self.device)
self.base_init_quat = torch.tensor(self.env_cfg["base_init_quat"], device=self.device)
self.inv_base_init_quat = inv_quat(self.base_init_quat)
# self.base_init_pos = torch.tensor(self.env_cfg["base_init_pos"], device=self.device)
self.drone = self.scene.add_entity(gs.morphs.Drone(file="urdf/drones/cf2x.urdf"))

# build scene
self.scene.build(n_envs=num_envs)

# prepare reward functions and multiply reward scales by dt
self.reward_functions, self.episode_sums = dict(), dict()
for name in self.reward_scales.keys():
self.reward_scales[name] *= self.dt
self.reward_functions[name] = getattr(self, "_reward_" + name)
self.episode_sums[name] = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float)

# initialize buffers
self.obs_buf = torch.zeros((self.num_envs, self.num_obs), device=self.device, dtype=gs.tc_float)
self.rew_buf = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float)
self.reset_buf = torch.ones((self.num_envs,), device=self.device, dtype=gs.tc_int)
self.episode_length_buf = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_int)
self.commands = torch.zeros((self.num_envs, self.num_commands), device=self.device, dtype=gs.tc_float)

self.actions = torch.zeros((self.num_envs, self.num_actions), device=self.device, dtype=gs.tc_float)
self.last_actions = torch.zeros_like(self.actions)

self.base_pos = torch.zeros((self.num_envs, 3), device=self.device, dtype=gs.tc_float)
self.base_quat = torch.zeros((self.num_envs, 4), device=self.device, dtype=gs.tc_float)
self.base_lin_vel = torch.zeros((self.num_envs, 3), device=self.device, dtype=gs.tc_float)
self.base_ang_vel = torch.zeros((self.num_envs, 3), device=self.device, dtype=gs.tc_float)
self.last_base_pos = torch.zeros_like(self.base_pos)

self.extras = dict() # extra information for logging

def _resample_commands(self, envs_idx):
self.commands[envs_idx, 0] = gs_rand_float(*self.command_cfg["pos_x_range"], (len(envs_idx),), self.device)
self.commands[envs_idx, 1] = gs_rand_float(*self.command_cfg["pos_y_range"], (len(envs_idx),), self.device)
self.commands[envs_idx, 2] = gs_rand_float(*self.command_cfg["pos_z_range"], (len(envs_idx),), self.device)

def step(self, actions):
self.actions = torch.clip(actions, -self.env_cfg["clip_actions"], self.env_cfg["clip_actions"])
exec_actions = self.actions.cpu()

# exec_actions = self.last_actions if self.simulate_action_latency else self.actions
# target_dof_pos = exec_actions * self.env_cfg["action_scale"] + self.default_dof_pos
# self.drone.control_dofs_position(target_dof_pos)

# 14468 is hover rpm
self.drone.set_propellels_rpm((1 + exec_actions) * 14468.429183500699)
self.scene.step()

# update buffers
self.episode_length_buf += 1
self.last_base_pos[:] = self.base_pos[:]
self.base_pos[:] = self.drone.get_pos()
self.rel_pos = self.commands - self.base_pos
self.last_rel_pos = self.commands - self.last_base_pos
self.base_quat[:] = self.drone.get_quat()
# self.base_euler = quat_to_xyz(self.base_quat)
self.base_euler = quat_to_xyz(
transform_quat_by_quat(torch.ones_like(self.base_quat) * self.inv_base_init_quat, self.base_quat)
)
inv_base_quat = inv_quat(self.base_quat)
self.base_lin_vel[:] = transform_by_quat(self.drone.get_vel(), inv_base_quat)
self.base_ang_vel[:] = transform_by_quat(self.drone.get_ang(), inv_base_quat)

# resample commands
# envs_idx = (
# (self.episode_length_buf % int(self.env_cfg["resampling_time_s"] / self.dt) == 0)
# .nonzero(as_tuple=False)
# .flatten()
# )
# self._resample_commands(envs_idx)

# check termination and reset
self.reset_buf = self.episode_length_buf > self.max_episode_length
self.reset_buf |= torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"]
self.reset_buf |= torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"]
self.reset_buf |= torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"]
self.reset_buf |= torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"]
self.reset_buf |= torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"]
self.reset_buf |= self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"]

time_out_idx = (self.episode_length_buf > self.max_episode_length).nonzero(as_tuple=False).flatten()
self.extras["time_outs"] = torch.zeros_like(self.reset_buf, device=self.device, dtype=gs.tc_float)
self.extras["time_outs"][time_out_idx] = 1.0

self.reset_idx(self.reset_buf.nonzero(as_tuple=False).flatten())

# compute reward
self.rew_buf[:] = 0.0
for name, reward_func in self.reward_functions.items():
rew = reward_func() * self.reward_scales[name]
self.rew_buf += rew
self.episode_sums[name] += rew

# compute observations
self.obs_buf = torch.cat(
[
torch.clip(self.rel_pos * self.obs_scales["rel_pos"], -1, 1),
self.base_quat,
torch.clip(self.base_lin_vel * self.obs_scales["lin_vel"], -1, 1),
torch.clip(self.base_ang_vel * self.obs_scales["ang_vel"], -1, 1),
self.last_actions,
],
axis=-1,
)

self.last_actions[:] = self.actions[:]

return self.obs_buf, None, self.rew_buf, self.reset_buf, self.extras

def get_observations(self):
return self.obs_buf

def get_privileged_observations(self):
return None

def reset_idx(self, envs_idx):
if len(envs_idx) == 0:
return

# reset base
self.base_pos[envs_idx] = self.base_init_pos
self.last_base_pos[envs_idx] = self.base_init_pos
self.rel_pos = self.commands - self.base_pos
self.last_rel_pos = self.commands - self.last_base_pos
self.base_quat[envs_idx] = self.base_init_quat.reshape(1, -1)
self.drone.set_pos(self.base_pos[envs_idx], zero_velocity=True, envs_idx=envs_idx)
self.drone.set_quat(self.base_quat[envs_idx], zero_velocity=True, envs_idx=envs_idx)
self.base_lin_vel[envs_idx] = 0
self.base_ang_vel[envs_idx] = 0
self.drone.zero_all_dofs_velocity(envs_idx)

# reset buffers
self.last_actions[envs_idx] = 0.0
self.episode_length_buf[envs_idx] = 0
self.reset_buf[envs_idx] = True

# fill extras
self.extras["episode"] = {}
for key in self.episode_sums.keys():
self.extras["episode"]["rew_" + key] = (
torch.mean(self.episode_sums[key][envs_idx]).item() / self.env_cfg["episode_length_s"]
)
self.episode_sums[key][envs_idx] = 0.0

self._resample_commands(envs_idx)

def reset(self):
self.reset_buf[:] = True
self.reset_idx(torch.arange(self.num_envs, device=self.device))
return self.obs_buf, None

# ------------ reward functions----------------
def _reward_target(self):
target_rew = torch.sum(torch.square(self.last_rel_pos), dim=1) - torch.sum(torch.square(self.rel_pos), dim=1)
return target_rew

def _reward_smooth(self):
smooth_rew = torch.sum(torch.square(self.actions - self.last_actions), dim=1)
return smooth_rew

def _reward_crash(self):
crash_rew = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float)

crash_condition = (
(torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"]) |
(torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"]) |
(torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"]) |
(torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"]) |
(torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"]) |
(self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"])
)
crash_rew[crash_condition] = -1
return crash_rew
51 changes: 51 additions & 0 deletions examples/drone/hover_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import argparse
import os
import pickle

import torch
from hover_env import HoverEnv
from rsl_rl.runners import OnPolicyRunner

import genesis as gs


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--exp_name", type=str, default="drone-hovering")
parser.add_argument("--ckpt", type=int, default=100)
args = parser.parse_args()

gs.init()

log_dir = f"logs/{args.exp_name}"
env_cfg, obs_cfg, reward_cfg, command_cfg, train_cfg = pickle.load(open(f"logs/{args.exp_name}/cfgs.pkl", "rb"))
reward_cfg["reward_scales"] = {}

env = HoverEnv(
num_envs=1,
env_cfg=env_cfg,
obs_cfg=obs_cfg,
reward_cfg=reward_cfg,
command_cfg=command_cfg,
show_viewer=True,
)

runner = OnPolicyRunner(env, train_cfg, log_dir, device="cuda:0")
resume_path = os.path.join(log_dir, f"model_{args.ckpt}.pt")
runner.load(resume_path)
policy = runner.get_inference_policy(device="cuda:0")

obs, _ = env.reset()
with torch.no_grad():
while True:
actions = policy(obs)
obs, _, rews, dones, infos = env.step(actions)


if __name__ == "__main__":
main()

"""
# evaluation
python examples/drone/hover_eval.py
"""
Loading

0 comments on commit 99db3c1

Please sign in to comment.