diff --git a/examples/Parallel_computing.ipynb b/examples/Parallel_computing.ipynb index 38a1482a6..f591a9df9 100644 --- a/examples/Parallel_computing.ipynb +++ b/examples/Parallel_computing.ipynb @@ -19,10 +19,11 @@ "metadata": {}, "outputs": [], "source": [ - "# @title Imports and setup\n", - "from IPython.display import clear_output, HTML, display\n", + "#@title Imports and setup\n", "import sys\n", "\n", + "from IPython.display import HTML, clear_output, display\n", + "\n", "IS_COLAB = \"google.colab\" in sys.modules\n", "\n", "# Install JAX and Gazebo\n", @@ -30,11 +31,20 @@ " !{sys.executable} -m pip install -U -q jaxsim\n", " !apt -qq update && apt install -qq --no-install-recommends gazebo\n", " clear_output()\n", + "else:\n", + " # Set environment variable to avoid GPU out of memory errors\n", + " %env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n", + "\n", + "import time\n", + "from typing import Dict, Tuple\n", "\n", "import jax\n", "import jax.numpy as jnp\n", + "import jax_dataclasses\n", "import rod\n", "from rod.builder.primitives import SphereBuilder\n", + "\n", + "import jaxsim.typing as jtp\n", "from jaxsim import logging\n", "\n", "logging.set_logging_level(logging.LoggingLevel.INFO)\n", @@ -45,7 +55,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We will use a simple cartpole model for this example. The cartpole model is a 2D model with a cart that can move horizontally and a pole that can rotate around the cart. The state of the cartpole is given by the position of the cart, the angle of the pole, the velocity of the cart, and the angular velocity of the pole. The control input is the horizontal force applied to the cart." + "We will use a simple sphere model to simulate a free-falling body. The spheres set will be composed of 9 spheres, each with a different position. The spheres will be simulated in parallel, and the simulation will be run for 3000 steps corresponding to 3 seconds of simulation." ] }, { @@ -54,7 +64,7 @@ "metadata": {}, "outputs": [], "source": [ - "# @title Create a sphere model\n", + "#@title Create a sphere model\n", "model_sdf_string = rod.Sdf(\n", " version=\"1.7\",\n", " model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n", @@ -67,26 +77,6 @@ ").serialize(pretty=True)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "JAXsim offers a simple high-level API in order to extract quantities needed in most robotic applications. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from jaxsim.high_level.model import Model\n", - "\n", - "model = Model.build_from_model_description(\n", - " model_description=model_sdf_string, is_urdf=True\n", - ")" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -100,10 +90,10 @@ "metadata": {}, "outputs": [], "source": [ - "from jaxsim.simulation.ode_integration import IntegratorType\n", - "from jaxsim.simulation.simulator import JaxSim, SimulatorData, StepData\n", "from jaxsim.high_level.model import VelRepr\n", "from jaxsim.physics.algos.soft_contacts import SoftContactsParams\n", + "from jaxsim.simulation.ode_integration import IntegratorType\n", + "from jaxsim.simulation.simulator import JaxSim, SimulatorData, StepData\n", "\n", "# Simulation Step Parameters\n", "integration_time = 3.0 # seconds\n", @@ -122,15 +112,16 @@ "\n", "\n", "# Add model to simulator\n", - "\n", - "model = simulator.insert_model_from_description(model_description=model_sdf_string).mutable(validate=True)" + "model = simulator.insert_model_from_description(\n", + " model_description=model_sdf_string\n", + ").mutable(validate=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Let's create a position vector for a 8x8 grid of sphere positions." + "Let's create a position vector for a 3x3 grid. Every sphere will be placed at a different height." ] }, { @@ -140,41 +131,33 @@ "outputs": [], "source": [ "# Primary Calculations\n", - "radius = 0.1\n", - "envs_per_row = 8\n", - "num_envs = envs_per_row**2\n", - "edge_len = env_spacing * envs_per_row + env_spacing * (envs_per_row - 1)\n", + "env_spacing = 0.5\n", + "envs_per_row = 3\n", + "edge_len = env_spacing * (2 * envs_per_row - 1)\n", "\n", "\n", "# Create Grid\n", - "def grid(num_envs, edge_len, envs_per_row):\n", - " poses = []\n", - " x = 0\n", - " y = 0\n", - "\n", - " for env in range(num_envs):\n", - " x = jnp.linspace(-edge_len, edge_len, envs_per_row)\n", - " y = jnp.linspace(-edge_len, edge_len, envs_per_row)\n", - " xx, yy = jnp.meshgrid(x, y)\n", - "\n", - " poses = [\n", - " [[xx[i, j], yy[i, j], 1], [0, 0, 0]]\n", - " for i in range(xx.shape[0])\n", - " for j in range(yy.shape[0])\n", - " ]\n", + "def grid(edge_len, envs_per_row):\n", + " edge = jnp.linspace(-edge_len, edge_len, envs_per_row)\n", + " xx, yy = jnp.meshgrid(edge, edge)\n", + "\n", + " poses = [\n", + " [[xx[i, j], yy[i, j], 0.2 + 0.1 * (i * envs_per_row + j)], [0, 0, 0]]\n", + " for i in range(xx.shape[0])\n", + " for j in range(yy.shape[0])\n", + " ]\n", "\n", " return jnp.array(poses)\n", "\n", "\n", - "poses = grid(num_envs, edge_len, envs_per_row)\n", - "model.reset_joint_positions(positions=random_positions)" + "poses = grid(edge_len, envs_per_row)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The visualization is done using mujoco package, to be able to render easily the animations also on Google Colab. If you are not interested in the animation, execute but do not try to understand deeply this cell." + "In order to parallelize the simulation, we first need to define a function `simulate` for a single element of the batch." ] }, { @@ -183,135 +166,9 @@ "metadata": {}, "outputs": [], "source": [ - "# @title Set up MuJoCo renderer\n", - "!{sys.executable} -m pip install -U -q mujoco\n", - "!{sys.executable} -m pip install -q mediapy\n", - "\n", - "import mediapy as media\n", - "import tempfile\n", - "import xml.etree.ElementTree as ET\n", - "import numpy as np\n", - "\n", - "import distutils.util\n", - "import os\n", - "import subprocess\n", - "\n", - "if IS_COLAB:\n", - " if subprocess.run(\"ffmpeg -version\", shell=True).returncode:\n", - " !command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n", - " clear_output()\n", - "\n", - " if subprocess.run(\"nvidia-smi\").returncode:\n", - " raise RuntimeError(\n", - " \"Cannot communicate with GPU. \"\n", - " \"Make sure you are using a GPU Colab runtime. \"\n", - " \"Go to the Runtime menu and select Choose runtime type.\"\n", - " )\n", - "\n", - " # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n", - " # This is usually installed as part of an Nvidia driver package, but the Colab\n", - " # kernel doesn't install its driver via APT, and as a result the ICD is missing.\n", - " # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\n", - " NVIDIA_ICD_CONFIG_PATH = \"/usr/share/glvnd/egl_vendor.d/10_nvidia.json\"\n", - " if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n", - " with open(NVIDIA_ICD_CONFIG_PATH, \"w\") as f:\n", - " f.write(\n", - " \"\"\"{\n", - " \"file_format_version\" : \"1.0.0\",\n", - " \"ICD\" : {\n", - " \"library_path\" : \"libEGL_nvidia.so.0\"\n", - " }\n", - " }\n", - " \"\"\"\n", - " )\n", - "\n", - "%env MUJOCO_GL=egl\n", - "\n", - "try:\n", - " import mujoco\n", - "except Exception as e:\n", - " raise e from RuntimeError(\n", - " \"Something went wrong during installation. Check the shell output above \"\n", - " \"for more information.\\n\"\n", - " \"If using a hosted Colab runtime, make sure you enable GPU acceleration \"\n", - " 'by going to the Runtime menu and selecting \"Choose runtime type\".'\n", - " )\n", - "\n", - "\n", - "def load_mujoco_model_with_camera(xml_string, camera_pos, camera_xyaxes):\n", - " def to_mjcf_string(list_to_str):\n", - " return \" \".join(map(str, list_to_str))\n", - "\n", - " mj_model_raw = mujoco.MjModel.from_xml_string(model_urdf_string)\n", - " path_temp_xml = tempfile.NamedTemporaryFile(mode=\"w+\")\n", - " mujoco.mj_saveLastXML(path_temp_xml.name, mj_model_raw)\n", - " # Add camera in mujoco model\n", - " tree = ET.parse(path_temp_xml)\n", - " for elem in tree.getroot().iter(\"worldbody\"):\n", - " worldbody_elem = elem\n", - " camera_elem = ET.Element(\"camera\")\n", - " # Set attributes\n", - " camera_elem.set(\"name\", \"side\")\n", - " camera_elem.set(\"pos\", to_mjcf_string(camera_pos))\n", - " camera_elem.set(\"xyaxes\", to_mjcf_string(camera_xyaxes))\n", - " camera_elem.set(\"mode\", \"fixed\")\n", - " worldbody_elem.append(camera_elem)\n", - "\n", - " # Save new model\n", - " mujoco_xml_with_camera = ET.tostring(tree.getroot(), encoding=\"unicode\")\n", - " mj_model = mujoco.MjModel.from_xml_string(mujoco_xml_with_camera)\n", - " return mj_model\n", - "\n", - "\n", - "def from_jaxsim_to_mujoco_pos(jaxsim_jointpos, mjmodel, jaxsimmodel):\n", - " mujocoqposaddr2jaxindex = {}\n", - " for jaxjnt in jaxsimmodel.joints():\n", - " jntname = jaxjnt.name()\n", - " mujocoqposaddr2jaxindex[mjmodel.joint(jntname).qposadr[0]] = jaxjnt.index() - 1\n", - "\n", - " mujoco_jointpos = jaxsim_jointpos\n", - " for i in range(0, len(mujoco_jointpos)):\n", - " mujoco_jointpos[i] = jaxsim_jointpos[mujocoqposaddr2jaxindex[i]]\n", - "\n", - " return mujoco_jointpos\n", - "\n", - "\n", - "# To get a good camera location, you can use \"Copy camera\" functionality in MuJoCo GUI\n", - "mj_model = load_mujoco_model_with_camera(\n", - " model_urdf_string,\n", - " [3.954, 3.533, 2.343],\n", - " [-0.594, 0.804, -0.000, -0.163, -0.120, 0.979],\n", - ")\n", - "renderer = mujoco.Renderer(mj_model, height=480, width=640)\n", - "\n", - "\n", - "def get_image(camera, mujocojointpos) -> np.ndarray:\n", - " \"\"\"Renders the environment state.\"\"\"\n", - " # Copy joint data in mjdata state\n", - " d = mujoco.MjData(mj_model)\n", - " d.qpos = mujocojointpos\n", + "from jaxsim.simulation import simulator_callbacks\n", "\n", - " # Forward kinematics\n", - " mujoco.mj_forward(mj_model, d)\n", "\n", - " # use the mjData object to update the renderer\n", - " renderer.update_scene(d, camera=camera)\n", - " return renderer.render()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In order to parallelize the simulation, we need to define a function for a single element of the batch. This function will be called in parallel, vectorizing the simulation over the chosen dimension or parameter." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ "# Create a logger to store simulation data\n", "@jax_dataclasses.pytree_dataclass\n", "class SimulatorLogger(simulator_callbacks.PostStepCallback):\n", @@ -331,36 +188,21 @@ " m = sim.get_model(model.name())\n", " m.data = model.data\n", "\n", - " sim, ((_, cb), step_data) = simulator.step_over_horizon(\n", + " sim, (cb, (_, step_data)) = simulator.step_over_horizon(\n", " horizon_steps=integration_time // step_size,\n", " callback_handler=SimulatorLogger(),\n", " clear_inputs=True,\n", " )\n", "\n", - " return step_data\n", - "\n", - "\n", - "for _ in range(300):\n", - " sim_images.append(\n", - " get_image(\n", - " \"side\",\n", - " from_jaxsim_to_mujoco_pos(\n", - " np.array(model.joint_positions()), mj_model, model\n", - " ),\n", - " )\n", - " )\n", - " model.integrate(\n", - " t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit\n", - " )\n", - "\n", - "media.show_video(sim_images, fps=1 / timestep)" + " return step_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We will make use of `jax.vmap` to simulate multiple models in parallel. This is a very powerful feature of JAX that allows us to write code that is very similar to the single-model case, but can be executed in parallel on multiple models.\n", + "We will make use of `jax.vmap` to simulate multiple models in parallel. This is a very powerful feature of JAX that allows to write code that is very similar to the single-model case, but can be executed in parallel on multiple models.\n", + "In order to do so, we need to first apply `jax.vmap` to the `simulate` function, and then call the resulting function with the batch of different poses as input.\n", "\n", "Note that in our case we are vectorizing over the `pose` argument of the function `simulate`, this correspond to the value assigned to the `in_axes` parameter of `jax.vmap`:\n", "\n", @@ -381,14 +223,21 @@ "\n", "time_history = simulate_vectorized(simulator, poses[:, 0])\n", "\n", - "logging.info(f\"Running simulation with {num_models} models\")" + "comp_time = time.perf_counter() - now\n", + "\n", + "logging.info(\n", + " f\"Running simulation with {envs_per_row**2} models took {comp_time} seconds.\"\n", + ")\n", + "logging.info(\n", + " f\"This corresponds to an RTF (Real Time Factor) of {envs_per_row**2 *integration_time/comp_time}\"\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now let's extract the data from the simulation and plot it." + "Now let's extract the data from the simulation and plot it. We expect to see the height time series of each sphere starting from a different value." ] }, { @@ -403,49 +252,13 @@ "\n", "import matplotlib.pyplot as plt\n", "\n", - "plt.plot(\n", - " time_history[model.name()].tf[0], x_t.base_position[obj_id], label=[\"x\", \"y\", \"z\"]\n", - ")\n", + "plt.plot(time_history[model.name()].tf[0], x_t.base_position[:, :, 2].T)\n", "plt.grid(True)\n", - "plt.legend()\n", "plt.xlabel(\"Time [s]\")\n", - "plt.ylabel(\"Position [m]\")\n", + "plt.ylabel(\"Height [m]\")\n", "plt.title(\"Trajectory of the model's base\")\n", "plt.show()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sim_images = []\n", - "timestep = 0.01\n", - "\n", - "for _ in range(300):\n", - " sim_images.append(\n", - " get_image(\n", - " \"side\",\n", - " from_jaxsim_to_mujoco_pos(\n", - " np.array(model.joint_positions()), mj_model, model\n", - " ),\n", - " )\n", - " )\n", - " model.set_joint_generalized_force_targets(\n", - " forces=pd_controller(\n", - " q=model.joint_positions(),\n", - " q_d=jnp.array([0.0, 0.0]),\n", - " q_dot=model.joint_velocities(),\n", - " q_dot_d=jnp.array([0.0, 0.0]),\n", - " )\n", - " )\n", - " model.integrate(\n", - " t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit\n", - " )\n", - "\n", - "media.show_video(sim_images, fps=1 / timestep)" - ] } ], "metadata": {