Skip to content

Commit

Permalink
Merge pull request #358 from KafuuChikai/drone_rl
Browse files Browse the repository at this point in the history
Update Drone-Env with Random Target and Improved Documentation
  • Loading branch information
zhouxian authored Dec 27, 2024
2 parents c7200f5 + 613bfdc commit ad8317c
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 52 deletions.
44 changes: 44 additions & 0 deletions examples/drone/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,50 @@ Run with:
python fly.py -v -m
```

### 3. Hover Environment (`hover_env.py`, `hover_train.py`, `hover_eval.py`)

The hover environment (`hover_env.py`) is designed to train a drone to maintain a stable hover position by reaching randomly generated target points. The environment includes:

- Initialization of the scene and entities (plane, drone and target).
- Reward functions to provide feedback to the agent based on its performance in reaching the target points.
- **Command resampling to generate new random target points** and environment reset functionalities to ensure continuous training.

**Acknowledgement**: The reward design is inspired by [Champion-level drone racing using deep
reinforcement learning (Nature 2023)](https://www.nature.com/articles/s41586-023-06419-4.pdf)

#### 3.0 Installation

At this stage, we have defined the environments. Now, we use the PPO implementation from `rsl-rl` to train the policy. Follow these installation steps:

```bash
# Install rsl_rl.
git clone https://github.com/leggedrobotics/rsl_rl
cd rsl_rl && git checkout v1.0.2 && pip install -e .

