From 1440da520898ecb514b6455a292c754919823a77 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 15 Jan 2024 15:08:54 +0100 Subject: [PATCH] Add first version of parallel computing example --- examples/Parallel_computing.ipynb | 485 ++++++++++++++++++++++++++++++ examples/pixi.toml | 1 + 2 files changed, 486 insertions(+) create mode 100644 examples/Parallel_computing.ipynb diff --git a/examples/Parallel_computing.ipynb b/examples/Parallel_computing.ipynb new file mode 100644 index 000000000..fe6ad4237 --- /dev/null +++ b/examples/Parallel_computing.ipynb @@ -0,0 +1,485 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# `JAXsim` Showcase: Parallel Simulation of a free-falling body\n", + "\n", + "\n", + " \"Open\n", + "\n", + "\n", + "First, we install the necessary packages and import them." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Imports and setup\n", + "from IPython.display import clear_output, HTML, display\n", + "import sys\n", + "\n", + "IS_COLAB = \"google.colab\" in sys.modules\n", + "\n", + "# Install JAX and Gazebo\n", + "if IS_COLAB:\n", + " !{sys.executable} -m pip install -U -q jaxsim\n", + " !apt -qq update && apt install -qq --no-install-recommends gazebo\n", + " clear_output()\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import rod\n", + "from rod.builder.primitives import SphereBuilder\n", + "from jaxsim import logging\n", + "\n", + "logging.set_logging_level(logging.LoggingLevel.INFO)\n", + "logging.info(f\"Running on {jax.devices()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will use a simple cartpole model for this example. The cartpole model is a 2D model with a cart that can move horizontally and a pole that can rotate around the cart. The state of the cartpole is given by the position of the cart, the angle of the pole, the velocity of the cart, and the angular velocity of the pole. The control input is the horizontal force applied to the cart." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Create a sphere model\n", + "model_sdf_string = rod.Sdf(\n", + " version=\"1.7\",\n", + " model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n", + " .build_model()\n", + " .add_link()\n", + " .add_inertial()\n", + " .add_visual()\n", + " .add_collision()\n", + " .build(),\n", + ").serialize(pretty=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "JAXsim offers a simple high-level API in order to extract quantities needed in most robotic applications. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from jaxsim.high_level.model import Model\n", + "\n", + "model = Model.build_from_model_description(\n", + " model_description=model_sdf_string, is_urdf=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can create a simulator instance and load the model into it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from jaxsim.simulation.ode_integration import IntegratorType\n", + "from jaxsim.simulation.simulator import JaxSim, SimulatorData, StepData\n", + "from jaxsim.high_level.model import VelRepr\n", + "from jaxsim.physics.algos.soft_contacts import SoftContactsParams\n", + "\n", + "# Simulation Step Parameters\n", + "integration_time = 3.0 # seconds\n", + "step_size = 0.001\n", + "steps_per_run = 1\n", + "\n", + "simulator = JaxSim.build(\n", + " step_size=step_size,\n", + " steps_per_run=steps_per_run,\n", + " velocity_representation=VelRepr.Body,\n", + " integrator_type=IntegratorType.EulerSemiImplicit,\n", + " simulator_data=SimulatorData(\n", + " contact_parameters=SoftContactsParams(K=1e6, D=2e3, mu=0.5),\n", + " ),\n", + ").mutable(validate=False)\n", + "\n", + "\n", + "# Add model to simulator\n", + "\n", + "model = simulator.insert_model_from_description(model_description=model_sdf_string).mutable(validate=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's create a position vector for a 8x8 grid of sphere positions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Primary Calculations\n", + "radius = 0.1\n", + "envs_per_row = 8\n", + "num_envs = envs_per_row**2\n", + "edge_len = env_spacing * envs_per_row + env_spacing * (envs_per_row - 1)\n", + "\n", + "\n", + "# Create Grid\n", + "def grid(num_envs, edge_len, envs_per_row):\n", + " poses = []\n", + " x = 0\n", + " y = 0\n", + "\n", + " for env in range(num_envs):\n", + " x = jnp.linspace(-edge_len, edge_len, envs_per_row)\n", + " y = jnp.linspace(-edge_len, edge_len, envs_per_row)\n", + " xx, yy = jnp.meshgrid(x, y)\n", + "\n", + " poses = [\n", + " [[xx[i, j], yy[i, j], 1], [0, 0, 0]]\n", + " for i in range(xx.shape[0])\n", + " for j in range(yy.shape[0])\n", + " ]\n", + "\n", + " return jnp.array(poses)\n", + "\n", + "\n", + "poses = grid(num_envs, edge_len, envs_per_row)\n", + "model.reset_joint_positions(positions=random_positions)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The visualization is done using mujoco package, to be able to render easily the animations also on Google Colab. If you are not interested in the animation, execute but do not try to understand deeply this cell." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Set up MuJoCo renderer\n", + "!{sys.executable} -m pip install -U -q mujoco\n", + "!{sys.executable} -m pip install -q mediapy\n", + "\n", + "import mediapy as media\n", + "import tempfile\n", + "import xml.etree.ElementTree as ET\n", + "import numpy as np\n", + "\n", + "import distutils.util\n", + "import os\n", + "import subprocess\n", + "\n", + "if IS_COLAB:\n", + " if subprocess.run(\"ffmpeg -version\", shell=True).returncode:\n", + " !command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n", + " clear_output()\n", + "\n", + " if subprocess.run(\"nvidia-smi\").returncode:\n", + " raise RuntimeError(\n", + " \"Cannot communicate with GPU. \"\n", + " \"Make sure you are using a GPU Colab runtime. \"\n", + " \"Go to the Runtime menu and select Choose runtime type.\"\n", + " )\n", + "\n", + " # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n", + " # This is usually installed as part of an Nvidia driver package, but the Colab\n", + " # kernel doesn't install its driver via APT, and as a result the ICD is missing.\n", + " # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\n", + " NVIDIA_ICD_CONFIG_PATH = \"/usr/share/glvnd/egl_vendor.d/10_nvidia.json\"\n", + " if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n", + " with open(NVIDIA_ICD_CONFIG_PATH, \"w\") as f:\n", + " f.write(\n", + " \"\"\"{\n", + " \"file_format_version\" : \"1.0.0\",\n", + " \"ICD\" : {\n", + " \"library_path\" : \"libEGL_nvidia.so.0\"\n", + " }\n", + " }\n", + " \"\"\"\n", + " )\n", + "\n", + "%env MUJOCO_GL=egl\n", + "\n", + "try:\n", + " import mujoco\n", + "except Exception as e:\n", + " raise e from RuntimeError(\n", + " \"Something went wrong during installation. Check the shell output above \"\n", + " \"for more information.\\n\"\n", + " \"If using a hosted Colab runtime, make sure you enable GPU acceleration \"\n", + " 'by going to the Runtime menu and selecting \"Choose runtime type\".'\n", + " )\n", + "\n", + "\n", + "def load_mujoco_model_with_camera(xml_string, camera_pos, camera_xyaxes):\n", + " def to_mjcf_string(list_to_str):\n", + " return \" \".join(map(str, list_to_str))\n", + "\n", + " mj_model_raw = mujoco.MjModel.from_xml_string(model_urdf_string)\n", + " path_temp_xml = tempfile.NamedTemporaryFile(mode=\"w+\")\n", + " mujoco.mj_saveLastXML(path_temp_xml.name, mj_model_raw)\n", + " # Add camera in mujoco model\n", + " tree = ET.parse(path_temp_xml)\n", + " for elem in tree.getroot().iter(\"worldbody\"):\n", + " worldbody_elem = elem\n", + " camera_elem = ET.Element(\"camera\")\n", + " # Set attributes\n", + " camera_elem.set(\"name\", \"side\")\n", + " camera_elem.set(\"pos\", to_mjcf_string(camera_pos))\n", + " camera_elem.set(\"xyaxes\", to_mjcf_string(camera_xyaxes))\n", + " camera_elem.set(\"mode\", \"fixed\")\n", + " worldbody_elem.append(camera_elem)\n", + "\n", + " # Save new model\n", + " mujoco_xml_with_camera = ET.tostring(tree.getroot(), encoding=\"unicode\")\n", + " mj_model = mujoco.MjModel.from_xml_string(mujoco_xml_with_camera)\n", + " return mj_model\n", + "\n", + "\n", + "def from_jaxsim_to_mujoco_pos(jaxsim_jointpos, mjmodel, jaxsimmodel):\n", + " mujocoqposaddr2jaxindex = {}\n", + " for jaxjnt in jaxsimmodel.joints():\n", + " jntname = jaxjnt.name()\n", + " mujocoqposaddr2jaxindex[mjmodel.joint(jntname).qposadr[0]] = jaxjnt.index() - 1\n", + "\n", + " mujoco_jointpos = jaxsim_jointpos\n", + " for i in range(0, len(mujoco_jointpos)):\n", + " mujoco_jointpos[i] = jaxsim_jointpos[mujocoqposaddr2jaxindex[i]]\n", + "\n", + " return mujoco_jointpos\n", + "\n", + "\n", + "# To get a good camera location, you can use \"Copy camera\" functionality in MuJoCo GUI\n", + "mj_model = load_mujoco_model_with_camera(\n", + " model_urdf_string,\n", + " [3.954, 3.533, 2.343],\n", + " [-0.594, 0.804, -0.000, -0.163, -0.120, 0.979],\n", + ")\n", + "renderer = mujoco.Renderer(mj_model, height=480, width=640)\n", + "\n", + "\n", + "def get_image(camera, mujocojointpos) -> np.ndarray:\n", + " \"\"\"Renders the environment state.\"\"\"\n", + " # Copy joint data in mjdata state\n", + " d = mujoco.MjData(mj_model)\n", + " d.qpos = mujocojointpos\n", + "\n", + " # Forward kinematics\n", + " mujoco.mj_forward(mj_model, d)\n", + "\n", + " # use the mjData object to update the renderer\n", + " renderer.update_scene(d, camera=camera)\n", + " return renderer.render()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to parallelize the simulation, we need to define a function for a single element of the batch. This function will be called in parallel, vectorizing the simulation over the chosen dimension or parameter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a logger to store simulation data\n", + "@jax_dataclasses.pytree_dataclass\n", + "class SimulatorLogger(simulator_callbacks.PostStepCallback):\n", + " def post_step(\n", + " self, sim: JaxSim, step_data: Dict[str, StepData]\n", + " ) -> Tuple[JaxSim, jtp.PyTree]:\n", + " \"\"\"Return the StepData object of each simulated model\"\"\"\n", + " return sim, step_data\n", + "\n", + "\n", + "# Define a function to simulate a single model instance\n", + "def simulate(sim: JaxSim, pose) -> JaxSim:\n", + " model.zero()\n", + " model.reset_base_position(position=jnp.array(pose))\n", + "\n", + " with sim.editable(validate=True) as sim:\n", + " m = sim.get_model(model.name())\n", + " m.data = model.data\n", + "\n", + " sim, ((_, cb), step_data) = simulator.step_over_horizon(\n", + " horizon_steps=integration_time // step_size,\n", + " callback_handler=SimulatorLogger(),\n", + " clear_inputs=True,\n", + " )\n", + "\n", + " return step_data\n", + "\n", + "\n", + "for _ in range(300):\n", + " sim_images.append(\n", + " get_image(\n", + " \"side\",\n", + " from_jaxsim_to_mujoco_pos(\n", + " np.array(model.joint_positions()), mj_model, model\n", + " ),\n", + " )\n", + " )\n", + " model.integrate(\n", + " t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit\n", + " )\n", + "\n", + "media.show_video(sim_images, fps=1 / timestep)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will make use of `jax.vmap` to simulate multiple models in parallel. This is a very powerful feature of JAX that allows us to write code that is very similar to the single-model case, but can be executed in parallel on multiple models.\n", + "\n", + "Note that in our case we are vectorizing over the `pose` argument of the function `simulate`, this correspond to the value assigned to the `in_axes` parameter of `jax.vmap`:\n", + "\n", + "`in_axes=(None, 0)` means that the first argument of `simulate` is not vectorized, while the second argument is vectorized over the zero-th dimension." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define a function to simulate multiple model instances\n", + "simulate_vectorized = jax.vmap(simulate, in_axes=(None, 0))\n", + "\n", + "# Run and time the simulation\n", + "now = time.perf_counter()\n", + "\n", + "time_history = simulate_vectorized(simulator, poses[:, 0])\n", + "\n", + "logging.info(f\"Running simulation with {num_models} models\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's extract the data from the simulation and plot it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "time_history: Dict[str, StepData]\n", + "x_t = time_history[model.name()].tf_model_state\n", + "\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "plt.plot(\n", + " time_history[model.name()].tf[0], x_t.base_position[obj_id], label=[\"x\", \"y\", \"z\"]\n", + ")\n", + "plt.grid(True)\n", + "plt.legend()\n", + "plt.xlabel(\"Time [s]\")\n", + "plt.ylabel(\"Position [m]\")\n", + "plt.title(\"Trajectory of the model's base\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sim_images = []\n", + "timestep = 0.01\n", + "\n", + "for _ in range(300):\n", + " sim_images.append(\n", + " get_image(\n", + " \"side\",\n", + " from_jaxsim_to_mujoco_pos(\n", + " np.array(model.joint_positions()), mj_model, model\n", + " ),\n", + " )\n", + " )\n", + " model.set_joint_generalized_force_targets(\n", + " forces=pd_controller(\n", + " q=model.joint_positions(),\n", + " q_d=jnp.array([0.0, 0.0]),\n", + " q_dot=model.joint_velocities(),\n", + " q_dot_d=jnp.array([0.0, 0.0]),\n", + " )\n", + " )\n", + " model.integrate(\n", + " t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit\n", + " )\n", + "\n", + "media.show_video(sim_images, fps=1 / timestep)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuClass": "premium", + "gpuType": "V100", + "private_outputs": true, + "provenance": [ + { + "file_id": "1QsuS7EJhdPEHxxAu9XwozvA7eb4ZnlAb", + "timestamp": 1701993737024 + } + ], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/pixi.toml b/examples/pixi.toml index 4229fe34c..25bff9834 100644 --- a/examples/pixi.toml +++ b/examples/pixi.toml @@ -12,6 +12,7 @@ unix = true [tasks] PD_controller = {cmd = "jupyter notebook PD_controller.ipynb", depends_on = ["install"]} +Parallel_computing = {cmd = "jupyter notebook Parallel_computing.ipynb", depends_on = ["install"]} install = "python -m pip install git+https://github.com/ami-iit/jaxsim.git" [dependencies]