Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into tsung-domain-randomiz…
Browse files Browse the repository at this point in the history
…ation
  • Loading branch information
middleyuan committed Dec 2, 2024
2 parents 2798130 + 7a8ef09 commit ad996d2
Show file tree
Hide file tree
Showing 14 changed files with 647 additions and 75 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: Ruff
on: [ push, pull_request ]
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/ruff-action@v1
22 changes: 22 additions & 0 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: Testing # Skips RL tests because stable-baselines3 comes with a lot of heavy-weight dependencies

on: [push]

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: mamba-org/setup-micromamba@v1
with:
micromamba-version: '2.0.2-1' # any version from https://github.com/mamba-org/micromamba-releases
environment-name: test-env
init-shell: bash
create-args: python=3.11
cache-environment: true
- name: Install dependencies and package
run: pip install .[test]
shell: micromamba-shell {0}
- name: Test with pytest
run: pytest tests --cov=crazyflow
shell: micromamba-shell {0}
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
# crazyflow
Fast, parallelizable simulations of Crazyflies with JAX and MuJoCo.

[![Python Version]][Python Version URL] [![Ruff Check]][Ruff Check URL] [![Documentation Status]][Documentation Status URL] [![Tests]][Tests URL]

[Python Version]: https://img.shields.io/badge/python-3.10+-blue.svg
[Python Version URL]: https://www.python.org

[Ruff Check]: https://github.com/utiasDSL/crazyflow/actions/workflows/ruff.yml/badge.svg?style=flat-square
[Ruff Check URL]: https://github.com/utiasDSL/crazyflow/actions/workflows/ruff.yml

[Documentation Status]: https://readthedocs.org/projects/crazyflow/badge/?version=latest
[Documentation Status URL]: https://crazyflow.readthedocs.io/en/latest/?badge=latest

[Tests]: https://github.com/utiasDSL/crazyflow/actions/workflows/testing.yml/badge.svg
[Tests URL]: https://github.com/utiasDSL/crazyflow/actions/workflows/testing.yml


## Architecture

<img src="/docs/img/architecture.png" width="75%" alt="Architecture">


