Skip to content

Commit

Permalink
add backflip evaluation doc
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyanx02 committed Dec 20, 2024
1 parent 85f1d87 commit 0305fff
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 1 deletion.
15 changes: 15 additions & 0 deletions examples/locomotion/backflip/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Go2 Backflips

To try Go2 backflips, follow these steps:

1. Download the checkpoints from the link below:
```
https://drive.google.com/drive/folders/1ZxBaDP4_Br0ZhriQx_8A3JIwZZLnxix4?usp=drive_link
```
2. Place the downloaded checkpoints into the `backflip` directory.
For advanced training code, including features like domain randomization and reference rewards for backflip, visit the repository:
```
https://github.com/ziyanx02/Genesis-backflip
```
149 changes: 149 additions & 0 deletions examples/locomotion/go2_backflip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import argparse
import os
import pickle

import torch
from go2_env import Go2Env

import genesis as gs


def get_cfgs():
env_cfg = {
"num_actions": 12,
# joint/link names
"default_joint_angles": { # [rad]
"FL_hip_joint": 0.0,
"FR_hip_joint": 0.0,
"RL_hip_joint": 0.0,
"RR_hip_joint": 0.0,
"FL_thigh_joint": 0.8,
"FR_thigh_joint": 0.8,
"RL_thigh_joint": 1.0,
"RR_thigh_joint": 1.0,
"FL_calf_joint": -1.5,
"FR_calf_joint": -1.5,
"RL_calf_joint": -1.5,
"RR_calf_joint": -1.5,
},
"dof_names": [
"FR_hip_joint",
"FR_thigh_joint",
"FR_calf_joint",
"FL_hip_joint",
"FL_thigh_joint",
"FL_calf_joint",
"RR_hip_joint",
"RR_thigh_joint",
"RR_calf_joint",
"RL_hip_joint",
"RL_thigh_joint",
"RL_calf_joint",
],
# PD
"kp": 70.0,
"kd": 3.0,
# termination
"termination_if_roll_greater_than": 1000, # degree
"termination_if_pitch_greater_than": 1000,
# base pose
"base_init_pos": [0.0, 0.0, 0.35],
"base_init_quat": [0.0, 0.0, 0.0, 1.0],
"episode_length_s": 20.0,
"resampling_time_s": 4.0,
"action_scale": 0.5,
"simulate_action_latency": True,
"clip_actions": 100.0,
}
obs_cfg = {
"num_obs": 60,
"obs_scales": {
"lin_vel": 2.0,
"ang_vel": 0.25,
"dof_pos": 1.0,
"dof_vel": 0.05,
},
}
reward_cfg = {"reward_scales": {},}
command_cfg = {
"num_commands": 3,
"lin_vel_x_range": [0, 0],
"lin_vel_y_range": [0, 0],
"ang_vel_range": [0, 0],
}

return env_cfg, obs_cfg, reward_cfg, command_cfg


class BackflipEnv(Go2Env):

def get_observations(self):

phase = torch.pi * self.episode_length_buf[:, None] / self.max_episode_length
self.obs_buf = torch.cat(
[
self.base_ang_vel * self.obs_scales['ang_vel'], # 3
self.projected_gravity, # 3
(self.dof_pos - self.default_dof_pos) * self.obs_scales['dof_pos'], # 12
self.dof_vel * self.obs_scales['dof_vel'], # 12
self.actions, # 12
self.last_actions, # 12
torch.sin(phase),
torch.cos(phase),
torch.sin(phase / 2),
torch.cos(phase / 2),
torch.sin(phase / 4),
torch.cos(phase / 4),
],
axis=-1,
)

return self.obs_buf

def step(self, actions):
super().step(actions)
self.get_observations()
return self.obs_buf, None, self.rew_buf, self.reset_buf, self.extras

def main():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--exp_name", type=str, default="single")
args = parser.parse_args()

gs.init()

env_cfg, obs_cfg, reward_cfg, command_cfg = get_cfgs()

if args.exp_name == "single":
env_cfg["episode_length_s"] = 2
elif args.exp_name == "double":
env_cfg["episode_length_s"] = 3
else:
raise RuntimeError

env = BackflipEnv(
num_envs=1,
env_cfg=env_cfg,
obs_cfg=obs_cfg,
reward_cfg=reward_cfg,
command_cfg=command_cfg,
show_viewer=True,
)

policy = torch.jit.load(f"./backflip/{args.exp_name}.pt")
policy.to(device="cuda:0")

obs, _ = env.reset()
with torch.no_grad():
while True:
actions = policy(obs)
obs, _, rews, dones, infos = env.step(actions)


if __name__ == "__main__":
main()

"""
# evaluation
python examples/locomotion/go2_eval.py -e go2-walking -v --ckpt 100
"""
2 changes: 1 addition & 1 deletion examples/locomotion/go2_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_cfgs():
"termination_if_pitch_greater_than": 10,
# base pose
"base_init_pos": [0.0, 0.0, 0.42],
"base_init_quat": [0.0, 0.0, 0.0, 1.0],
"base_init_quat": [1.0, 0.0, 0.0, 0.0],
"episode_length_s": 20.0,
"resampling_time_s": 4.0,
"action_scale": 0.25,
Expand Down

0 comments on commit 0305fff

Please sign in to comment.