Skip to content

Commit

Permalink
Merge pull request #300 from CarlottaSartore/add_frame_example
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti authored Nov 28, 2024
2 parents 9739cdf + c20e426 commit 524ad42
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 35 deletions.
145 changes: 112 additions & 33 deletions examples/jaxsim_as_multibody_dynamics_library.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -320,12 +320,18 @@
"\n",
"- **`jaxsim.api.model`**: vectorized functions operating on the whole model.\n",
"- **`jaxsim.api.link`**: functions operating on individual links.\n",
"- **`jaxsim.api.frame`**: functions operating on individual frames. \n",
"\n",
"Due to JAX limitations on vectorizable data types, many APIs operate on indices instead of names. Since using indices can be error prone, JaxSim provides conversion functions:\n",
"Due to JAX limitations on vectorizable data types, many APIs operate on indices instead of names. Since using indices can be error prone, JaxSim provides conversion functions for both links:\n",
"\n",
"- **jaxsim.api.link.names_to_idxs()**\n",
"- **jaxsim.api.link.idxs_to_names()**\n",
"\n",
"and frames: \n",
"\n",
"- **jaxsim.api.frame.names_to_idxs()**\n",
"- **jaxsim.api.frame.idxs_to_names()**\n",
"\n",
"We recommend using names whenever possible to avoid hard-to-trace errors.\n"
]
},
Expand Down Expand Up @@ -354,7 +360,7 @@
},
"outputs": [],
"source": [
"# @title Pose\n",
"# @title Link Pose\n",
"\n",
"# Compute its pose w.r.t. the world frame through forward kinematics.\n",
"W_H_L = js.link.transform(model=model, data=data, link_index=link_index)\n",
Expand All @@ -380,7 +386,7 @@
},
"outputs": [],
"source": [
"# @title 6D Velocity\n",
"# @title Link 6D Velocity\n",
"\n",
"# JaxSim allows to select the so-called representation of the frame velocity.\n",
"L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body)\n",
Expand Down Expand Up @@ -413,10 +419,87 @@
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SSoziCShtwZ9"
},
"outputs": [],
"source": [
"# Find the index of a frame.\n",
"frame_name = \"l_foot_front\"\n",
"frame_index = js.frame.name_to_idx(model=model, frame_name=frame_name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fVp_xP_1twZ9",
"outputId": "cfaa0569-d768-4708-c98c-a5867c056d04"
},
"outputs": [],
"source": [
"# @title Frame Pose\n",
"\n",
"# Compute its pose w.r.t. the world frame through forward kinematics.\n",
"W_H_F = js.frame.transform(model=model, data=data, frame_index=frame_index)\n",
"\n",
"print(f\"Transform of '{frame_name}': shape={W_H_F.shape}\\n{W_H_F}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QqaqxneEFYiW"
},
"outputs": [],
"source": [
"# @title Frame 6D Velocity\n",
"\n",
"# JaxSim allows to select the so-called representation of the frame velocity.\n",
"F_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Body)\n",
"FW_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Mixed)\n",
"W_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial)\n",
"\n",
"print(f\"Body-fixed velocity F_v_WF={F_v_WF}\")\n",
"print(f\"Mixed velocity: FW_v_WF={FW_v_WF}\")\n",
"print(f\"Inertial-fixed velocity: W_v_WF={W_v_WF}\")\n",
"\n",
"# These can also be computed passing through the frame free-floating Jacobian.\n",
"# This type of Jacobian has a input velocity representation that corresponds\n",
"# the velocity representation of ν, and an output velocity representation that\n",
"# corresponds to the velocity representation of the desired 6D velocity.\n",
"\n",
"# You can use the following context manager to easily switch between representations.\n",
"with data.switch_velocity_representation(VelRepr.Body):\n",
"\n",
" # Body-fixed generalized velocity.\n",
" B_ν = data.generalized_velocity()\n",
"\n",
" # Free-floating Jacobian accepting a body-fixed generalized velocity and\n",
" # returning an inertial-fixed link velocity.\n",
" W_J_WF_B = js.frame.jacobian(\n",
" model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial\n",
" )\n",
"\n",
"# Now the following relation should hold.\n",
"assert jnp.allclose(W_v_WF, W_J_WF_B @ B_ν)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "d_vp6D74GoVZ",
"outputId": "798b9283-792e-4339-b56c-df2595fac974"
},
"source": [
"## Robot Dynamics\n",
"\n",
Expand Down Expand Up @@ -460,11 +543,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fVp_xP_1twZ9",
"outputId": "cfaa0569-d768-4708-c98c-a5867c056d04"
"id": "oOKJOVfsH4Ki"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -509,7 +588,11 @@
{
"cell_type": "markdown",
"metadata": {
"id": "QqaqxneEFYiW"
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "FlNo8dNWKKtu",
"outputId": "313e939b-f88f-4407-c9ee-b5b3b7443061"
},
"source": [
"### Forward Dynamics\n",
Expand All @@ -530,11 +613,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "d_vp6D74GoVZ",
"outputId": "798b9283-792e-4339-b56c-df2595fac974"
"id": "LXARuRu1Ly1K"
},
"outputs": [],
"source": [
Expand All @@ -554,7 +633,11 @@
{
"cell_type": "markdown",
"metadata": {
"id": "oOKJOVfsH4Ki"
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "g5GOYXDnLySU",
"outputId": "ad4ce77d-d06f-473a-9c32-040680d76aa5"
},
"source": [
"### Inverse Dynamics\n",
Expand All @@ -575,11 +658,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "FlNo8dNWKKtu",
"outputId": "313e939b-f88f-4407-c9ee-b5b3b7443061"
"id": "UTae5MjhaP2H"
},
"outputs": [],
"source": [
Expand All @@ -604,7 +683,11 @@
{
"cell_type": "markdown",
"metadata": {
"id": "LXARuRu1Ly1K"
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gYZ1jK1Neg1H",
"outputId": "0de79770-1e18-4027-bb47-5713bc1b4a72"
},
"source": [
"### Centroidal Dynamics\n",
Expand Down Expand Up @@ -656,11 +739,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "g5GOYXDnLySU",
"outputId": "ad4ce77d-d06f-473a-9c32-040680d76aa5"
"id": "rrSfxp8lh9YZ"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -716,7 +795,11 @@
{
"cell_type": "markdown",
"metadata": {
"id": "UTae5MjhaP2H"
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Ot6HePB_twaE",
"outputId": "02a6abae-257e-45ee-e9de-6a607cdbeb9a"
},
"source": [
"## Contact Frames\n",
Expand Down Expand Up @@ -746,11 +829,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gYZ1jK1Neg1H",
"outputId": "0de79770-1e18-4027-bb47-5713bc1b4a72"
"id": "LITRC3STliKR"
},
"outputs": [],
"source": [
Expand Down
54 changes: 52 additions & 2 deletions src/jaxsim/api/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int:
The index of the frame.
"""

if frame_name not in model.kin_dyn_parameters.frame_parameters.name:
if frame_name not in model.frame_names():
raise ValueError(f"Frame '{frame_name}' not found in the model.")

return (
Expand Down Expand Up @@ -180,6 +180,56 @@ def transform(
return W_H_L @ L_H_F


@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def velocity(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
frame_index: jtp.IntLike,
output_vel_repr: VelRepr | None = None,
) -> jtp.Vector:
"""
Compute the 6D velocity of the frame.
Args:
model: The model to consider.
data: The data of the considered model.
frame_index: The index of the frame.
output_vel_repr:
The output velocity representation of the frame velocity.
Returns:
The 6D velocity of the frame in the specified velocity representation.
"""
n_l = model.number_of_links()
n_f = model.number_of_frames()

exceptions.raise_value_error_if(
condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
msg="Invalid frame index '{idx}'",
idx=frame_index,
)

output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)

# Get the frame jacobian having I as input representation (taken from data)
# and O as output representation, specified by the user (or taken from data).
O_J_WF_I = jacobian(
model=model,
data=data,
frame_index=frame_index,
output_vel_repr=output_vel_repr,
)

# Get the generalized velocity in the input velocity representation.
I_ν = data.generalized_velocity()

# Compute the frame velocity in the output velocity representation.
return O_J_WF_I @ I_ν


@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def jacobian(
model: js.model.JaxSimModel,
Expand Down Expand Up @@ -207,7 +257,7 @@ def jacobian(
"""

n_l = model.number_of_links()
n_f = len(model.frame_names())
n_f = model.number_of_frames()

exceptions.raise_value_error_if(
condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
Expand Down
10 changes: 10 additions & 0 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,16 @@ def number_of_joints(self) -> int:

return len(self.joint_model.joint_names) - 1

def number_of_frames(self) -> int:
"""
Return the number of frames of the model.
Returns:
The number of frames of the model.
"""

return len(self.frame_parameters.name)

def support_body_array(self, link_index: jtp.IntLike) -> jtp.Vector:
r"""
Return the support parent array :math:`\kappa(i)` of a link.
Expand Down
11 changes: 11 additions & 0 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,17 @@ def number_of_joints(self) -> int:

return self.kin_dyn_parameters.number_of_joints()

def number_of_frames(self) -> int:
"""
Return the number of frames in the model.
Returns:
The number of frames in the model.
"""

return self.kin_dyn_parameters.number_of_frames()

# =================
# Base link methods
# =================
Expand Down
6 changes: 6 additions & 0 deletions tests/test_api_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ def test_frame_jacobians(
J_WL_idt = kin_dyn.jacobian_frame(frame_name=frame_name)
assert J_WL_js == pytest.approx(J_WL_idt, abs=1e-9)

for frame_name, frame_index in zip(frame_names, frame_indices, strict=True):

v_WF_idt = kin_dyn.frame_velocity(frame_name=frame_name)
v_WF_js = js.frame.velocity(model=model, data=data, frame_index=frame_index)
assert v_WF_js == pytest.approx(v_WF_idt), frame_name


def test_frame_jacobian_derivative(
jaxsim_models_types: js.model.JaxSimModel,
Expand Down

0 comments on commit 524ad42

Please sign in to comment.