Skip to content

Commit

Permalink
rollout fixups and correctly initialize unspecified controls
Browse files Browse the repository at this point in the history
  • Loading branch information
aftersomemath committed Nov 27, 2024
1 parent 41b3257 commit e6eab04
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 22 deletions.
2 changes: 1 addition & 1 deletion doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Python bindings
- Added ``bind`` method and removed id attribute from :ref:`mjSpec` objects. Using ids is error prone in scenarios of repeated attachment and
detachment. Python users are encouraged to use names for unique identification of model elements.
- Removed ``nroll`` argument from :ref:`rollout<PyRollout>` because its value can always be inferred.
- :ref:`rollout<PyRollout>` can now accept lists of MjModel of length ``nroll``. ``nroll`` argument deprecated because
- :ref:`rollout<PyRollout>` can now accept sequences of MjModel of length ``nroll``. ``nroll`` argument deprecated because
its value can always be inferred.

Bug fixes
Expand Down
2 changes: 1 addition & 1 deletion doc/python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ states and sensor values. The basic usage form is
state, sensordata = rollout.rollout(model, data, initial_state, control)
``model`` is either a single instance of MjModel or a list of compatible MjModel of length ``nroll``.
``model`` is either a single instance of MjModel or a sequence of compatible MjModel of length ``nroll``.
``initial_state`` is an ``nroll x nstate`` array, with ``nroll`` initial states of size ``nstate``, where
``nstate = mj_stateSize(model, mjtState.mjSTATE_FULLPHYSICS)`` is the size of the
:ref:`full physics state<geFullPhysics>`. ``control`` is a ``nroll x nstep x ncontrol`` array of controls. Controls are
Expand Down
34 changes: 18 additions & 16 deletions python/mujoco/rollout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,26 +73,28 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int n
if (!(control_spec & mjSTATE_XFRC_APPLIED)) {
mju_zero(d->xfrc_applied, 6*nbody);
}
if (!(control_spec & mjSTATE_MOCAP_POS)) {
for (int i = 0; i < nbody; i++) {
int id = m[0]->body_mocapid[i];
if (id >= 0) mju_copy3(d->mocap_pos+3*id, m[0]->body_pos+3*i);

// loop over rollouts
for (int r = 0; r < nroll; r++) {
// clear user inputs if unspecified
if (!(control_spec & mjSTATE_MOCAP_POS)) {
for (int i = 0; i < nbody; i++) {
int id = m[r]->body_mocapid[i];
if (id >= 0) mju_copy3(d->mocap_pos+3*id, m[r]->body_pos+3*i);
}
}
}
if (!(control_spec & mjSTATE_MOCAP_QUAT)) {
for (int i = 0; i < nbody; i++) {
int id = m[0]->body_mocapid[i];
if (id >= 0) mju_copy4(d->mocap_quat+4*id, m[0]->body_quat+4*i);
if (!(control_spec & mjSTATE_MOCAP_QUAT)) {
for (int i = 0; i < nbody; i++) {
int id = m[r]->body_mocapid[i];
if (id >= 0) mju_copy4(d->mocap_quat+4*id, m[r]->body_quat+4*i);
}
}
}
if (!(control_spec & mjSTATE_EQ_ACTIVE)) {
for (int i = 0; i < neq; i++) {
d->eq_active[i] = m[0]->eq_active0[i];
if (!(control_spec & mjSTATE_EQ_ACTIVE)) {
for (int i = 0; i < neq; i++) {
d->eq_active[i] = m[r]->eq_active0[i];
}
}
}

// loop over rollouts
for (int r = 0; r < nroll; r++) {
// set initial state
mj_setState(m[r], d, state0 + r*nstate, mjSTATE_FULLPHYSICS);

Expand Down
10 changes: 7 additions & 3 deletions python/mujoco/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""Roll out open-loop trajectories from initial states, get subsequent states and sensor values."""

from collections.abc import Sequence
from typing import Optional, Union

import mujoco
Expand All @@ -22,7 +23,7 @@
from numpy import typing as npt


def rollout(model: Union[mujoco.MjModel, list[mujoco.MjModel]],
def rollout(model: Union[mujoco.MjModel, Sequence[mujoco.MjModel]],
data: mujoco.MjData,
initial_state: npt.ArrayLike,
control: Optional[npt.ArrayLike] = None,
Expand All @@ -41,7 +42,7 @@ def rollout(model: Union[mujoco.MjModel, list[mujoco.MjModel]],
Allocates outputs if none are given.
Args:
model: An mjModel or a list of MjModel with the same size signature.
model: An mjModel or a sequence of MjModel with the same size signature.
data: An associated mjData instance.
initial_state: Array of initial states from which to roll out trajectories.
([nroll or 1] x nstate)
Expand Down Expand Up @@ -76,6 +77,9 @@ def rollout(model: Union[mujoco.MjModel, list[mujoco.MjModel]],
initial_warmstart, control, state, sensordata)
return state, sensordata

if not isinstance(model, mujoco.MjModel):
model = list(model)

# check control_spec
if control_spec & ~mujoco.mjtState.mjSTATE_USER.value:
raise ValueError('control_spec can only contain bits in mjSTATE_USER')
Expand Down Expand Up @@ -151,7 +155,7 @@ def rollout(model: Union[mujoco.MjModel, list[mujoco.MjModel]],
_check_trailing_dimension(nsensordata, sensordata=sensordata)

# tile input arrays/lists if required (singleton expansion)
model = model*nroll if len(model) == 1 else model
model = model * nroll if len(model) == 1 else model
initial_state = _tile_if_required(initial_state, nroll)
initial_warmstart = _tile_if_required(initial_warmstart, nroll)
control = _tile_if_required(control, nroll, nstep)
Expand Down
2 changes: 1 addition & 1 deletion python/mujoco/rollout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def test_threading(self):
def thread_initializer():
thread_local.data = mujoco.MjData(model)

model_list = [model]*nroll
model_list = [model] * nroll
def call_rollout(initial_state, control, state, sensordata):
rollout.rollout(model_list, thread_local.data, initial_state, control,
skip_checks=True,
Expand Down

0 comments on commit e6eab04

Please sign in to comment.