-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #338 from KafuuChikai/drone_rl
Add Drone Hovering Reinforcement Learning Environment and Training Scripts
- Loading branch information
Showing
3 changed files
with
422 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
import torch | ||
import math | ||
import genesis as gs | ||
from genesis.utils.geom import quat_to_xyz, transform_by_quat, inv_quat, transform_quat_by_quat | ||
|
||
def gs_rand_float(lower, upper, shape, device): | ||
return (upper - lower) * torch.rand(size=shape, device=device) + lower | ||
|
||
class HoverEnv: | ||
def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_viewer=False, device="cuda"): | ||
self.device = torch.device(device) | ||
|
||
self.num_envs = num_envs | ||
self.num_obs = obs_cfg["num_obs"] | ||
self.num_privileged_obs = None | ||
self.num_actions = env_cfg["num_actions"] | ||
self.num_commands = command_cfg["num_commands"] | ||
|
||
# self.simulate_action_latency = env_cfg["simulate_action_latency"] | ||
self.dt = 0.01 # run in 100hz | ||
self.max_episode_length = math.ceil(env_cfg["episode_length_s"] / self.dt) | ||
|
||
self.env_cfg = env_cfg | ||
self.obs_cfg = obs_cfg | ||
self.reward_cfg = reward_cfg | ||
self.command_cfg = command_cfg | ||
|
||
self.obs_scales = obs_cfg["obs_scales"] | ||
self.reward_scales = reward_cfg["reward_scales"] | ||
|
||
# create scene | ||
self.scene = gs.Scene( | ||
sim_options=gs.options.SimOptions(dt=self.dt, substeps=2), | ||
viewer_options=gs.options.ViewerOptions( | ||
max_FPS=60, | ||
camera_pos=(2.0, 0.0, 2.5), | ||
camera_lookat=(0.0, 0.0, 1.0), | ||
camera_fov=40, | ||
), | ||
vis_options=gs.options.VisOptions(n_rendered_envs=1), | ||
rigid_options=gs.options.RigidOptions( | ||
dt=self.dt, | ||
constraint_solver=gs.constraint_solver.Newton, | ||
enable_collision=True, | ||
enable_joint_limit=True, | ||
), | ||
show_viewer=show_viewer, | ||
) | ||
|
||
# add plane | ||
self.scene.add_entity(gs.morphs.Plane()) | ||
|
||
# add drone | ||
self.base_init_pos = torch.tensor(self.env_cfg["base_init_pos"], device=self.device) | ||
self.base_init_quat = torch.tensor(self.env_cfg["base_init_quat"], device=self.device) | ||
self.inv_base_init_quat = inv_quat(self.base_init_quat) | ||
# self.base_init_pos = torch.tensor(self.env_cfg["base_init_pos"], device=self.device) | ||
self.drone = self.scene.add_entity(gs.morphs.Drone(file="urdf/drones/cf2x.urdf")) | ||
|
||
# build scene | ||
self.scene.build(n_envs=num_envs) | ||
|
||
# prepare reward functions and multiply reward scales by dt | ||
self.reward_functions, self.episode_sums = dict(), dict() | ||
for name in self.reward_scales.keys(): | ||
self.reward_scales[name] *= self.dt | ||
self.reward_functions[name] = getattr(self, "_reward_" + name) | ||
self.episode_sums[name] = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float) | ||
|
||
# initialize buffers | ||
self.obs_buf = torch.zeros((self.num_envs, self.num_obs), device=self.device, dtype=gs.tc_float) | ||
self.rew_buf = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float) | ||
self.reset_buf = torch.ones((self.num_envs,), device=self.device, dtype=gs.tc_int) | ||
self.episode_length_buf = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_int) | ||
self.commands = torch.zeros((self.num_envs, self.num_commands), device=self.device, dtype=gs.tc_float) | ||
|
||
self.actions = torch.zeros((self.num_envs, self.num_actions), device=self.device, dtype=gs.tc_float) | ||
self.last_actions = torch.zeros_like(self.actions) | ||
|
||
self.base_pos = torch.zeros((self.num_envs, 3), device=self.device, dtype=gs.tc_float) | ||
self.base_quat = torch.zeros((self.num_envs, 4), device=self.device, dtype=gs.tc_float) | ||
self.base_lin_vel = torch.zeros((self.num_envs, 3), device=self.device, dtype=gs.tc_float) | ||
self.base_ang_vel = torch.zeros((self.num_envs, 3), device=self.device, dtype=gs.tc_float) | ||
self.last_base_pos = torch.zeros_like(self.base_pos) | ||
|
||
self.extras = dict() # extra information for logging | ||
|
||
def _resample_commands(self, envs_idx): | ||
self.commands[envs_idx, 0] = gs_rand_float(*self.command_cfg["pos_x_range"], (len(envs_idx),), self.device) | ||
self.commands[envs_idx, 1] = gs_rand_float(*self.command_cfg["pos_y_range"], (len(envs_idx),), self.device) | ||
self.commands[envs_idx, 2] = gs_rand_float(*self.command_cfg["pos_z_range"], (len(envs_idx),), self.device) | ||
|
||
def step(self, actions): | ||
self.actions = torch.clip(actions, -self.env_cfg["clip_actions"], self.env_cfg["clip_actions"]) | ||
exec_actions = self.actions.cpu() | ||
|
||
# exec_actions = self.last_actions if self.simulate_action_latency else self.actions | ||
# target_dof_pos = exec_actions * self.env_cfg["action_scale"] + self.default_dof_pos | ||
# self.drone.control_dofs_position(target_dof_pos) | ||
|
||
# 14468 is hover rpm | ||
self.drone.set_propellels_rpm((1 + exec_actions) * 14468.429183500699) | ||
self.scene.step() | ||
|
||
# update buffers | ||
self.episode_length_buf += 1 | ||
self.last_base_pos[:] = self.base_pos[:] | ||
self.base_pos[:] = self.drone.get_pos() | ||
self.rel_pos = self.commands - self.base_pos | ||
self.last_rel_pos = self.commands - self.last_base_pos | ||
self.base_quat[:] = self.drone.get_quat() | ||
# self.base_euler = quat_to_xyz(self.base_quat) | ||
self.base_euler = quat_to_xyz( | ||
transform_quat_by_quat(torch.ones_like(self.base_quat) * self.inv_base_init_quat, self.base_quat) | ||
) | ||
inv_base_quat = inv_quat(self.base_quat) | ||
self.base_lin_vel[:] = transform_by_quat(self.drone.get_vel(), inv_base_quat) | ||
self.base_ang_vel[:] = transform_by_quat(self.drone.get_ang(), inv_base_quat) | ||
|
||
# resample commands | ||
# envs_idx = ( | ||
# (self.episode_length_buf % int(self.env_cfg["resampling_time_s"] / self.dt) == 0) | ||
# .nonzero(as_tuple=False) | ||
# .flatten() | ||
# ) | ||
# self._resample_commands(envs_idx) | ||
|
||
# check termination and reset | ||
self.reset_buf = self.episode_length_buf > self.max_episode_length | ||
self.reset_buf |= torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"] | ||
self.reset_buf |= torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"] | ||
self.reset_buf |= torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"] | ||
self.reset_buf |= torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"] | ||
self.reset_buf |= torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"] | ||
self.reset_buf |= self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"] | ||
|
||
time_out_idx = (self.episode_length_buf > self.max_episode_length).nonzero(as_tuple=False).flatten() | ||
self.extras["time_outs"] = torch.zeros_like(self.reset_buf, device=self.device, dtype=gs.tc_float) | ||
self.extras["time_outs"][time_out_idx] = 1.0 | ||
|
||
self.reset_idx(self.reset_buf.nonzero(as_tuple=False).flatten()) | ||
|
||
# compute reward | ||
self.rew_buf[:] = 0.0 | ||
for name, reward_func in self.reward_functions.items(): | ||
rew = reward_func() * self.reward_scales[name] | ||
self.rew_buf += rew | ||
self.episode_sums[name] += rew | ||
|
||
# compute observations | ||
self.obs_buf = torch.cat( | ||
[ | ||
torch.clip(self.rel_pos * self.obs_scales["rel_pos"], -1, 1), | ||
self.base_quat, | ||
torch.clip(self.base_lin_vel * self.obs_scales["lin_vel"], -1, 1), | ||
torch.clip(self.base_ang_vel * self.obs_scales["ang_vel"], -1, 1), | ||
self.last_actions, | ||
], | ||
axis=-1, | ||
) | ||
|
||
self.last_actions[:] = self.actions[:] | ||
|
||
return self.obs_buf, None, self.rew_buf, self.reset_buf, self.extras | ||
|
||
def get_observations(self): | ||
return self.obs_buf | ||
|
||
def get_privileged_observations(self): | ||
return None | ||
|
||
def reset_idx(self, envs_idx): | ||
if len(envs_idx) == 0: | ||
return | ||
|
||
# reset base | ||
self.base_pos[envs_idx] = self.base_init_pos | ||
self.last_base_pos[envs_idx] = self.base_init_pos | ||
self.rel_pos = self.commands - self.base_pos | ||
self.last_rel_pos = self.commands - self.last_base_pos | ||
self.base_quat[envs_idx] = self.base_init_quat.reshape(1, -1) | ||
self.drone.set_pos(self.base_pos[envs_idx], zero_velocity=True, envs_idx=envs_idx) | ||
self.drone.set_quat(self.base_quat[envs_idx], zero_velocity=True, envs_idx=envs_idx) | ||
self.base_lin_vel[envs_idx] = 0 | ||
self.base_ang_vel[envs_idx] = 0 | ||
self.drone.zero_all_dofs_velocity(envs_idx) | ||
|
||
# reset buffers | ||
self.last_actions[envs_idx] = 0.0 | ||
self.episode_length_buf[envs_idx] = 0 | ||
self.reset_buf[envs_idx] = True | ||
|
||
# fill extras | ||
self.extras["episode"] = {} | ||
for key in self.episode_sums.keys(): | ||
self.extras["episode"]["rew_" + key] = ( | ||
torch.mean(self.episode_sums[key][envs_idx]).item() / self.env_cfg["episode_length_s"] | ||
) | ||
self.episode_sums[key][envs_idx] = 0.0 | ||
|
||
self._resample_commands(envs_idx) | ||
|
||
def reset(self): | ||
self.reset_buf[:] = True | ||
self.reset_idx(torch.arange(self.num_envs, device=self.device)) | ||
return self.obs_buf, None | ||
|
||
# ------------ reward functions---------------- | ||
def _reward_target(self): | ||
target_rew = torch.sum(torch.square(self.last_rel_pos), dim=1) - torch.sum(torch.square(self.rel_pos), dim=1) | ||
return target_rew | ||
|
||
def _reward_smooth(self): | ||
smooth_rew = torch.sum(torch.square(self.actions - self.last_actions), dim=1) | ||
return smooth_rew | ||
|
||
def _reward_crash(self): | ||
crash_rew = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float) | ||
|
||
crash_condition = ( | ||
(torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"]) | | ||
(torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"]) | | ||
(torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"]) | | ||
(torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"]) | | ||
(torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"]) | | ||
(self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"]) | ||
) | ||
crash_rew[crash_condition] = -1 | ||
return crash_rew |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import argparse | ||
import os | ||
import pickle | ||
|
||
import torch | ||
from hover_env import HoverEnv | ||
from rsl_rl.runners import OnPolicyRunner | ||
|
||
import genesis as gs | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-e", "--exp_name", type=str, default="drone-hovering") | ||
parser.add_argument("--ckpt", type=int, default=100) | ||
args = parser.parse_args() | ||
|
||
gs.init() | ||
|
||
log_dir = f"logs/{args.exp_name}" | ||
env_cfg, obs_cfg, reward_cfg, command_cfg, train_cfg = pickle.load(open(f"logs/{args.exp_name}/cfgs.pkl", "rb")) | ||
reward_cfg["reward_scales"] = {} | ||
|
||
env = HoverEnv( | ||
num_envs=1, | ||
env_cfg=env_cfg, | ||
obs_cfg=obs_cfg, | ||
reward_cfg=reward_cfg, | ||
command_cfg=command_cfg, | ||
show_viewer=True, | ||
) | ||
|
||
runner = OnPolicyRunner(env, train_cfg, log_dir, device="cuda:0") | ||
resume_path = os.path.join(log_dir, f"model_{args.ckpt}.pt") | ||
runner.load(resume_path) | ||
policy = runner.get_inference_policy(device="cuda:0") | ||
|
||
obs, _ = env.reset() | ||
with torch.no_grad(): | ||
while True: | ||
actions = policy(obs) | ||
obs, _, rews, dones, infos = env.step(actions) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
||
""" | ||
# evaluation | ||
python examples/drone/hover_eval.py | ||
""" |
Oops, something went wrong.