Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rigid contacts model #149

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
4e86135
Add `contact_state` attribute to `JaxSimModel`
flferretti Apr 22, 2024
c43ce11
Dynamically import the correct contact state
flferretti Apr 22, 2024
22cf6b0
Add `ContactParams`, `ContactModel`, `ContactsState` abstract classes
flferretti Apr 22, 2024
18205be
Use generic variable names in `js.data.JaxSimModelData`
flferretti Apr 22, 2024
66fd327
Fix import path in tests
flferretti Apr 22, 2024
cc130ea
Add rigid contacts model
flferretti Apr 22, 2024
ed7c1c7
Move `soft_contacts` and `rigid_contacts` to `api`
flferretti Apr 23, 2024
9856aaa
Fix circular import errors
flferretti Apr 23, 2024
dd3b8a8
Add `valid` method to contact base classes
flferretti Apr 26, 2024
31172c5
Fix `JaxSimModelData` and `RigidContacts` init
flferretti Apr 26, 2024
801ebba
Fix `ODEData` build
flferretti Apr 26, 2024
eebba88
Allow kwargs in `system_velocity_dynamics`
flferretti May 6, 2024
5bfd37e
Add computation of rigid contacts forces
flferretti May 6, 2024
2cf56e1
Remove unused import
flferretti May 6, 2024
4aee3c3
Import `soft_contacts` and `rigid_contacts` in api
flferretti May 6, 2024
cb1de62
Fix AD test
flferretti May 7, 2024
6a9759b
Inherit contact model during model reduction
flferretti May 8, 2024
42c81e7
Fallback to default soft contact model in ODEState
flferretti May 8, 2024
01acdde
Add Heaviside contact detection
flferretti May 8, 2024
edddc12
Update documentation
flferretti May 8, 2024
616fd3b
Allow to build ODEState without passing model
flferretti May 8, 2024
ec36695
Add PD stabilization of contact force response
flferretti May 8, 2024
64e5fc1
Use single contact jacobian
flferretti May 13, 2024
ce635c5
[wip] Add test notebook for contact models
flferretti May 15, 2024
06564b6
Expose quantities related to generic frames (#148)
xela-95 May 23, 2024
5dc6c8d
Import `soft_contacts` and `rigid_contacts` in api
flferretti May 6, 2024
b6ffb32
Update `ruff` and `black` version. Add nb pre-commit
flferretti May 23, 2024
13b38de
DROP! Update Notebook
flferretti May 23, 2024
ea1c5b7
Implement __hash__ and __eq__ methods of JaxSimModel
diegoferigo Jun 3, 2024
ffc7ccd
Update usage of HashlessObject in JaxSimModel
diegoferigo Jun 3, 2024
8abd08c
Maintain same terrain when reducing a model
xela-95 Jun 4, 2024
7e5413b
Add `contact_state` attribute to `JaxSimModel`
flferretti Apr 22, 2024
1520eec
Dynamically import the correct contact state
flferretti Apr 22, 2024
fe2829c
Add `ContactParams`, `ContactModel`, `ContactsState` abstract classes
flferretti Apr 22, 2024
fdfa5cc
Fix circular import errors
flferretti Apr 23, 2024
630016b
Inherit contact model during model reduction
flferretti May 8, 2024
e480be6
Fallback to default soft contact model in ODEState
flferretti May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.2.0
rev: 24.4.2
hooks:
- id: black
language_version: python3.11
Expand All @@ -21,6 +21,11 @@ repos:
name: isort (python)

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.2
rev: v0.4.4
hooks:
- id: ruff

- repo: https://github.com/kynan/nbstripout
rev: 0.7.1
hooks:
- id: nbstripout
6 changes: 6 additions & 0 deletions docs/modules/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ Contact
.. automodule:: jaxsim.api.contact
:members:

.. automodule:: jaxsim.api.soft_contact
:members:

.. automodule:: jaxsim.api.rigid_contact
:members:

KinDynParameters
~~~~~~~~~~~~~~~~

Expand Down
6 changes: 0 additions & 6 deletions docs/modules/rbda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,6 @@ Jacobians
.. automodule:: jaxsim.rbda.jacobian
:members:

Soft Contacts
~~~~~~~~~~~~~

.. automodule:: jaxsim.rbda.soft_contacts
:members:

Utilities
~~~~~~~~~

Expand Down
224 changes: 206 additions & 18 deletions examples/Parallel_computing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
"source": [
"# @title Imports and setup\n",
"import sys\n",
"import os\n",
"import pathlib\n",
"\n",
"# Deactivate GPU to avoid out of memory errors\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
"\n",
"from IPython.display import HTML, clear_output, display\n",
"\n",
Expand All @@ -46,10 +51,19 @@
"import jax.numpy as jnp\n",
"import jax_dataclasses\n",
"import rod\n",
"from rod.builder.primitives import SphereBuilder\n",
"from rod.builder.primitives import SphereBuilder, BoxBuilder\n",
"\n",
"import jaxsim.typing as jtp\n",
"from jaxsim import logging\n",
"from jaxsim.api.common import VelRepr\n",
"\n",
"from jaxsim.mujoco import (\n",
" MujocoVideoRecorder,\n",
" MujocoModelHelper,\n",
" RodModelToMjcf,\n",
" SdfToMjcf,\n",
" UrdfToMjcf,\n",
")\n",
"\n",
"logging.set_logging_level(logging.LoggingLevel.INFO)\n",
"logging.info(f\"Running on {jax.devices()}\")"
Expand All @@ -73,14 +87,21 @@
"# @title Create a sphere model\n",
"model_sdf_string = rod.Sdf(\n",
" version=\"1.7\",\n",
" model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n",
" # model=BoxBuilder(x=0.30, y=0.30, z=0.30, mass=1.0, name=\"box\")\n",
" model=SphereBuilder(radius=0.15, mass=1.0, name=\"sphere\")\n",
" .build_model()\n",
" .add_link()\n",
" .add_inertial()\n",
" .add_visual()\n",
" .add_collision()\n",
" .build(),\n",
").serialize(pretty=True)"
").serialize(pretty=True)\n",
"# import urllib\n",
"\n",
"# url = \"https://raw.githubusercontent.com/icub-tech-iit/ergocub-gazebo-simulations/master/models/stickBot/model.urdf\"\n",
"\n",
"# model_sdf_string = urllib.request.urlopen(url).read().decode()\n",
"# # model_sdf_string = pathlib.Path(\"/home/flferretti/git/element_rl-for-codesign/assets/model/hopper.sdf\")"
]
},
{
Expand All @@ -103,24 +124,71 @@
"source": [
"import jaxsim.api as js\n",
"from jaxsim import integrators\n",
"import jaxsim\n",
"\n",
"dt = 0.001\n",
"integration_time = 1500\n",
"\n",
"model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_sdf_string\n",
" model_description=model_sdf_string,\n",
" contact_model=js.rigid_contacts.RigidContacts(),\n",
" is_urdf=True,\n",
")\n",
"\n",
"model = js.model.reduce(\n",
" model=model,\n",
" considered_joints=tuple(\n",
" [\n",
" j\n",
" for j in model.joint_names()\n",
" if \"camera\" not in j\n",
" and \"neck\" not in j\n",
" and \"wrist\" not in j\n",
" and \"thumb\" not in j\n",
" and \"index\" not in j\n",
" and \"middle\" not in j\n",
" and \"ring\" not in j\n",
" and \"pinkie\" not in j\n",
" and \"elbow\" not in j\n",
" and \"shoulder\" not in j\n",
" and \"hip\" not in j\n",
" and \"knee\" not in j\n",
" and \"lidar\" not in j\n",
" and \"torso\" not in j\n",
" ]\n",
" ),\n",
")\n",
"model = js.model.reduce(model=model, considered_joints=tuple())\n",
"\n",
"data = js.data.JaxSimModelData.build(\n",
" model=model, velocity_representation=VelRepr.Inertial\n",
")\n",
"data = js.data.JaxSimModelData.build(model=model)\n",
"integrator = integrators.fixed_step.RungeKutta4SO3.build(\n",
" dynamics=js.ode.wrap_system_dynamics_for_integration(\n",
" model=model,\n",
" data=data,\n",
" system_dynamics=js.ode.system_dynamics,\n",
" ),\n",
")\n",
"# with jax.disable_jit():\n",
"integrator_state = integrator.init(x0=data.state, t0=0.0, dt=dt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mcjf_string, assets = UrdfToMjcf.convert(urdf=model_sdf_string)\n",
"mj_helper = MujocoModelHelper.build_from_xml(\n",
" mjcf_description=mcjf_string, assets=assets\n",
")\n",
"recorder = MujocoVideoRecorder(\n",
" model=mj_helper.model, data=mj_helper.data, fps=int(1 / dt), width=640, height=480\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -138,11 +206,11 @@
"metadata": {},
"outputs": [],
"source": [
"data = data.replace(\n",
" soft_contacts_params=js.contact.estimate_good_soft_contacts_parameters(\n",
" model, number_of_active_collidable_points_steady_state=3\n",
" )\n",
")"
"# data = data.replace(\n",
"# soft_contacts_params=js.contact.estimate_good_soft_contacts_parameters(\n",
"# model, number_of_active_collidable_points_steady_state=3\n",
"# )\n",
"# )"
]
},
{
Expand All @@ -159,8 +227,8 @@
"outputs": [],
"source": [
"# Primary Calculations\n",
"envs_per_row = 4 # @slider(2, 10, 1)\n",
"\n",
"envs_per_row = 1 # @slider(2, 10, 1)\n",
"initial_height = 0.7\n",
"env_spacing = 0.5\n",
"edge_len = env_spacing * (2 * envs_per_row - 1)\n",
"\n",
Expand All @@ -171,7 +239,7 @@
" xx, yy = jnp.meshgrid(edge, edge)\n",
"\n",
" poses = [\n",
" [[xx[i, j], yy[i, j], 0.2 + 0.1 * (i * envs_per_row + j)], [0, 0, 0]]\n",
" [[xx[i, j], yy[i, j], initial_height], [0, 0, 0]]\n",
" for i in range(xx.shape[0])\n",
" for j in range(yy.shape[0])\n",
" ]\n",
Expand Down Expand Up @@ -203,20 +271,71 @@
"\n",
" data = data.reset_base_position(base_position=pose)\n",
" x_t_i = []\n",
" forces = []\n",
"\n",
" S = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T\n",
" τ = jnp.zeros(model.dofs())\n",
"\n",
" # l_foot = model.link_names().index(\"l_ankle_2\")\n",
" # r_foot = model.link_names().index(\"r_ankle_2\")\n",
"\n",
" for _ in range(integration_time):\n",
" F = []\n",
"\n",
" h = js.model.free_floating_bias_forces(model=model, data=data)\n",
"\n",
" M = js.model.free_floating_mass_matrix(model=model, data=data)\n",
"\n",
" J̇ν = js.model.link_bias_accelerations(model=model, data=data)\n",
"\n",
" M_inv = jnp.linalg.inv(M)\n",
"\n",
" # idxs = (0,) # (l_foot, r_foot)\n",
" # O_JL = jax.vmap(\n",
" # lambda body: js.link.jacobian(\n",
" # model=model,\n",
" # data=data,\n",
" # link_index=body,\n",
" # # output_vel_repr=VelRepr.Inertial,\n",
" # )\n",
" # )(jnp.array(idxs))\n",
" O_JL = js.link.jacobian(\n",
" model=model,\n",
" data=data,\n",
" link_index=0,\n",
" output_vel_repr=VelRepr.Mixed,\n",
" )\n",
"\n",
" # O_JL = O_JL.reshape(6 * len(idxs), 10)\n",
"\n",
" # W_H_L = js.link.transform(model=model, data=data, link_index=body)\n",
" # W_X_L = jaxsim.math.Adjoint.from_transform(W_H_L).T\n",
" # F = -jnp.linalg.inv(O_JL @ M_inv @ O_JL.T) @ (\n",
" # J̇ν[l_foot:r_foot+1].ravel() + O_JL @ M_inv @ (S @ τ - h)\n",
" # )\n",
" F = -jnp.linalg.inv(O_JL.squeeze() @ M_inv @ O_JL.squeeze().T) @ (\n",
" J̇ν[0] + O_JL.squeeze() @ M_inv @ (S @ τ - h)\n",
" )\n",
"\n",
" # F = F.reshape(-1, 6)\n",
"\n",
" # link_forces = jnp.zeros((model.number_of_links(), 6)).at[l_foot:r_foot+1].set(jnp.array(F))\n",
" link_forces = jnp.zeros((model.number_of_links(), 6)).at[0].set(jnp.array(F))\n",
"\n",
" data, integrator_state = js.model.step(\n",
" dt=dt,\n",
" model=model,\n",
" data=data,\n",
" integrator=integrator,\n",
" integrator_state=integrator_state,\n",
" joint_forces=None,\n",
" link_forces=None,\n",
" link_forces=link_forces,\n",
" )\n",
"\n",
" x_t_i.append(data.base_position())\n",
" forces.append(F)\n",
"\n",
" return x_t_i"
" return x_t_i, forces"
]
},
{
Expand All @@ -238,12 +357,13 @@
"outputs": [],
"source": [
"# Define a function to simulate multiple model instances\n",
"simulate_vectorized = jax.vmap(simulate, in_axes=(None, None, 0))\n",
"# simulate_vectorized = jax.vmap(simulate, in_axes=(None, None, 0))\n",
"\n",
"# Run and time the simulation\n",
"now = time.perf_counter()\n",
"\n",
"x_t = simulate_vectorized(data, integrator_state, poses[:, 0])\n",
"# x_t = simulate_vectorized(data, integrator_state, poses[:, 0]).\n",
"x_t, forces = simulate(data, integrator_state, poses[:, 0])\n",
"\n",
"comp_time = time.perf_counter() - now\n",
"\n",
Expand All @@ -255,6 +375,27 @@
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"for pose in x_t:\n",
" mj_helper.set_base_position(pose)\n",
" recorder.record_frame()\n",
"\n",
"import datetime\n",
"\n",
"import mediapy as media\n",
"\n",
"media.show_video(recorder.frames, fps=1 / dt)\n",
"\n",
"recorder.write_video(path=Path.cwd() / Path(f\"video_{datetime.datetime.now()}.mp4\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -271,13 +412,60 @@
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"plt.plot(np.arange(len(x_t)) * dt, np.array(x_t)[:, :, 2])\n",
"plt.plot(np.arange(len(x_t[:])) * dt, np.array(x_t)[:, 2])\n",
"plt.grid(True)\n",
"plt.xlabel(\"Time [s]\")\n",
"plt.ylabel(\"Height [m]\")\n",
"plt.title(\"Trajectory of the model's base\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"forces = np.array([force for force in forces])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"forces.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"plt.plot(\n",
" np.arange(len(forces[:600])) * dt,\n",
" forces[:600],\n",
" label=[\"X\", \"Y\", \"Z\", \"Rx\", \"Ry\", \"Rz\"],\n",
")\n",
"plt.grid(True)\n",
"plt.xlabel(\"Time [s]\")\n",
"plt.ylabel(\"Force [N]\")\n",
"plt.title(\"Contact forces\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Loading
Loading