Skip to content

Commit

Permalink
[wip] Add state controller functions
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Dec 3, 2024
1 parent f61961c commit bd89276
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
28 changes: 25 additions & 3 deletions crazyflow/sim/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
self.states = default_state(n_worlds, n_drones, self.device)
self.controls = default_controls(n_worlds, n_drones, self.device)
self.params = default_params(n_worlds, n_drones, 0.025, J, J_INV, self.device)
self.data = SimData(states=self.states, controls=self.controls, params=self.params)
# Initialize MuJoCo world and data
self._xml_path = xml_path or self.default_path
self._spec, self._mj_model, self._mj_data, self._mjx_model, self._mjx_data = self.setup_mj()
Expand Down Expand Up @@ -232,9 +233,9 @@ def contacts(self, body: str | None = None) -> Array:
def _control_fn(self) -> Callable[[SimData], SimData]:
match self.control:
case Control.state:
return self._step_state_controller
return step_state_controller
case Control.attitude:
return self._step_attitude_controller
return step_attitude_controller
case _:
raise NotImplementedError(f"Control mode {self.control} not implemented")

Expand Down Expand Up @@ -293,7 +294,7 @@ def _step_emulate_firmware(self) -> SimControls:
return fused_masked_attitude2rpm(mask, self.states, self.controls, self.dt)

@staticmethod
def _sync_mjx(states: SimState, mjx_data: Data) -> Data:
def _sync_mjx(states: SimState, mjx_data: Data, mjx_model: Model) -> Data:
"""Sync the states to the MuJoCo data.
We initialize this function in Sim.setup() to compile it with the finalized MuJoCo model.
Expand Down Expand Up @@ -401,5 +402,26 @@ def contacts(geom_start: int, geom_count: int, data: Data) -> Array:
return data.contact.dist < 0 & (geom1_valid | geom2_valid)


def step_state_controller(data: SimData) -> SimData:

Check failure on line 405 in crazyflow/sim/core.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D103)

crazyflow/sim/core.py:405:5: D103 Missing docstring in public function
controls = data.controls
mask = controllable(data.steps, controls.steps, data.freq, controls.state_freq)
controls = commit_state_controls(mask, controls)
controls = state2attitude(mask, data.states, controls, 1 / data.freq)

Check failure on line 409 in crazyflow/sim/core.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

crazyflow/sim/core.py:409:16: F821 Undefined name `state2attitude`
return data.replace(controls=controls)


def controllable(step: Array, ctrl_step: Array, ctrl_freq: int, freq: int) -> Array:

Check failure on line 413 in crazyflow/sim/core.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D103)

crazyflow/sim/core.py:413:5: D103 Missing docstring in public function
return (step - ctrl_step) >= (freq / ctrl_freq)


def commit_state_controls(mask: Array, controls: SimControls) -> SimControls:

Check failure on line 417 in crazyflow/sim/core.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D103)

crazyflow/sim/core.py:417:5: D103 Missing docstring in public function
cmd, staged_cmd = controls.state, controls.staged_state
return controls.replace(state=jnp.where(mask[:, None, None], staged_cmd, cmd))


def step_attitude_controller(data: SimData) -> SimData:

Check failure on line 422 in crazyflow/sim/core.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D103)

crazyflow/sim/core.py:422:5: D103 Missing docstring in public function
pass


mjx_kinematics = jax.vmap(mjx.kinematics, in_axes=(None, 0))
mjx_collision = jax.vmap(mjx.collision, in_axes=(None, 0))
5 changes: 3 additions & 2 deletions crazyflow/sim/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def default_params(
) -> SimParams:
"""Create a default set of parameters for the simulation."""
mass = jnp.ones((n_worlds, n_drones, 1), device=device) * mass
J = jnp.tile(J[None, None, :, :], (n_worlds, n_drones, 1, 1))
J_INV = jnp.tile(J_INV[None, None, :, :], (n_worlds, n_drones, 1, 1))
j, j_inv = jnp.array(J, device=device), jnp.array(J_INV, device=device)
J = jnp.tile(j[None, None, :, :], (n_worlds, n_drones, 1, 1))
J_INV = jnp.tile(j_inv[None, None, :, :], (n_worlds, n_drones, 1, 1))
return SimParams(mass=mass, J=J, J_INV=J_INV)


Expand Down
8 changes: 0 additions & 8 deletions tests/unit/test_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,6 @@ def create_sim() -> Sim:
assert sim.controls.state.device == jax.devices(device)[0]


@pytest.mark.unit
@pytest.mark.parametrize("device", ["gpu", "cpu"])
def test_setup(device: str):
skip_unavailable_device(device)
sim = Sim(n_worlds=2, n_drones=3, device=device)
sim.setup()


@pytest.mark.unit
@pytest.mark.parametrize("device", ["gpu", "cpu"])
@pytest.mark.parametrize("physics", Physics)
Expand Down

0 comments on commit bd89276

Please sign in to comment.