Skip to content

Commit

Permalink
Update render usage
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 12, 2023
1 parent 7e7b36e commit 73662de
Showing 1 changed file with 39 additions and 37 deletions.
76 changes: 39 additions & 37 deletions examples/PD_controller.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,19 @@
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import clear_output, HTML\n",
"import sys \n",
"from IPython.display import clear_output, HTML, display\n",
"import sys\n",
"\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
"IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n",
"\n",
"# Install JAX and Gazebo\n",
"!{sys.executable} -m pip install -U -q \"jax[cuda12_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!{sys.executable} -m pip install -U -q jaxsim\n",
"!apt -qq update && apt install -qq --no-install-recommends gazebo\n",
"clear_output()\n",
"if IS_COLAB or IS_KAGGLE:\n",
" !{sys.executable} -m pip install -U -q \"jax[cuda12_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
" !{sys.executable} -m pip install -U -q jaxsim\n",
" !{sys.executable} -m pip install -q git+https://github.com/ami-iit/meshcat-viz-python\n",
" !apt -qq update && apt install -qq --no-install-recommends gazebo\n",
" clear_output()\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down Expand Up @@ -113,30 +118,26 @@
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" from meshcat_viz.world import MeshcatWorld\n",
"except:\n",
" !{sys.executable} -m pip install -q git+https://github.com/ami-iit/meshcat-viz-python\n",
" clear_output()\n",
" from meshcat_viz.world import MeshcatWorld\n",
"from meshcat_viz.world import MeshcatWorld\n",
"\n",
"world = MeshcatWorld()\n",
"world.meshcat_visualizer.render_static()\n",
"\n",
"world.insert_model(\n",
" model_description=model_urdf_path, model_name=\"Cartpole\", is_urdf=True\n",
")\n",
"render = lambda: display(\n",
" HTML(\n",
" \"\"\"\n",
" <div style=\"height: 400px; width: 100%; overflow-x: auto; overflow-y: hidden; resize: both\">\n",
" <iframe src=\"{url}\" style=\"width: 100%; height: 100%; border: none\"></iframe>\n",
" </div>\n",
" \"\"\".format(\n",
" url=world.web_url\n",
" )\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"world.meshcat_visualizer.render_static()"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -152,15 +153,16 @@
"source": [
"from jaxsim.simulation.ode_integration import IntegratorType\n",
"\n",
"render()\n",
"\n",
"for _ in range(200):\n",
" model.integrate(t0=0.0, tf=0.01, integrator_type=IntegratorType.EulerSemiImplicit)\n",
" \n",
" world.update_model(\n",
" model_name=\"Cartpole\",\n",
" base_position=model.base_position(),\n",
" joint_positions=model.joint_positions(),\n",
" joint_names=model.joint_names(),\n",
" )"
" )\n",
" model.integrate(t0=0.0, tf=0.01, integrator_type=IntegratorType.EulerSemiImplicit)"
]
},
{
Expand All @@ -179,8 +181,9 @@
"outputs": [],
"source": [
"# Define the PD gains\n",
"KP = 10.0\n",
"KD = 1.0\n",
"KP = 30.0\n",
"KD = 4.0\n",
"\n",
"\n",
"def pd_controller(\n",
" q: jax.Array, q_d: jax.Array, q_dot: jax.Array, q_dot_d: jax.Array\n",
Expand All @@ -201,7 +204,15 @@
"metadata": {},
"outputs": [],
"source": [
"render()\n",
"\n",
"for _ in range(200):\n",
" world.update_model(\n",
" model_name=\"Cartpole\",\n",
" base_position=model.base_position(),\n",
" joint_positions=model.joint_positions(),\n",
" joint_names=model.joint_names(),\n",
" )\n",
" model.set_joint_generalized_force_targets(\n",
" forces=pd_controller(\n",
" q=model.joint_positions(),\n",
Expand All @@ -210,16 +221,7 @@
" q_dot_d=jnp.array([0.0, 0.0]),\n",
" )\n",
" )\n",
"\n",
" logging.info(f\"Joint generalized forces: {model.data.model_input.tau}\")\n",
" \n",
" model.integrate(t0=0.0, tf=0.01, integrator_type=IntegratorType.EulerSemiImplicit)\n",
" world.update_model(\n",
" model_name=\"Cartpole\",\n",
" base_position=model.base_position(),\n",
" joint_positions=model.joint_positions(),\n",
" joint_names=model.joint_names(),\n",
" )"
" model.integrate(t0=0.0, tf=0.01, integrator_type=IntegratorType.EulerSemiImplicit)"
]
}
],
Expand Down

0 comments on commit 73662de

Please sign in to comment.