Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add backflip ckpts; update locomotion quaternion #58

Merged
merged 2 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions examples/locomotion/backflip/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Go2 Backflips

To try Go2 backflips, follow these steps:

1. Download the checkpoints from the link below:
```
https://drive.google.com/drive/folders/1ZxBaDP4_Br0ZhriQx_8A3JIwZZLnxix4?usp=drive_link
```
2. Place the downloaded checkpoints into the `backflip` directory.
For advanced training code, including features like domain randomization and reference rewards for backflip, visit this repository:
```
https://github.com/ziyanx02/Genesis-backflip
```
149 changes: 149 additions & 0 deletions examples/locomotion/go2_backflip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import argparse
import os
import pickle

import torch
from go2_env import Go2Env

import genesis as gs


def get_cfgs():
env_cfg = {
"num_actions": 12,
# joint/link names
"default_joint_angles": { # [rad]
"FL_hip_joint": 0.0,
"FR_hip_joint": 0.0,
"RL_hip_joint": 0.0,
"RR_hip_joint": 0.0,
"FL_thigh_joint": 0.8,
"FR_thigh_joint": 0.8,
"RL_thigh_joint": 1.0,
"RR_thigh_joint": 1.0,
"FL_calf_joint": -1.5,
"FR_calf_joint": -1.5,
"RL_calf_joint": -1.5,
"RR_calf_joint": -1.5,
},
"dof_names": [
"FR_hip_joint",
"FR_thigh_joint",
"FR_calf_joint",
"FL_hip_joint",
"FL_thigh_joint",
"FL_calf_joint",
"RR_hip_joint",
"RR_thigh_joint",
"RR_calf_joint",
"RL_hip_joint",
"RL_thigh_joint",
"RL_calf_joint",
],
# PD
"kp": 70.0,
"kd": 3.0,
# termination
"termination_if_roll_greater_than": 1000, # degree
"termination_if_pitch_greater_than": 1000,
# base pose
"base_init_pos": [0.0, 0.0, 0.35],
"base_init_quat": [0.0, 0.0, 0.0, 1.0],
"episode_length_s": 20.0,
"resampling_time_s": 4.0,
"action_scale": 0.5,
"simulate_action_latency": True,
"clip_actions": 100.0,
}
obs_cfg = {
"num_obs": 60,
"obs_scales": {
"lin_vel": 2.0,
"ang_vel": 0.25,
"dof_pos": 1.0,
"dof_vel": 0.05,
},
}
reward_cfg = {"reward_scales": {},}
command_cfg = {
"num_commands": 3,
"lin_vel_x_range": [0, 0],
"lin_vel_y_range": [0, 0],
"ang_vel_range": [0, 0],
}

return env_cfg, obs_cfg, reward_cfg, command_cfg


class BackflipEnv(Go2Env):

def get_observations(self):

phase = torch.pi * self.episode_length_buf[:, None] / self.max_episode_length
self.obs_buf = torch.cat(
[
self.base_ang_vel * self.obs_scales['ang_vel'], # 3
self.projected_gravity, # 3
(self.dof_pos - self.default_dof_pos) * self.obs_scales['dof_pos'], # 12
self.dof_vel * self.obs_scales['dof_vel'], # 12
self.actions, # 12
self.last_actions, # 12
torch.sin(phase),
torch.cos(phase),
torch.sin(phase / 2),
torch.cos(phase / 2),
torch.sin(phase / 4),
torch.cos(phase / 4),
],
axis=-1,
)

return self.obs_buf

def step(self, actions):
super().step(actions)
self.get_observations()
return self.obs_buf, None, self.rew_buf, self.reset_buf, self.extras

def main():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--exp_name", type=str, default="single")
args = parser.parse_args()

gs.init()

env_cfg, obs_cfg, reward_cfg, command_cfg = get_cfgs()

if args.exp_name == "single":
env_cfg["episode_length_s"] = 2
elif args.exp_name == "double":
env_cfg["episode_length_s"] = 3
else:
raise RuntimeError

env = BackflipEnv(
num_envs=1,
env_cfg=env_cfg,
obs_cfg=obs_cfg,
reward_cfg=reward_cfg,
command_cfg=command_cfg,
show_viewer=True,
)

policy = torch.jit.load(f"./backflip/{args.exp_name}.pt")
policy.to(device="cuda:0")

obs, _ = env.reset()
with torch.no_grad():
while True:
actions = policy(obs)
obs, _, rews, dones, infos = env.step(actions)


if __name__ == "__main__":
main()

"""
# evaluation
python examples/locomotion/go2_eval.py -e go2-walking -v --ckpt 100
"""
2 changes: 1 addition & 1 deletion examples/locomotion/go2_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_cfgs():
"termination_if_pitch_greater_than": 10,
# base pose
"base_init_pos": [0.0, 0.0, 0.42],
"base_init_quat": [0.0, 0.0, 0.0, 1.0],
"base_init_quat": [1.0, 0.0, 0.0, 0.0],
"episode_length_s": 20.0,
"resampling_time_s": 4.0,
"action_scale": 0.25,
Expand Down