## Known Issues
- `"RuntimeError: MUJOCO_PATH environment variable is not set"` upon installing this package. This error can be resolved by using `venv` instead of `conda`. Somtimes the `mujoco` install can [fail with `conda`](https://github.com/google-deepmind/mujoco/issues/1004).
- If using `zsh` don't forget to escape brackets when installing additional dependencies: `pip install .\[gpu\]`.
114 changes: 89 additions & 25 deletions benchmark/main.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,118 @@
import time

import gymnasium
import jax
import jax.numpy as jnp
import numpy as np
from ml_collections import config_dict

import crazyflow # noqa: F401, ensure gymnasium envs are registered
from crazyflow.sim.core import Sim


def profile_step(sim: Sim, n_steps: int, device: str):
def analyze_timings(times: list[float], n_steps: int, n_worlds: int, freq: float) -> None:
"""Analyze timing results and print performance metrics."""
if not times:
raise ValueError("The list of timing results is empty.")

tmin, idx_tmin = np.min(times), np.argmin(times)
tmax, idx_tmax = np.max(times), np.argmax(times)

# Check for significant variance
if tmax / tmin > 5:
print("Warning: step time varies by more than 5x. Is JIT compiling during the benchmark?")
print(f"Times: max {tmax:.2e} @ {idx_tmax}, min {tmin:.2e} @ {idx_tmin}")

# Performance metrics
n_frames = n_steps * n_worlds # Number of frames simulated
total_time = np.sum(times)
avg_step_time = np.mean(times)
step_time_std = np.std(times)
fps = n_frames / total_time
real_time_factor = (n_steps / freq) * n_worlds / total_time

print(
f"Avg step time: {avg_step_time:.2e}s, std: {step_time_std:.2e}"
f"\nFPS: {fps:.3e}, Real time factor: {real_time_factor:.2e}"
)


def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
"""Profile the Crazyflow gym environment step performance."""
times = []
device = jax.devices(device)[0]

envs = gymnasium.make_vec(
"DroneReachPos-v0",
max_episode_steps=200,
return_datatype="numpy",
num_envs=sim_config.n_worlds,
**sim_config,
)

# Action for going up (in attitude control)
action = np.zeros((sim_config.n_worlds, 4), dtype=np.float32)
action[..., 0] = -0.3

# Step through env once to ensure JIT compilation
envs.reset_all(seed=42)
envs.step(action)
envs.step(action)

jax.block_until_ready(envs.unwrapped.sim.states.pos) # Ensure JIT compiled dynamics

# Step through the environment
for _ in range(n_steps):
tstart = time.perf_counter()
envs.step(action)
jax.block_until_ready(envs.unwrapped.sim.states.pos)
times.append(time.perf_counter() - tstart)

envs.close()

analyze_timings(times, n_steps, envs.unwrapped.sim.n_worlds, envs.unwrapped.sim.freq)


def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
"""Profile the Crazyflow simulator step performance."""
sim = Sim(**sim_config)
times = []
device = jax.devices(device)[0]

cmd = jnp.zeros((sim.n_worlds, sim.n_drones, 4), device=device)
cmd = cmd.at[0, 0, 0].set(1)

sim.reset()
sim.attitude_control(cmd)
sim.step()
sim.reset()
jax.block_until_ready(sim._mjx_data) # Ensure JIT compiled dynamics
jax.block_until_ready(sim.states.pos) # Ensure JIT compiled dynamics

for _ in range(n_steps):
tstart = time.perf_counter()
sim.attitude_control(cmd)
sim.step()
jax.block_until_ready(sim._mjx_data)
jax.block_until_ready(sim.states.pos)
times.append(time.perf_counter() - tstart)
if max(times) / min(times) > 5:
tmin, idx_tmin = np.min(times), np.argmin(times)
tmax, idx_tmax = np.max(times), np.argmax(times)
print("Warning: step time varies by more than 5x. Is JIT compiling during the benchmark?")
print(f"Times: max {tmax:.2e}@{idx_tmax}, min {tmin:.2e}@{idx_tmin}")
n_frames = n_steps * sim.n_worlds # Number of frames simulated
total_time = np.sum(times)
real_time_factor = (n_steps / sim.freq) * sim.n_worlds / total_time
print(
f"Avg step time: {np.mean(times):.2e}s, std: {np.std(times):.2e}"
f"\nFPS: {n_frames / total_time:.3e}, Real time factor: {real_time_factor:.2e}"
)

analyze_timings(times, n_steps, sim.n_worlds, sim.freq)


def main():
"""Main entry point for profiling."""
device = "cpu"
sim = Sim(
n_worlds=1,
n_drones=1,
physics="sys_id",
control="attitude",
controller="emulatefirmware",
device=device,
)
profile_step(sim, 100, device)
sim_config = config_dict.ConfigDict()
sim_config.n_worlds = 1
sim_config.n_drones = 1
sim_config.physics = "sys_id"
sim_config.control = "attitude"
sim_config.controller = "emulatefirmware"
sim_config.device = device

print("Simulator performance")
profile_step(sim_config, 100, device)

print("\nGymnasium environment performance")
profile_gym_env_step(sim_config, 100, device)


if __name__ == "__main__":
Expand Down
77 changes: 56 additions & 21 deletions benchmark/performance.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import gymnasium
import jax
import numpy as np
from ml_collections import config_dict
from pyinstrument import Profiler
from pyinstrument.renderers.html import HTMLRenderer

import crazyflow # noqa: F401, ensure gymnasium envs are registered
from crazyflow.sim.core import Sim

if TYPE_CHECKING:
from crazyflow.gymnasium_envs import CrazyflowEnvReachGoal


def profile_step(sim: Sim, n_steps: int, device: str):
def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
sim = Sim(**sim_config)
device = jax.devices(device)[0]
ndim = 13 if sim.control == "state" else 4
control_fn = sim.state_control if sim.control == "state" else sim.attitude_control
Expand All @@ -15,7 +26,6 @@ def profile_step(sim: Sim, n_steps: int, device: str):
sim.reset()
control_fn(cmd)
sim.step()
control_fn(cmd)
sim.step()
sim.reset()
jax.block_until_ready(sim.states.pos)
Expand All @@ -33,27 +43,52 @@ def profile_step(sim: Sim, n_steps: int, device: str):
renderer.open_in_browser(profiler.last_session)


def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
device = jax.devices(device)[0]

envs: CrazyflowEnvReachGoal = gymnasium.make_vec(
"DroneReachPos-v0",
max_episode_steps=200,
return_datatype="numpy",
num_envs=sim_config.n_worlds,
**sim_config,
)

# Action for going up (in attitude control)
action = np.zeros((sim_config.n_worlds, 4), dtype=np.float32)
action[..., 0] = -0.3

# Step through env once to ensure JIT compilation.
envs.reset_all(seed=42)
envs.step(action)
envs.step(action) # Ensure all paths have been taken at least once
envs.reset_all(seed=42)

profiler = Profiler()
profiler.start()

for _ in range(n_steps):
envs.step(action)
jax.block_until_ready(envs.unwrapped.sim.states.pos)

profiler.stop()
renderer = HTMLRenderer()
renderer.open_in_browser(profiler.last_session)
envs.close()


def main():
device = "cpu"
sim = Sim(
n_worlds=1,
n_drones=1,
physics="analytical",
control="state",
controller="emulatefirmware",
device=device,
)
profile_step(sim, 1000, device)
# old | new
# sys_id + attitude:
# 0.61 reset, 0.61 step | 0.61 reset, 0.61 step
# sys_id + state:
# 14.53 step, 0.53 reset | 0.75 reset, 0.88 step

# Analytical + attitude:
# 0.75 reset, 9.38 step | 0.75 reset, 0.89 step
# Analytical + state:
# 0.75 reset, 15.1 step | 0.75 reset, 0.5 step
sim_config = config_dict.ConfigDict()
sim_config.n_worlds = 1
sim_config.n_drones = 1
sim_config.physics = "analytical"
sim_config.control = "attitude"
sim_config.controller = "emulatefirmware"
sim_config.device = device

profile_step(sim_config, 1000, device)
profile_gym_env_step(sim_config, 1000, device)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions crazyflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import crazyflow.gymnasium_envs # noqa: F401, ensure gymnasium envs are registered
15 changes: 15 additions & 0 deletions crazyflow/gymnasium_envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from gymnasium.envs.registration import register

from crazyflow.gymnasium_envs.crazyflow import CrazyflowEnvReachGoal, CrazyflowEnvTargetVelocity

__all__ = ["CrazyflowEnvReachGoal", "CrazyflowEnvTargetVelocity"]

register(
id="DroneReachPos-v0",
vector_entry_point="crazyflow.gymnasium_envs.crazyflow:CrazyflowEnvReachGoal",
)

register(
id="DroneReachVel-v0",
vector_entry_point="crazyflow.gymnasium_envs.crazyflow:CrazyflowEnvTargetVelocity",
)
Loading

0 comments on commit ad996d2

Please sign in to comment.