From 99db3c1ff34379e6d67addfe7c064b3da834418e Mon Sep 17 00:00:00 2001 From: KafuuChikai <810658920qq@gmail.com> Date: Thu, 26 Dec 2024 23:55:06 +0800 Subject: [PATCH] add train hover demo --- examples/drone/hover_env.py | 229 ++++++++++++++++++++++++++++++++++ examples/drone/hover_eval.py | 51 ++++++++ examples/drone/hover_train.py | 142 +++++++++++++++++++++ 3 files changed, 422 insertions(+) create mode 100644 examples/drone/hover_env.py create mode 100644 examples/drone/hover_eval.py create mode 100644 examples/drone/hover_train.py diff --git a/examples/drone/hover_env.py b/examples/drone/hover_env.py new file mode 100644 index 0000000..b45f5f3 --- /dev/null +++ b/examples/drone/hover_env.py @@ -0,0 +1,229 @@ +import torch +import math +import genesis as gs +from genesis.utils.geom import quat_to_xyz, transform_by_quat, inv_quat, transform_quat_by_quat + +def gs_rand_float(lower, upper, shape, device): + return (upper - lower) * torch.rand(size=shape, device=device) + lower + +class HoverEnv: + def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_viewer=False, device="cuda"): + self.device = torch.device(device) + + self.num_envs = num_envs + self.num_obs = obs_cfg["num_obs"] + self.num_privileged_obs = None + self.num_actions = env_cfg["num_actions"] + self.num_commands = command_cfg["num_commands"] + + # self.simulate_action_latency = env_cfg["simulate_action_latency"] + self.dt = 0.01 # run in 100hz + self.max_episode_length = math.ceil(env_cfg["episode_length_s"] / self.dt) + + self.env_cfg = env_cfg + self.obs_cfg = obs_cfg + self.reward_cfg = reward_cfg + self.command_cfg = command_cfg + + self.obs_scales = obs_cfg["obs_scales"] + self.reward_scales = reward_cfg["reward_scales"] + + # create scene + self.scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=self.dt, substeps=2), + viewer_options=gs.options.ViewerOptions( + max_FPS=60, + camera_pos=(2.0, 0.0, 2.5), + camera_lookat=(0.0, 0.0, 1.0), + camera_fov=40, + ), + vis_options=gs.options.VisOptions(n_rendered_envs=1), + rigid_options=gs.options.RigidOptions( + dt=self.dt, + constraint_solver=gs.constraint_solver.Newton, + enable_collision=True, + enable_joint_limit=True, + ), + show_viewer=show_viewer, + ) + + # add plane + self.scene.add_entity(gs.morphs.Plane()) + + # add drone + self.base_init_pos = torch.tensor(self.env_cfg["base_init_pos"], device=self.device) + self.base_init_quat = torch.tensor(self.env_cfg["base_init_quat"], device=self.device) + self.inv_base_init_quat = inv_quat(self.base_init_quat) + # self.base_init_pos = torch.tensor(self.env_cfg["base_init_pos"], device=self.device) + self.drone = self.scene.add_entity(gs.morphs.Drone(file="urdf/drones/cf2x.urdf")) + + # build scene + self.scene.build(n_envs=num_envs) + + # prepare reward functions and multiply reward scales by dt + self.reward_functions, self.episode_sums = dict(), dict() + for name in self.reward_scales.keys(): + self.reward_scales[name] *= self.dt + self.reward_functions[name] = getattr(self, "_reward_" + name) + self.episode_sums[name] = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float) + + # initialize buffers + self.obs_buf = torch.zeros((self.num_envs, self.num_obs), device=self.device, dtype=gs.tc_float) + self.rew_buf = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float) + self.reset_buf = torch.ones((self.num_envs,), device=self.device, dtype=gs.tc_int) + self.episode_length_buf = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_int) + self.commands = torch.zeros((self.num_envs, self.num_commands), device=self.device, dtype=gs.tc_float) + + self.actions = torch.zeros((self.num_envs, self.num_actions), device=self.device, dtype=gs.tc_float) + self.last_actions = torch.zeros_like(self.actions) + + self.base_pos = torch.zeros((self.num_envs, 3), device=self.device, dtype=gs.tc_float) + self.base_quat = torch.zeros((self.num_envs, 4), device=self.device, dtype=gs.tc_float) + self.base_lin_vel = torch.zeros((self.num_envs, 3), device=self.device, dtype=gs.tc_float) + self.base_ang_vel = torch.zeros((self.num_envs, 3), device=self.device, dtype=gs.tc_float) + self.last_base_pos = torch.zeros_like(self.base_pos) + + self.extras = dict() # extra information for logging + + def _resample_commands(self, envs_idx): + self.commands[envs_idx, 0] = gs_rand_float(*self.command_cfg["pos_x_range"], (len(envs_idx),), self.device) + self.commands[envs_idx, 1] = gs_rand_float(*self.command_cfg["pos_y_range"], (len(envs_idx),), self.device) + self.commands[envs_idx, 2] = gs_rand_float(*self.command_cfg["pos_z_range"], (len(envs_idx),), self.device) + + def step(self, actions): + self.actions = torch.clip(actions, -self.env_cfg["clip_actions"], self.env_cfg["clip_actions"]) + exec_actions = self.actions.cpu() + + # exec_actions = self.last_actions if self.simulate_action_latency else self.actions + # target_dof_pos = exec_actions * self.env_cfg["action_scale"] + self.default_dof_pos + # self.drone.control_dofs_position(target_dof_pos) + + # 14468 is hover rpm + self.drone.set_propellels_rpm((1 + exec_actions) * 14468.429183500699) + self.scene.step() + + # update buffers + self.episode_length_buf += 1 + self.last_base_pos[:] = self.base_pos[:] + self.base_pos[:] = self.drone.get_pos() + self.rel_pos = self.commands - self.base_pos + self.last_rel_pos = self.commands - self.last_base_pos + self.base_quat[:] = self.drone.get_quat() + # self.base_euler = quat_to_xyz(self.base_quat) + self.base_euler = quat_to_xyz( + transform_quat_by_quat(torch.ones_like(self.base_quat) * self.inv_base_init_quat, self.base_quat) + ) + inv_base_quat = inv_quat(self.base_quat) + self.base_lin_vel[:] = transform_by_quat(self.drone.get_vel(), inv_base_quat) + self.base_ang_vel[:] = transform_by_quat(self.drone.get_ang(), inv_base_quat) + + # resample commands + # envs_idx = ( + # (self.episode_length_buf % int(self.env_cfg["resampling_time_s"] / self.dt) == 0) + # .nonzero(as_tuple=False) + # .flatten() + # ) + # self._resample_commands(envs_idx) + + # check termination and reset + self.reset_buf = self.episode_length_buf > self.max_episode_length + self.reset_buf |= torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"] + self.reset_buf |= torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"] + self.reset_buf |= torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"] + self.reset_buf |= torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"] + self.reset_buf |= torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"] + self.reset_buf |= self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"] + + time_out_idx = (self.episode_length_buf > self.max_episode_length).nonzero(as_tuple=False).flatten() + self.extras["time_outs"] = torch.zeros_like(self.reset_buf, device=self.device, dtype=gs.tc_float) + self.extras["time_outs"][time_out_idx] = 1.0 + + self.reset_idx(self.reset_buf.nonzero(as_tuple=False).flatten()) + + # compute reward + self.rew_buf[:] = 0.0 + for name, reward_func in self.reward_functions.items(): + rew = reward_func() * self.reward_scales[name] + self.rew_buf += rew + self.episode_sums[name] += rew + + # compute observations + self.obs_buf = torch.cat( + [ + torch.clip(self.rel_pos * self.obs_scales["rel_pos"], -1, 1), + self.base_quat, + torch.clip(self.base_lin_vel * self.obs_scales["lin_vel"], -1, 1), + torch.clip(self.base_ang_vel * self.obs_scales["ang_vel"], -1, 1), + self.last_actions, + ], + axis=-1, + ) + + self.last_actions[:] = self.actions[:] + + return self.obs_buf, None, self.rew_buf, self.reset_buf, self.extras + + def get_observations(self): + return self.obs_buf + + def get_privileged_observations(self): + return None + + def reset_idx(self, envs_idx): + if len(envs_idx) == 0: + return + + # reset base + self.base_pos[envs_idx] = self.base_init_pos + self.last_base_pos[envs_idx] = self.base_init_pos + self.rel_pos = self.commands - self.base_pos + self.last_rel_pos = self.commands - self.last_base_pos + self.base_quat[envs_idx] = self.base_init_quat.reshape(1, -1) + self.drone.set_pos(self.base_pos[envs_idx], zero_velocity=True, envs_idx=envs_idx) + self.drone.set_quat(self.base_quat[envs_idx], zero_velocity=True, envs_idx=envs_idx) + self.base_lin_vel[envs_idx] = 0 + self.base_ang_vel[envs_idx] = 0 + self.drone.zero_all_dofs_velocity(envs_idx) + + # reset buffers + self.last_actions[envs_idx] = 0.0 + self.episode_length_buf[envs_idx] = 0 + self.reset_buf[envs_idx] = True + + # fill extras + self.extras["episode"] = {} + for key in self.episode_sums.keys(): + self.extras["episode"]["rew_" + key] = ( + torch.mean(self.episode_sums[key][envs_idx]).item() / self.env_cfg["episode_length_s"] + ) + self.episode_sums[key][envs_idx] = 0.0 + + self._resample_commands(envs_idx) + + def reset(self): + self.reset_buf[:] = True + self.reset_idx(torch.arange(self.num_envs, device=self.device)) + return self.obs_buf, None + + # ------------ reward functions---------------- + def _reward_target(self): + target_rew = torch.sum(torch.square(self.last_rel_pos), dim=1) - torch.sum(torch.square(self.rel_pos), dim=1) + return target_rew + + def _reward_smooth(self): + smooth_rew = torch.sum(torch.square(self.actions - self.last_actions), dim=1) + return smooth_rew + + def _reward_crash(self): + crash_rew = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float) + + crash_condition = ( + (torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"]) | + (torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"]) | + (torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"]) | + (torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"]) | + (torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"]) | + (self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"]) + ) + crash_rew[crash_condition] = -1 + return crash_rew \ No newline at end of file diff --git a/examples/drone/hover_eval.py b/examples/drone/hover_eval.py new file mode 100644 index 0000000..7935239 --- /dev/null +++ b/examples/drone/hover_eval.py @@ -0,0 +1,51 @@ +import argparse +import os +import pickle + +import torch +from hover_env import HoverEnv +from rsl_rl.runners import OnPolicyRunner + +import genesis as gs + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-e", "--exp_name", type=str, default="drone-hovering") + parser.add_argument("--ckpt", type=int, default=100) + args = parser.parse_args() + + gs.init() + + log_dir = f"logs/{args.exp_name}" + env_cfg, obs_cfg, reward_cfg, command_cfg, train_cfg = pickle.load(open(f"logs/{args.exp_name}/cfgs.pkl", "rb")) + reward_cfg["reward_scales"] = {} + + env = HoverEnv( + num_envs=1, + env_cfg=env_cfg, + obs_cfg=obs_cfg, + reward_cfg=reward_cfg, + command_cfg=command_cfg, + show_viewer=True, + ) + + runner = OnPolicyRunner(env, train_cfg, log_dir, device="cuda:0") + resume_path = os.path.join(log_dir, f"model_{args.ckpt}.pt") + runner.load(resume_path) + policy = runner.get_inference_policy(device="cuda:0") + + obs, _ = env.reset() + with torch.no_grad(): + while True: + actions = policy(obs) + obs, _, rews, dones, infos = env.step(actions) + + +if __name__ == "__main__": + main() + +""" +# evaluation +python examples/drone/hover_eval.py +""" diff --git a/examples/drone/hover_train.py b/examples/drone/hover_train.py new file mode 100644 index 0000000..a842934 --- /dev/null +++ b/examples/drone/hover_train.py @@ -0,0 +1,142 @@ +import argparse +import os +import pickle +import shutil + +from hover_env import HoverEnv +from rsl_rl.runners import OnPolicyRunner + +import genesis as gs + + +def get_train_cfg(exp_name, max_iterations): + + train_cfg_dict = { + "algorithm": { + "clip_param": 0.2, + "desired_kl": 0.01, + "entropy_coef": 0.01, + "gamma": 0.99, + "lam": 0.95, + "learning_rate": 0.001, + "max_grad_norm": 1.0, + "num_learning_epochs": 5, + "num_mini_batches": 4, + "schedule": "adaptive", + "use_clipped_value_loss": True, + "value_loss_coef": 1.0, + }, + "init_member_classes": {}, + "policy": { + "activation": "tanh", + "actor_hidden_dims": [128, 128], + "critic_hidden_dims": [128, 128], + "init_noise_std": 1.0, + }, + "runner": { + "algorithm_class_name": "PPO", + "checkpoint": -1, + "experiment_name": exp_name, + "load_run": -1, + "log_interval": 1, + "max_iterations": max_iterations, + "num_steps_per_env": 24, + "policy_class_name": "ActorCritic", + "record_interval": -1, + "resume": False, + "resume_path": None, + "run_name": "", + "runner_class_name": "runner_class_name", + "save_interval": 100, + }, + "runner_class_name": "OnPolicyRunner", + "seed": 1, + } + + return train_cfg_dict + + +def get_cfgs(): + env_cfg = { + "num_actions": 4, + # termination + "termination_if_roll_greater_than": 180, # degree + "termination_if_pitch_greater_than": 180, + "termination_if_close_to_ground": 0.1, + "termination_if_x_greater_than": 3.0, + "termination_if_y_greater_than": 3.0, + "termination_if_z_greater_than": 2.0, + # base pose + "base_init_pos": [0.0, 0.0, 1.0], + "base_init_quat": [1.0, 0.0, 0.0, 0.0], + "episode_length_s": 5.0, + "resampling_time_s": 5.0, + # "action_scale": 0.25, + # "simulate_action_latency": True, + "clip_actions": 1.0, + } + obs_cfg = { + "num_obs": 17, + "obs_scales": { + "rel_pos": 1 / 3.0, + "euler_xy": 1 / 180, + "euler_z": 1 / 360, + "lin_vel": 1 / 3.0, + "ang_vel": 1 / 3.14159, + }, + } + reward_cfg = { + "reward_scales":{ + "target": 5.0, + "smooth": -0.001, + "crash": 1.0, + } + } + command_cfg = { + "num_commands": 3, + "pos_x_range": [-1.0, 1.0], + "pos_y_range": [-1.0, 1.0], + "pos_z_range": [1.0, 1.0], + } + + return env_cfg, obs_cfg, reward_cfg, command_cfg + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-e", "--exp_name", type=str, default="drone-hovering") + parser.add_argument("-B", "--num_envs", type=int, default=4096) + parser.add_argument("--max_iterations", type=int, default=1000) + args = parser.parse_args() + + gs.init(logging_level="warning") + + log_dir = f"logs/{args.exp_name}" + env_cfg, obs_cfg, reward_cfg, command_cfg = get_cfgs() + train_cfg = get_train_cfg(args.exp_name, args.max_iterations) + + if os.path.exists(log_dir): + shutil.rmtree(log_dir) + os.makedirs(log_dir, exist_ok=True) + + env = HoverEnv( + num_envs=args.num_envs, env_cfg=env_cfg, obs_cfg=obs_cfg, reward_cfg=reward_cfg, command_cfg=command_cfg + ) + + runner = OnPolicyRunner(env, train_cfg, log_dir, device="cuda:0") + + pickle.dump( + [env_cfg, obs_cfg, reward_cfg, command_cfg, train_cfg], + open(f"{log_dir}/cfgs.pkl", "wb"), + ) + + runner.learn(num_learning_iterations=args.max_iterations, init_at_random_ep_len=True) + + +if __name__ == "__main__": + main() + +""" +# training +python examples/drone/hover_train.py +"""