From 742c1c63a6efe7b3e655249f34ab46feefcdb7c7 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 23 May 2024 12:18:02 +0200 Subject: [PATCH] DROP! Update Notebook --- examples/Parallel_computing.ipynb | 178 +++++++++++++++++++++++++++--- 1 file changed, 164 insertions(+), 14 deletions(-) diff --git a/examples/Parallel_computing.ipynb b/examples/Parallel_computing.ipynb index 60a03a60b..532257820 100644 --- a/examples/Parallel_computing.ipynb +++ b/examples/Parallel_computing.ipynb @@ -23,6 +23,7 @@ "# @title Imports and setup\n", "import sys\n", "import os\n", + "import pathlib\n", "\n", "# Deactivate GPU to avoid out of memory errors\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n", @@ -54,8 +55,15 @@ "\n", "import jaxsim.typing as jtp\n", "from jaxsim import logging\n", - "\n", - "from jaxsim.mujoco import MujocoVideoRecorder, MujocoModelHelper, RodModelToMjcf\n", + "from jaxsim.api.common import VelRepr\n", + "\n", + "from jaxsim.mujoco import (\n", + " MujocoVideoRecorder,\n", + " MujocoModelHelper,\n", + " RodModelToMjcf,\n", + " SdfToMjcf,\n", + " UrdfToMjcf,\n", + ")\n", "\n", "logging.set_logging_level(logging.LoggingLevel.INFO)\n", "logging.info(f\"Running on {jax.devices()}\")" @@ -79,14 +87,21 @@ "# @title Create a sphere model\n", "model_sdf_string = rod.Sdf(\n", " version=\"1.7\",\n", - " model=BoxBuilder(x=0.30, y=0.30, z=0.30, mass=1.0, name=\"box\")\n", + " # model=BoxBuilder(x=0.30, y=0.30, z=0.30, mass=1.0, name=\"box\")\n", + " model=SphereBuilder(radius=0.15, mass=1.0, name=\"sphere\")\n", " .build_model()\n", " .add_link()\n", " .add_inertial()\n", " .add_visual()\n", " .add_collision()\n", " .build(),\n", - ").serialize(pretty=True)" + ").serialize(pretty=True)\n", + "# import urllib\n", + "\n", + "# url = \"https://raw.githubusercontent.com/icub-tech-iit/ergocub-gazebo-simulations/master/models/stickBot/model.urdf\"\n", + "\n", + "# model_sdf_string = urllib.request.urlopen(url).read().decode()\n", + "# # model_sdf_string = pathlib.Path(\"/home/flferretti/git/element_rl-for-codesign/assets/model/hopper.sdf\")" ] }, { @@ -109,14 +124,45 @@ "source": [ "import jaxsim.api as js\n", "from jaxsim import integrators\n", + "import jaxsim\n", "\n", "dt = 0.001\n", "integration_time = 1500\n", "\n", "model = js.model.JaxSimModel.build_from_model_description(\n", - " model_description=model_sdf_string\n", + " model_description=model_sdf_string,\n", + " contact_model=js.rigid_contacts.RigidContacts(),\n", + " is_urdf=True,\n", + ")\n", + "\n", + "model = js.model.reduce(\n", + " model=model,\n", + " considered_joints=tuple(\n", + " [\n", + " j\n", + " for j in model.joint_names()\n", + " if \"camera\" not in j\n", + " and \"neck\" not in j\n", + " and \"wrist\" not in j\n", + " and \"thumb\" not in j\n", + " and \"index\" not in j\n", + " and \"middle\" not in j\n", + " and \"ring\" not in j\n", + " and \"pinkie\" not in j\n", + " and \"elbow\" not in j\n", + " and \"shoulder\" not in j\n", + " and \"hip\" not in j\n", + " and \"knee\" not in j\n", + " and \"lidar\" not in j\n", + " and \"torso\" not in j\n", + " ]\n", + " ),\n", + ")\n", + "model = js.model.reduce(model=model, considered_joints=tuple())\n", + "\n", + "data = js.data.JaxSimModelData.build(\n", + " model=model, velocity_representation=VelRepr.Inertial\n", ")\n", - "data = js.data.JaxSimModelData.build(model=model)\n", "integrator = integrators.fixed_step.RungeKutta4SO3.build(\n", " dynamics=js.ode.wrap_system_dynamics_for_integration(\n", " model=model,\n", @@ -124,6 +170,7 @@ " system_dynamics=js.ode.system_dynamics,\n", " ),\n", ")\n", + "# with jax.disable_jit():\n", "integrator_state = integrator.init(x0=data.state, t0=0.0, dt=dt)" ] }, @@ -133,12 +180,12 @@ "metadata": {}, "outputs": [], "source": [ - "mcjf_string, assets = RodModelToMjcf.convert(rod_model=model_sdf_string.model)\n", + "mcjf_string, assets = UrdfToMjcf.convert(urdf=model_sdf_string)\n", "mj_helper = MujocoModelHelper.build_from_xml(\n", " mjcf_description=mcjf_string, assets=assets\n", ")\n", "recorder = MujocoVideoRecorder(\n", - " model=mj_helper.model, assets=mj_helper.data, fps=int(1 / dt), width=640, height=480\n", + " model=mj_helper.model, data=mj_helper.data, fps=int(1 / dt), width=640, height=480\n", ")" ] }, @@ -224,8 +271,57 @@ "\n", " data = data.reset_base_position(base_position=pose)\n", " x_t_i = []\n", + " forces = []\n", + "\n", + " S = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T\n", + " τ = jnp.zeros(model.dofs())\n", + "\n", + " # l_foot = model.link_names().index(\"l_ankle_2\")\n", + " # r_foot = model.link_names().index(\"r_ankle_2\")\n", "\n", " for _ in range(integration_time):\n", + " F = []\n", + "\n", + " h = js.model.free_floating_bias_forces(model=model, data=data)\n", + "\n", + " M = js.model.free_floating_mass_matrix(model=model, data=data)\n", + "\n", + " J̇ν = js.model.link_bias_accelerations(model=model, data=data)\n", + "\n", + " M_inv = jnp.linalg.inv(M)\n", + "\n", + " # idxs = (0,) # (l_foot, r_foot)\n", + " # O_JL = jax.vmap(\n", + " # lambda body: js.link.jacobian(\n", + " # model=model,\n", + " # data=data,\n", + " # link_index=body,\n", + " # # output_vel_repr=VelRepr.Inertial,\n", + " # )\n", + " # )(jnp.array(idxs))\n", + " O_JL = js.link.jacobian(\n", + " model=model,\n", + " data=data,\n", + " link_index=0,\n", + " output_vel_repr=VelRepr.Mixed,\n", + " )\n", + "\n", + " # O_JL = O_JL.reshape(6 * len(idxs), 10)\n", + "\n", + " # W_H_L = js.link.transform(model=model, data=data, link_index=body)\n", + " # W_X_L = jaxsim.math.Adjoint.from_transform(W_H_L).T\n", + " # F = -jnp.linalg.inv(O_JL @ M_inv @ O_JL.T) @ (\n", + " # J̇ν[l_foot:r_foot+1].ravel() + O_JL @ M_inv @ (S @ τ - h)\n", + " # )\n", + " F = -jnp.linalg.inv(O_JL.squeeze() @ M_inv @ O_JL.squeeze().T) @ (\n", + " J̇ν[0] + O_JL.squeeze() @ M_inv @ (S @ τ - h)\n", + " )\n", + "\n", + " # F = F.reshape(-1, 6)\n", + "\n", + " # link_forces = jnp.zeros((model.number_of_links(), 6)).at[l_foot:r_foot+1].set(jnp.array(F))\n", + " link_forces = jnp.zeros((model.number_of_links(), 6)).at[0].set(jnp.array(F))\n", + "\n", " data, integrator_state = js.model.step(\n", " dt=dt,\n", " model=model,\n", @@ -233,11 +329,13 @@ " integrator=integrator,\n", " integrator_state=integrator_state,\n", " joint_forces=None,\n", - " link_forces=None,\n", + " link_forces=link_forces,\n", " )\n", + "\n", " x_t_i.append(data.base_position())\n", + " forces.append(F)\n", "\n", - " return x_t_i" + " return x_t_i, forces" ] }, { @@ -265,7 +363,7 @@ "now = time.perf_counter()\n", "\n", "# x_t = simulate_vectorized(data, integrator_state, poses[:, 0]).\n", - "x_t = simulate(data, integrator_state, poses)\n", + "x_t, forces = simulate(data, integrator_state, poses[:, 0])\n", "\n", "comp_time = time.perf_counter() - now\n", "\n", @@ -289,8 +387,13 @@ " mj_helper.set_base_position(pose)\n", " recorder.record_frame()\n", "\n", + "import datetime\n", + "\n", + "import mediapy as media\n", "\n", - "recorder.write_video(path=Path.cwd() / Path(\"sphere.mp4\"), exists_ok=True)" + "media.show_video(recorder.frames, fps=1 / dt)\n", + "\n", + "recorder.write_video(path=Path.cwd() / Path(f\"video_{datetime.datetime.now()}.mp4\"))" ] }, { @@ -309,13 +412,60 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "plt.plot(np.arange(len(x_t)) * dt, np.array(x_t)[:, :, 2])\n", + "plt.plot(np.arange(len(x_t[:])) * dt, np.array(x_t)[:, 2])\n", "plt.grid(True)\n", "plt.xlabel(\"Time [s]\")\n", "plt.ylabel(\"Height [m]\")\n", "plt.title(\"Trajectory of the model's base\")\n", "plt.show()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "forces = np.array([force for force in forces])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "forces.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "plt.plot(\n", + " np.arange(len(forces[:600])) * dt,\n", + " forces[:600],\n", + " label=[\"X\", \"Y\", \"Z\", \"Rx\", \"Ry\", \"Rz\"],\n", + ")\n", + "plt.grid(True)\n", + "plt.xlabel(\"Time [s]\")\n", + "plt.ylabel(\"Force [N]\")\n", + "plt.title(\"Contact forces\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -347,7 +497,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.8" } }, "nbformat": 4,