Skip to content

Commit

Permalink
Add first commit PD example with cartpole model
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Nov 29, 2023
1 parent 31bf38e commit 2f929b6
Show file tree
Hide file tree
Showing 2 changed files with 263 additions and 0 deletions.
182 changes: 182 additions & 0 deletions examples/PD_controller.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# `JAXsim` Showcase: PD Controller\n",
"\n",
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/flferretti/jaxsim/example/PD_controller.py\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>\n",
"\n",
"First, we install the necessary packages and import them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade -q \"jax[cuda12_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"%pip install -q -e git+https://github.com/ami-iit/jaxsim@new_api#egg=jaxsim\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from jaxsim import logging\n",
"\n",
"logging.info(f\"Running on {jax.devices()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use a simple cartpole model for this example. The cartpole model is a 2D model with a cart that can move horizontally and a pole that can rotate around the cart. The state of the cartpole is given by the position of the cart, the angle of the pole, the velocity of the cart, and the angular velocity of the pole. The control input is the horizontal force applied to the cart."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_urdf_path = (\n",
" \"https://raw.githubusercontent.com/flferretti/jaxsim/examples/assets/cartpole.urdf\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"JAXsim offers a simple high-level API in order to extract quantities needed in most robotic applications. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from jaxsim.high_level import Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = Model.build_from_model_description(model_description=urdf_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The visualization is done using the [`meshcat-viz-python`](https://github.com/ami-iit/meshcat-viz-python) package. Let's import it and create a `Visualizer` object."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" from meshcat_viz.world import MeshcatWorld\n",
"except:\n",
" %pip install -q git+https://github.com/ami-iit/meshcat-viz-python\n",
" from meshcat_viz.world import MeshcatWorld\n",
"\n",
"world = MeshcatWorld()\n",
"world.open()\n",
"\n",
"from IPython.display import IFrame\n",
"IFrame(src=\"https://127.0.0.1:7010/static/\", width='100%', height='500px')\n",
"\n",
"world.insert_model(\n",
" model_description=model_urdf_path, model_name=\"Cartpole\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see how the model behaves when not controlled:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for _ in range(200):\n",
" model.integrate(0.01)\n",
" world.update_model(model_name=\"Cartpole\", base_position=model.base_position(), joint_positions=model.joint_positions())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's now define the PD controller. We will use the following equations:\n",
"\n",
"\\begin{align} \\tau &= K_p \\left( q_d - q \\right) + K_d \\left( \\dot{q}_d - \\dot{q} \\right) \\end{align}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def pd_controller(q: jax.Array, q_d: jax.Array, q_dot: jax.Array, q_dot_d: jax.Array) -> jax.Array:\n",
" return KP * (q_d - q) + KD * (q_dot_d - q_dot)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can use the `pd_controller` function to compute the torque to apply to the cartpole. Our aim is to stabilize the cartpole in the upright position, so we set the desired position `q_d` to 0 and the desired velocity `q_dot_d` to 0."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for _ in range(200):\n",
" model.integrate(0.01)\n",
" world.update_model(model_name=\"Cartpole\", base_position=model.base_position(), joint_positions=model.joint_positions())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "jaxsim",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
81 changes: 81 additions & 0 deletions examples/assets/cartpole.urdf
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
<?xml version="1.0" encoding="utf-8"?>
<!--Create with rod using build_cartpole_model.py -->
<robot name="cartpole">
<link name="world"/>
<link name="rail">
<inertial>
<origin xyz="0.0 0.0 1.2" rpy="1.5707963267948963 0.0 0.0"/>
<mass value="5.0"/>
<inertia ixx="10.416697916666665" ixy="0.0" ixz="0.0" iyy="10.416697916666665" iyz="0.0" izz="6.25e-05"/>
</inertial>
<visual name="rail_visual">
<origin xyz="0.0 0.0 1.2" rpy="1.5707963267948963 0.0 0.0"/>
<geometry>
<cylinder radius="0.005" length="5.0"/>
</geometry>
</visual>
<collision name="rail_collision">
<origin xyz="0.0 0.0 1.2" rpy="1.5707963267948963 0.0 0.0"/>
<geometry>
<cylinder radius="0.005" length="5.0"/>
</geometry>
</collision>
</link>
<link name="cart">
<inertial>
<origin xyz="0.0 0.0 0.0" rpy="0.0 0.0 0.0"/>
<mass value="1.0"/>
<inertia ixx="0.0035416666666666674" ixy="0.0" ixz="0.0" iyy="0.0010416666666666669" iyz="0.0" izz="0.0041666666666666675"/>
</inertial>
<visual name="cart_visual">
<origin xyz="0.0 0.0 0.0" rpy="0.0 0.0 0.0"/>
<geometry>
<box size="0.1 0.2 0.05"/>
</geometry>
</visual>
<collision name="cart_collision">
<origin xyz="0.0 0.0 0.0" rpy="0.0 0.0 0.0"/>
<geometry>
<box size="0.1 0.2 0.05"/>
</geometry>
</collision>
</link>
<link name="pole">
<inertial>
<origin xyz="0.0 0.0 0.5" rpy="0.0 0.0 0.0"/>
<mass value="0.1"/>
<inertia ixx="0.008333958333333334" ixy="0.0" ixz="0.0" iyy="0.008333958333333334" iyz="0.0" izz="1.25e-06"/>
</inertial>
<visual name="pole_visual">
<origin xyz="0.0 0.0 0.5" rpy="0.0 0.0 0.0"/>
<geometry>
<cylinder radius="0.005" length="1.0"/>
</geometry>
</visual>
<collision name="pole_collision">
<origin xyz="0.0 0.0 0.5" rpy="0.0 0.0 0.0"/>
<geometry>
<cylinder radius="0.005" length="1.0"/>
</geometry>
</collision>
</link>
<joint name="world_to_rail" type="fixed">
<origin xyz="0.0 0.0 0.0" rpy="0.0 0.0 0.0"/>
<parent link="world"/>
<child link="rail"/>
</joint>
<joint name="linear" type="prismatic">
<origin xyz="0.0 0.0 1.2" rpy="0.0 0.0 0.0"/>
<parent link="rail"/>
<child link="cart"/>
<axis xyz="0 1 0"/>
<limit effort="500.0" velocity="10.0" lower="-2.4" upper="2.4"/>
</joint>
<joint name="pivot" type="revolute">
<origin xyz="0.0 0.0 0.0" rpy="0.0 0.0 0.0"/>
<parent link="cart"/>
<child link="pole"/>
<axis xyz="1 0 0"/>
<limit effort="500.0" velocity="10.0" lower="-3.4028234663852886e+38" upper="3.4028234663852886e+38"/>
</joint>
</robot>

0 comments on commit 2f929b6

Please sign in to comment.