diff --git a/examples/locomotion/backflip/readme.md b/examples/locomotion/backflip/readme.md new file mode 100644 index 00000000..a12ea331 --- /dev/null +++ b/examples/locomotion/backflip/readme.md @@ -0,0 +1,15 @@ +# Go2 Backflips + +To try Go2 backflips, follow these steps: + +1. Download the checkpoints from the link below: + ``` + https://drive.google.com/drive/folders/1ZxBaDP4_Br0ZhriQx_8A3JIwZZLnxix4?usp=drive_link + ``` +2. Place the downloaded checkpoints into the `backflip` directory. + +For advanced training code, including features like domain randomization and reference rewards for backflip, visit this repository: + +``` +https://github.com/ziyanx02/Genesis-backflip +``` \ No newline at end of file diff --git a/examples/locomotion/go2_backflip.py b/examples/locomotion/go2_backflip.py new file mode 100644 index 00000000..474913b0 --- /dev/null +++ b/examples/locomotion/go2_backflip.py @@ -0,0 +1,149 @@ +import argparse +import os +import pickle + +import torch +from go2_env import Go2Env + +import genesis as gs + + +def get_cfgs(): + env_cfg = { + "num_actions": 12, + # joint/link names + "default_joint_angles": { # [rad] + "FL_hip_joint": 0.0, + "FR_hip_joint": 0.0, + "RL_hip_joint": 0.0, + "RR_hip_joint": 0.0, + "FL_thigh_joint": 0.8, + "FR_thigh_joint": 0.8, + "RL_thigh_joint": 1.0, + "RR_thigh_joint": 1.0, + "FL_calf_joint": -1.5, + "FR_calf_joint": -1.5, + "RL_calf_joint": -1.5, + "RR_calf_joint": -1.5, + }, + "dof_names": [ + "FR_hip_joint", + "FR_thigh_joint", + "FR_calf_joint", + "FL_hip_joint", + "FL_thigh_joint", + "FL_calf_joint", + "RR_hip_joint", + "RR_thigh_joint", + "RR_calf_joint", + "RL_hip_joint", + "RL_thigh_joint", + "RL_calf_joint", + ], + # PD + "kp": 70.0, + "kd": 3.0, + # termination + "termination_if_roll_greater_than": 1000, # degree + "termination_if_pitch_greater_than": 1000, + # base pose + "base_init_pos": [0.0, 0.0, 0.35], + "base_init_quat": [0.0, 0.0, 0.0, 1.0], + "episode_length_s": 20.0, + "resampling_time_s": 4.0, + "action_scale": 0.5, + "simulate_action_latency": True, + "clip_actions": 100.0, + } + obs_cfg = { + "num_obs": 60, + "obs_scales": { + "lin_vel": 2.0, + "ang_vel": 0.25, + "dof_pos": 1.0, + "dof_vel": 0.05, + }, + } + reward_cfg = {"reward_scales": {},} + command_cfg = { + "num_commands": 3, + "lin_vel_x_range": [0, 0], + "lin_vel_y_range": [0, 0], + "ang_vel_range": [0, 0], + } + + return env_cfg, obs_cfg, reward_cfg, command_cfg + + +class BackflipEnv(Go2Env): + + def get_observations(self): + + phase = torch.pi * self.episode_length_buf[:, None] / self.max_episode_length + self.obs_buf = torch.cat( + [ + self.base_ang_vel * self.obs_scales['ang_vel'], # 3 + self.projected_gravity, # 3 + (self.dof_pos - self.default_dof_pos) * self.obs_scales['dof_pos'], # 12 + self.dof_vel * self.obs_scales['dof_vel'], # 12 + self.actions, # 12 + self.last_actions, # 12 + torch.sin(phase), + torch.cos(phase), + torch.sin(phase / 2), + torch.cos(phase / 2), + torch.sin(phase / 4), + torch.cos(phase / 4), + ], + axis=-1, + ) + + return self.obs_buf + + def step(self, actions): + super().step(actions) + self.get_observations() + return self.obs_buf, None, self.rew_buf, self.reset_buf, self.extras + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-e", "--exp_name", type=str, default="single") + args = parser.parse_args() + + gs.init() + + env_cfg, obs_cfg, reward_cfg, command_cfg = get_cfgs() + + if args.exp_name == "single": + env_cfg["episode_length_s"] = 2 + elif args.exp_name == "double": + env_cfg["episode_length_s"] = 3 + else: + raise RuntimeError + + env = BackflipEnv( + num_envs=1, + env_cfg=env_cfg, + obs_cfg=obs_cfg, + reward_cfg=reward_cfg, + command_cfg=command_cfg, + show_viewer=True, + ) + + policy = torch.jit.load(f"./backflip/{args.exp_name}.pt") + policy.to(device="cuda:0") + + obs, _ = env.reset() + with torch.no_grad(): + while True: + actions = policy(obs) + obs, _, rews, dones, infos = env.step(actions) + + +if __name__ == "__main__": + main() + +""" +# evaluation +python examples/locomotion/go2_eval.py -e go2-walking -v --ckpt 100 +""" diff --git a/examples/locomotion/go2_train.py b/examples/locomotion/go2_train.py index 15a2d20a..7c77990c 100644 --- a/examples/locomotion/go2_train.py +++ b/examples/locomotion/go2_train.py @@ -96,7 +96,7 @@ def get_cfgs(): "termination_if_pitch_greater_than": 10, # base pose "base_init_pos": [0.0, 0.0, 0.42], - "base_init_quat": [0.0, 0.0, 0.0, 1.0], + "base_init_quat": [1.0, 0.0, 0.0, 0.0], "episode_length_s": 20.0, "resampling_time_s": 4.0, "action_scale": 0.25,