Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change path for MJX Predictive Sampling Bimanual task model #306

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions python/mujoco_mpc/mjx/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# MJX Predictive Sampling

Run `handover` example:

```sh
python visualize.py
```

##
Requires: mujoco, mujoco-mjx, jax[cuda], matplotlib, mediapy (Python), ffmpeg
25 changes: 9 additions & 16 deletions python/mujoco_mpc/mjx/tasks/bimanual/handover.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.
# ==============================================================================

from etils import epath
from typing import Callable
from pathlib import Path
import jax
from jax import numpy as jp
import mujoco
from mujoco import mjx
from mujoco_mpc.mjx import predictive_sampling

CostFn = Callable[[mjx.Model, mjx.Data], jax.Array]

def bring_to_target(m: mjx.Model, d: mjx.Data) -> jax.Array:
"""Returns cost for bimanual bring to target task."""
Expand Down Expand Up @@ -48,22 +49,14 @@ def bring_to_target(m: mjx.Model, d: mjx.Data) -> jax.Array:


def get_models_and_cost_fn() -> (
tuple[mujoco.MjModel, mujoco.MjModel, predictive_sampling.CostFn]
tuple[mujoco.MjModel, mujoco.MjModel, CostFn]
):
"""Returns a tuple of the model and the cost function."""
path = epath.Path(
'build/mjpc/tasks/bimanual/'
model_path = (
Path(__file__).parent.parent.parent
/ "../../../build/mjpc/tasks/bimanual/mjx_scene.xml"
)
model_file_name = 'mjx_scene.xml'
xml = (path / model_file_name).read_text()
assets = {}
for f in path.glob('*.xml'):
if f.name == model_file_name:
continue
assets[f.name] = f.read_bytes()
for f in (path / 'assets').glob('*'):
assets[f.name] = f.read_bytes()
sim_model = mujoco.MjModel.from_xml_string(xml, assets)
plan_model = mujoco.MjModel.from_xml_string(xml, assets)
sim_model = mujoco.MjModel.from_xml_path(str(model_path))
plan_model = mujoco.MjModel.from_xml_path(str(model_path))
plan_model.opt.timestep = 0.01 # incidentally, already the case
return sim_model, plan_model, bring_to_target
4 changes: 2 additions & 2 deletions python/mujoco_mpc/mjx/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import matplotlib.pyplot as plt
import mediapy
import mujoco
from mujoco_mpc.mjx import predictive_sampling
from mujoco_mpc.mjx.tasks.bimanual import handover
import predictive_sampling
from tasks.bimanual import handover
# %%
nsteps = 500
steps_per_plan = 4
Expand Down
Loading