Skip to content

Commit

Permalink
Fix minor bugs, fix formatting, enforce consistent type hinting
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Nov 26, 2024
1 parent d7ea128 commit 1a65f13
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 63 deletions.
8 changes: 4 additions & 4 deletions benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def analyze_timings(times: list[float], n_steps: int, n_worlds: int, freq: float
# Check for significant variance
if tmax / tmin > 5:
print("Warning: step time varies by more than 5x. Is JIT compiling during the benchmark?")
print(f"Times: max {tmax:.2e}@{idx_tmax}, min {tmin:.2e}@{idx_tmin}")
print(f"Times: max {tmax:.2e} @ {idx_tmax}, min {tmin:.2e} @ {idx_tmin}")

# Performance metrics
n_frames = n_steps * n_worlds # Number of frames simulated
Expand All @@ -43,7 +43,7 @@ def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, devic
device = jax.devices(device)[0]

envs = gymnasium.make_vec(
"CrazyflowEnvReachGoal-v0",
"DroneReachPos-v0",
max_episode_steps=200,
return_datatype="numpy",
num_envs=sim_config.n_worlds,
Expand Down Expand Up @@ -114,10 +114,10 @@ def main():
sim_config.controller = "emulatefirmware"
sim_config.device = device

print("SIM PERFORMANCE")
print("Simulator performance")
profile_step(sim_config, 100, device)

print("\nGYM ENV PERFORMANCE")
print("\nGymnasium environment performance")
profile_gym_env_step(sim_config, 100, device)


Expand Down
51 changes: 22 additions & 29 deletions benchmark/performance.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import gymnasium
import jax
import numpy as np
Expand All @@ -8,6 +12,9 @@
import crazyflow # noqa: F401, ensure gymnasium envs are registered
from crazyflow.sim.core import Sim

if TYPE_CHECKING:
from crazyflow.gymnasium_envs import CrazyflowEnvReachGoal


def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
sim = Sim(**sim_config)
Expand Down Expand Up @@ -40,8 +47,8 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
device = jax.devices(device)[0]

envs = gymnasium.make_vec(
"CrazyflowEnvReachGoal-v0",
envs: CrazyflowEnvReachGoal = gymnasium.make_vec(
"DroneReachPos-v0",
max_episode_steps=200,
return_datatype="numpy",
num_envs=sim_config.n_worlds,
Expand All @@ -50,28 +57,25 @@ def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, devic
)

# Action for going up (in attitude control)
action = np.array(
[[[-0.3, 0, 0, 0] for _ in range(sim_config.n_drones)] for _ in range(sim_config.n_worlds)],
dtype=np.float32,
).reshape(sim_config.n_worlds, -1)

# step through env once to ensure JIT compilation
_, _ = envs.reset_all(seed=42)
_, _, _, _, _ = envs.step(action)
_, _ = envs.reset_all(seed=42)
_, _, _, _, _ = envs.step(action)
_, _ = envs.reset_all(seed=42)
_, _, _, _, _ = envs.step(action)
_, _ = envs.reset_all(seed=42)

jax.block_until_ready(envs.unwrapped.sim._mjx_data) # Ensure JIT compiled dynamics
action = np.zeros((sim_config.n_worlds, 4), dtype=np.float32)
action[..., 0] = -0.3

# Step through env once to ensure JIT compilation.
# TODO: Currently triggering recompiles also after the first full run. Investigate why and fix
# envs accordingly.
envs.reset_all(seed=42)

for _ in range(envs.max_episode_steps + 1): # Ensure all paths have been taken at least once
envs.step(action)

envs.reset_all(seed=42)

profiler = Profiler()
profiler.start()

for _ in range(n_steps):
_, _, _, _, _ = envs.step(action)
jax.block_until_ready(envs.unwrapped.sim._mjx_data)

profiler.stop()
renderer = HTMLRenderer()
renderer.open_in_browser(profiler.last_session)
Expand All @@ -89,17 +93,6 @@ def main():
sim_config.device = device

profile_step(sim_config, 1000, device)
# old | new
# sys_id + attitude:
# 0.61 reset, 0.61 step | 0.61 reset, 0.61 step
# sys_id + state:
# 14.53 step, 0.53 reset | 0.75 reset, 0.88 step

# Analytical + attitude:
# 0.75 reset, 9.38 step | 0.75 reset, 0.80 step
# Analytical + state:
# 0.75 reset, 15.1 step | 0.75 reset, 0.82 step

profile_gym_env_step(sim_config, 1000, device)


Expand Down
1 change: 1 addition & 0 deletions crazyflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import crazyflow.gymnasium_envs # noqa: F401, ensure gymnasium envs are registered
12 changes: 8 additions & 4 deletions crazyflow/gymnasium_envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from gymnasium.envs.registration import register

from crazyflow.gymnasium_envs.crazyflow import CrazyflowEnvReachGoal, CrazyflowEnvTargetVelocity

__all__ = ["CrazyflowEnvReachGoal", "CrazyflowEnvTargetVelocity"]

register(
id="CrazyflowEnvReachGoal-v0",
vector_entry_point="crazyflow.gymnasium_envs:CrazyflowEnvReachGoal",
id="DroneReachPos-v0",
vector_entry_point="crazyflow.gymnasium_envs.crazyflow:CrazyflowEnvReachGoal",
)

register(
id="CrazyflowEnvTargetVelocity-v0",
vector_entry_point="crazyflow.gymnasium_envs:CrazyflowEnvTargetVelocity",
id="DroneReachVel-v0",
vector_entry_point="crazyflow.gymnasium_envs.crazyflow:CrazyflowEnvTargetVelocity",
)
53 changes: 27 additions & 26 deletions crazyflow/gymnasium_envs/crazyflow.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import math
import warnings
from functools import partial
from typing import Dict, Literal, Optional, Tuple
from typing import Literal

import jax
import jax.numpy as jnp
import numpy as np
from flax.struct import dataclass
from gymnasium import spaces
from gymnasium.vector import VectorEnv
from gymnasium.vector.utils import batch_space
from jax import Array
from numpy.typing import NDArray

from crazyflow.control.controller import MAX_THRUST, MIN_THRUST, Control
from crazyflow.sim.core import Sim
Expand All @@ -19,8 +19,8 @@

@dataclass
class RescaleParams:
scale_factor: jnp.ndarray
mean: jnp.ndarray
scale_factor: Array
mean: Array


CONTROL_RESCALE_PARAMS = {
Expand All @@ -35,6 +35,12 @@ class RescaleParams:
}


@partial(jax.jit, static_argnames=["convert"])
def maybe_to_numpy(data: Array, convert: bool) -> NDArray | Array:
"""Converts data to numpy array if convert is True."""
return jax.lax.cond(convert, lambda: jax.device_get(data), lambda: data)


class CrazyflowBaseEnv(VectorEnv):
"""JAX Gymnasium environment for Crazyflie simulation."""

Expand Down Expand Up @@ -103,7 +109,7 @@ def __init__(
)
self.observation_space = batch_space(self.single_observation_space, self.sim.n_worlds)

def step(self, action: Array) -> Tuple[Array, Array, Array, Array, Dict]:
def step(self, action: Array) -> tuple[Array, Array, Array, Array, dict]:
assert self.action_space.contains(action), f"{action!r} ({type(action)}) invalid"
action = jnp.array(action, device=self.device).reshape(
(self.sim.n_worlds, self.sim.n_drones, -1)
Expand Down Expand Up @@ -139,8 +145,8 @@ def step(self, action: Array) -> Tuple[Array, Array, Array, Array, Dict]:
return (
self._get_obs(),
reward,
self._maybe_to_numpy(terminated),
self._maybe_to_numpy(truncated),
maybe_to_numpy(terminated, self.return_datatype == "numpy"),
maybe_to_numpy(truncated, self.return_datatype == "numpy"),
{},
)

Expand All @@ -165,7 +171,7 @@ def _rescale_action(action: Array, control_type: str) -> Array:
return action * params.scale_factor + params.mean

def reset_all(
self, *, seed: Optional[int] = None, options: Optional[dict] = None
self, *, seed: int | None = None, options: dict | None = None
) -> tuple[dict[str, Array], dict]:
super().reset(seed=seed)

Expand Down Expand Up @@ -226,42 +232,37 @@ def _reward() -> None:

@staticmethod
@jax.jit
def _terminated(dones: jax.Array, states: SimState, contacts: jax.Array) -> jnp.ndarray:
def _terminated(dones: Array, states: SimState, contacts: Array) -> Array:
contact = jnp.any(contacts, axis=1)
z_coords = states.pos[..., 2]
below_ground = jnp.any(
z_coords < -0.1, axis=1
) # Should not be triggered due to collision checking
# Sanity check if we are below the ground. Should not be triggered due to collision checking
below_ground = jnp.any(z_coords < -0.1, axis=1)
terminated = jnp.logical_or(below_ground, contact) # no termination condition
return jnp.where(dones, False, terminated)

@staticmethod
@jax.jit
def _truncated(
dones: jax.Array, steps: jax.Array, max_episode_steps: jax.Array, n_substeps: jax.Array
) -> jnp.ndarray:
dones: Array, steps: Array, max_episode_steps: Array, n_substeps: Array
) -> Array:
truncated = steps / n_substeps >= max_episode_steps
return jnp.where(dones, False, truncated)

def render(self):
self.sim.render()

def _get_obs(self) -> Dict[str, jnp.ndarray]:
def _get_obs(self) -> dict[str, Array]:
obs = {
state: self._maybe_to_numpy(
state: maybe_to_numpy(
getattr(self.sim.states, state)[..., 2]
if state == "pos"
else getattr(self.sim.states, state)
else getattr(self.sim.states, state),
self.return_datatype == "numpy",
)
for state in self.states_to_include_in_obs
}
return obs

def _maybe_to_numpy(self, data: Array) -> np.ndarray:
if self.return_datatype == "numpy" and not isinstance(data, np.ndarray):
return jax.device_get(data)
return data


class CrazyflowEnvReachGoal(CrazyflowBaseEnv):
"""JAX Gymnasium environment for Crazyflie simulation."""
Expand All @@ -284,7 +285,7 @@ def reward(self) -> Array:

@staticmethod
@jax.jit
def _reward(terminated: jax.Array, states: SimState, goal: jax.Array) -> jnp.ndarray:
def _reward(terminated: Array, states: SimState, goal: Array) -> Array:
norm_distance = jnp.linalg.norm(states.pos - goal, axis=2)
reward = jnp.exp(-2.0 * norm_distance)
return jnp.where(terminated, -1.0, reward)
Expand All @@ -302,7 +303,7 @@ def reset(self, mask: Array) -> None:
)
self.goal = self.goal.at[mask].set(new_goals[mask])

def _get_obs(self) -> Dict[str, jnp.ndarray]:
def _get_obs(self) -> dict[str, Array]:
obs = super()._get_obs()
obs["difference_to_goal"] = [self.goal - self.sim.states.pos]
return obs
Expand All @@ -329,7 +330,7 @@ def reward(self) -> Array:

@staticmethod
@jax.jit
def _reward(terminated: jax.Array, states: SimState, target_vel: jax.Array) -> jnp.ndarray:
def _reward(terminated: Array, states: SimState, target_vel: Array) -> Array:
norm_distance = jnp.linalg.norm(states.vel - target_vel, axis=2)
reward = jnp.exp(-norm_distance)
return jnp.where(terminated, -1.0, reward)
Expand All @@ -347,7 +348,7 @@ def reset(self, mask: Array) -> None:
)
self.target_vel = self.target_vel.at[mask].set(new_target_vel[mask])

def _get_obs(self) -> Dict[str, jnp.ndarray]:
def _get_obs(self) -> dict[str, Array]:
obs = super()._get_obs()
obs["difference_to_target_vel"] = [self.target_vel - self.sim.states.vel]
return obs

0 comments on commit 1a65f13

Please sign in to comment.