# Install tensorboard.
pip install tensorboard
```

#### 3.1 Training

Train the drone hovering policy using the `HoverEnv` environment.

Run with:

```bash
python hover_train.py -e drone-hovering -B 8192 --max_iterations 500
```

#### 3.2 Evaluation

Evaluate the trained drone hovering policy.

Run with:

```bash
python hover_eval.py -e drone-hovering --ckpt 500 --record
```

## Technical Details

- The drone model used is the Crazyflie 2.X (`urdf/drones/cf2x.urdf`)
Expand Down
97 changes: 63 additions & 34 deletions examples/drone/hover_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import genesis as gs
from genesis.utils.geom import quat_to_xyz, transform_by_quat, inv_quat, transform_quat_by_quat


def gs_rand_float(lower, upper, shape, device):
return (upper - lower) * torch.rand(size=shape, device=device) + lower


class HoverEnv:
def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_viewer=False, device="cuda"):
self.device = torch.device(device)
Expand All @@ -18,7 +16,7 @@ def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_vie
self.num_actions = env_cfg["num_actions"]
self.num_commands = command_cfg["num_commands"]

# self.simulate_action_latency = env_cfg["simulate_action_latency"]
self.simulate_action_latency = env_cfg["simulate_action_latency"]
self.dt = 0.01 # run in 100hz
self.max_episode_length = math.ceil(env_cfg["episode_length_s"] / self.dt)

Expand All @@ -34,8 +32,8 @@ def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_vie
self.scene = gs.Scene(
sim_options=gs.options.SimOptions(dt=self.dt, substeps=2),
viewer_options=gs.options.ViewerOptions(
max_FPS=60,
camera_pos=(2.0, 0.0, 2.5),
max_FPS=env_cfg["max_visualize_FPS"],
camera_pos=(3.0, 0.0, 3.0),
camera_lookat=(0.0, 0.0, 1.0),
camera_fov=40,
),
Expand All @@ -52,11 +50,35 @@ def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_vie
# add plane
self.scene.add_entity(gs.morphs.Plane())

# add target
if self.env_cfg["visualize_target"]:
self.target = self.scene.add_entity(morph=gs.morphs.Mesh(
file="meshes/sphere.obj",
scale=0.05,
fixed=True,
collision=False,
),
surface=gs.surfaces.Rough(
diffuse_texture=gs.textures.ColorTexture(
color=(1.0, 0.5, 0.5),
),
),
)

# add camera
if self.env_cfg["visualize_camera"]:
self.cam = self.scene.add_camera(
res=(640, 480),
pos=(3.5, 0.0, 2.5),
lookat=(0, 0, 0.5),
fov=30,
GUI=True,
)

# add drone
self.base_init_pos = torch.tensor(self.env_cfg["base_init_pos"], device=self.device)
self.base_init_quat = torch.tensor(self.env_cfg["base_init_quat"], device=self.device)
self.inv_base_init_quat = inv_quat(self.base_init_quat)
# self.base_init_pos = torch.tensor(self.env_cfg["base_init_pos"], device=self.device)
self.drone = self.scene.add_entity(gs.morphs.Drone(file="urdf/drones/cf2x.urdf"))

# build scene
Expand Down Expand Up @@ -91,17 +113,26 @@ def _resample_commands(self, envs_idx):
self.commands[envs_idx, 0] = gs_rand_float(*self.command_cfg["pos_x_range"], (len(envs_idx),), self.device)
self.commands[envs_idx, 1] = gs_rand_float(*self.command_cfg["pos_y_range"], (len(envs_idx),), self.device)
self.commands[envs_idx, 2] = gs_rand_float(*self.command_cfg["pos_z_range"], (len(envs_idx),), self.device)
if self.target is not None:
self.target.set_pos(self.commands[envs_idx], zero_velocity=True, envs_idx=envs_idx)

def _at_target(self):
at_target = (
(torch.norm(self.rel_pos, dim=1) < self.env_cfg["at_target_threshold"])
.nonzero(as_tuple=False)
.flatten()
)
return at_target

def step(self, actions):
self.actions = torch.clip(actions, -self.env_cfg["clip_actions"], self.env_cfg["clip_actions"])
exec_actions = self.actions.cpu()

# exec_actions = self.last_actions if self.simulate_action_latency else self.actions
# exec_actions = self.last_actions.cpu() if self.simulate_action_latency else self.actions.cpu()
# target_dof_pos = exec_actions * self.env_cfg["action_scale"] + self.default_dof_pos
# self.drone.control_dofs_position(target_dof_pos)

# 14468 is hover rpm
self.drone.set_propellels_rpm((1 + exec_actions) * 14468.429183500699)
self.drone.set_propellels_rpm((1 + exec_actions*0.8) * 14468.429183500699)
self.scene.step()

# update buffers
Expand All @@ -111,7 +142,6 @@ def step(self, actions):
self.rel_pos = self.commands - self.base_pos
self.last_rel_pos = self.commands - self.last_base_pos
self.base_quat[:] = self.drone.get_quat()
# self.base_euler = quat_to_xyz(self.base_quat)
self.base_euler = quat_to_xyz(
transform_quat_by_quat(torch.ones_like(self.base_quat) * self.inv_base_init_quat, self.base_quat)
)
Expand All @@ -120,21 +150,19 @@ def step(self, actions):
self.base_ang_vel[:] = transform_by_quat(self.drone.get_ang(), inv_base_quat)

# resample commands
# envs_idx = (
# (self.episode_length_buf % int(self.env_cfg["resampling_time_s"] / self.dt) == 0)
# .nonzero(as_tuple=False)
# .flatten()
# )
# self._resample_commands(envs_idx)
envs_idx = self._at_target()
self._resample_commands(envs_idx)

# check termination and reset
self.reset_buf = self.episode_length_buf > self.max_episode_length
self.reset_buf |= torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"]
self.reset_buf |= torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"]
self.reset_buf |= torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"]
self.reset_buf |= torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"]
self.reset_buf |= torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"]
self.reset_buf |= self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"]
self.crash_condition = (
(torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"]) |
(torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"]) |
(torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"]) |
(torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"]) |
(torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"]) |
(self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"])
)
self.reset_buf = (self.episode_length_buf > self.max_episode_length) | self.crash_condition

time_out_idx = (self.episode_length_buf > self.max_episode_length).nonzero(as_tuple=False).flatten()
self.extras["time_outs"] = torch.zeros_like(self.reset_buf, device=self.device, dtype=gs.tc_float)
Expand Down Expand Up @@ -216,16 +244,17 @@ def _reward_smooth(self):
smooth_rew = torch.sum(torch.square(self.actions - self.last_actions), dim=1)
return smooth_rew

def _reward_yaw(self):
yaw = self.base_euler[:, 2]
yaw = torch.where(yaw > 180, yaw - 360, yaw)/180*3.14159 # use rad for yaw_reward
yaw_rew = torch.exp(self.reward_cfg["yaw_lambda"] * torch.abs(yaw))
return yaw_rew

def _reward_angular(self):
angular_rew = torch.norm(self.base_ang_vel/3.14159, dim=1)
return angular_rew

def _reward_crash(self):
crash_rew = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float)

crash_condition = (
(torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"])
| (torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"])
| (torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"])
| (torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"])
| (torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"])
| (self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"])
)
crash_rew[crash_condition] = -1
return crash_rew
crash_rew[self.crash_condition] = 1
return crash_rew
27 changes: 22 additions & 5 deletions examples/drone/hover_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--exp_name", type=str, default="drone-hovering")
parser.add_argument("--ckpt", type=int, default=100)
parser.add_argument("--ckpt", type=int, default=500)
parser.add_argument("--record", action="store_true", default=False)
args = parser.parse_args()

gs.init()
Expand All @@ -21,6 +22,13 @@ def main():
env_cfg, obs_cfg, reward_cfg, command_cfg, train_cfg = pickle.load(open(f"logs/{args.exp_name}/cfgs.pkl", "rb"))
reward_cfg["reward_scales"] = {}

# visualize the target
env_cfg["visualize_target"] = True
# for video recording
env_cfg["visualize_camera"] = args.record
# set the max FPS for visualization
env_cfg["max_visualize_FPS"] = 60

env = HoverEnv(
num_envs=1,
env_cfg=env_cfg,
Expand All @@ -36,11 +44,20 @@ def main():
policy = runner.get_inference_policy(device="cuda:0")

obs, _ = env.reset()
with torch.no_grad():
while True:
actions = policy(obs)
obs, _, rews, dones, infos = env.step(actions)

max_sim_step = int(env_cfg["episode_length_s"]*env_cfg["max_visualize_FPS"])
with torch.no_grad():
if args.record:
env.cam.start_recording()
for _ in range(max_sim_step):
actions = policy(obs)
obs, _, rews, dones, infos = env.step(actions)
env.cam.render()
env.cam.stop_recording(save_to_filename="video.mp4", fps=env_cfg["max_visualize_FPS"])
else:
for _ in range(max_sim_step):
actions = policy(obs)
obs, _, rews, dones, infos = env.step(actions)

if __name__ == "__main__":
main()
Expand Down
31 changes: 18 additions & 13 deletions examples/drone/hover_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ def get_train_cfg(exp_name, max_iterations):
"algorithm": {
"clip_param": 0.2,
"desired_kl": 0.01,
"entropy_coef": 0.01,
"entropy_coef": 0.002,
"gamma": 0.99,
"lam": 0.95,
"learning_rate": 0.001,
"learning_rate": 0.0003,
"max_grad_norm": 1.0,
"num_learning_epochs": 5,
"num_mini_batches": 4,
Expand Down Expand Up @@ -69,27 +69,32 @@ def get_cfgs():
# base pose
"base_init_pos": [0.0, 0.0, 1.0],
"base_init_quat": [1.0, 0.0, 0.0, 0.0],
"episode_length_s": 5.0,
"resampling_time_s": 5.0,
# "action_scale": 0.25,
# "simulate_action_latency": True,
"episode_length_s": 15.0,
"at_target_threshold": 0.1,
"resampling_time_s": 3.0,
"simulate_action_latency": True,
"clip_actions": 1.0,
# visualization
"visualize_target": False,
"visualize_camera": False,
"max_visualize_FPS": 60,
}
obs_cfg = {
"num_obs": 17,
"obs_scales": {
"rel_pos": 1 / 3.0,
"euler_xy": 1 / 180,
"euler_z": 1 / 360,
"lin_vel": 1 / 3.0,
"ang_vel": 1 / 3.14159,
},
}
reward_cfg = {
"yaw_lambda": -10.0,
"reward_scales": {
"target": 5.0,
"smooth": -0.001,
"crash": 1.0,
"target": 10.0,
"smooth": -1e-4,
"yaw": 0.01,
"angular": -2e-4,
"crash": -10.0,
}
}
command_cfg = {
Expand All @@ -105,8 +110,8 @@ def get_cfgs():
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--exp_name", type=str, default="drone-hovering")
parser.add_argument("-B", "--num_envs", type=int, default=4096)
parser.add_argument("--max_iterations", type=int, default=1000)
parser.add_argument("-B", "--num_envs", type=int, default=8192)
parser.add_argument("--max_iterations", type=int, default=500)
args = parser.parse_args()

gs.init(logging_level="warning")
Expand Down

0 comments on commit ad8317c

Please sign in to comment.