Skip to content

Commit

Permalink
Merge pull request ami-iit#54 from ami-iit/automatic_differentiation
Browse files Browse the repository at this point in the history
First steps towards automatic differentiation of RBDAs
  • Loading branch information
diegoferigo authored Oct 11, 2023
2 parents 5573944 + c5c799d commit e9bb166
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 10 deletions.
14 changes: 12 additions & 2 deletions src/jaxsim/math/quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def derivative(
omega_in_body_fixed: bool = False,
K: float = 0.1,
) -> jtp.Vector:
w = omega.squeeze()
ω = omega.squeeze()
quaternion = quaternion.squeeze()

def Q_body(q: jtp.Vector) -> jtp.Matrix:
Expand Down Expand Up @@ -67,10 +67,20 @@ def Q_inertial(q: jtp.Vector) -> jtp.Matrix:
operand=quaternion,
)

norm_ω = jax.lax.cond(
pred=ω.dot(ω) < (1e-6) ** 2,
true_fun=lambda _: 1e-6,
false_fun=lambda _: jnp.linalg.norm(ω),
operand=None,
)

qd = 0.5 * (
Q
@ jnp.hstack(
[K * jnp.linalg.norm(w) * (1 - jnp.linalg.norm(quaternion)), w]
[
K * norm_ω * (1 - jnp.linalg.norm(quaternion)),
ω,
]
)
)

Expand Down
32 changes: 24 additions & 8 deletions src/jaxsim/physics/algos/soft_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ def contact_model(
m = tangential_deformation.squeeze()
= jnp.zeros_like(m)

# Note: all the small hardcoded tolerances in this method have been introduced
# to allow jax differentiating through this algorithm. They should not affect
# the accuracy of the simulation, although they might make it less readable.

# ========================
# Normal force computation
# ========================
Expand All @@ -249,7 +253,11 @@ def contact_model(

# Non-linear spring-damper model.
# This is the force magnitude along the direction normal to the terrain.
force_normal_mag = jnp.sqrt(δ) * (K * δ + D * δ̇)
force_normal_mag = jax.lax.select(
pred=δ >= 1e-9,
on_true=jnp.sqrt(δ + 1e-12) * (K * δ + D * δ̇),
on_false=jnp.array(0.0),
)

# Prevent negative normal forces that might occur when δ̇ is largely negative
force_normal_mag = jnp.maximum(0.0, force_normal_mag)
Expand Down Expand Up @@ -304,7 +312,7 @@ def below_terrain():
v_tangential = W_ṗ_C - v_normal

# Compute the tangential force. If inside the friction cone, the contact
f_tangential = -jnp.sqrt(δ) * (K * m + D * v_tangential)
f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential)

def sticking_contact():
# Sum the normal and tangential forces, and create the 6D force
Expand All @@ -319,9 +327,17 @@ def sticking_contact():
return CW_f,

def slipping_contact():
# Clip the tangential force if too small, allowing jax to
# differentiate through the norm computation
f_tangential_no_nan = jax.lax.select(
pred=f_tangential.dot(f_tangential) >= 1e-9**2,
on_true=f_tangential,
on_false=jnp.array([1e-12, 0, 0]),
)

# Project the force to the friction cone boundary
f_tangential_projected = (μ * force_normal_mag) * (
f_tangential / jnp.linalg.norm(f_tangential)
f_tangential / jnp.linalg.norm(f_tangential_no_nan)
)

# Sum the normal and tangential forces, and create the 6D force
Expand All @@ -331,18 +347,18 @@ def slipping_contact():
# Correct the material deformation derivative for slipping contacts.
# Basically we compute ṁ such that we get `f_tangential` on the cone
# given the current (m, δ).
ε = 1e-6
α = -K * jnp.sqrt(δ)
ε = 1e-9
δε = jnp.maximum(δ, ε)
βε = -D * jnp.sqrt(δε)
= (f_tangential_projected - α * m) / βε
α = -K * jnp.sqrt(δε)
β = -D * jnp.sqrt(δε)
= (f_tangential_projected - α * m) / β

# Return the 6D force in the contact frame and
# the deformation derivative
return CW_f,

CW_f, = jax.lax.cond(
pred=jnp.linalg.norm(f_tangential) > μ * force_normal_mag,
pred=f_tangential.dot(f_tangential) > (μ * force_normal_mag) ** 2,
true_fun=lambda _: slipping_contact(),
false_fun=lambda _: sticking_contact(),
operand=None,
Expand Down
190 changes: 190 additions & 0 deletions tests/test_ad_physics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import jax.numpy as jnp
import numpy as np
import pytest
from jax.test_util import check_grads
from pytest import param as p

from jaxsim.high_level.common import VelRepr
from jaxsim.high_level.model import Model

from . import utils_models, utils_rng
from .utils_models import Robot


@pytest.mark.parametrize(
"robot, vel_repr",
[
p(*[Robot.Ur10, VelRepr.Inertial], id="Ur10-Inertial"),
p(*[Robot.AnymalC, VelRepr.Inertial], id="AnymalC-Inertial"),
p(*[Robot.Cassie, VelRepr.Inertial], id="Cassie-Inertial"),
],
)
def test_ad_physics(robot: utils_models.Robot, vel_repr: VelRepr) -> None:
"""Unit test of the application of Automatic Differentiation on RBD algorithms."""

robot = Robot.Ur10
vel_repr = VelRepr.Inertial

# Initialize the gravity
gravity = np.array([0, 0, -10.0])

# Get the URDF of the robot
urdf_file_path = utils_models.ModelFactory.get_model_description(robot=robot)

# Build the high-level model
model = Model.build_from_model_description(
model_description=urdf_file_path,
vel_repr=vel_repr,
gravity=gravity,
is_urdf=True,
).mutable(mutable=True, validate=True)

# Initialize the model with a random state
model.data.model_state = utils_rng.random_physics_model_state(
physics_model=model.physics_model
)

# Initialize the model with a random input
model.data.model_input = utils_rng.random_physics_model_input(
physics_model=model.physics_model
)

# ========================
# Extract state and inputs
# ========================

# Extract the physics model used in the low-level physics algorithms
physics_model = model.physics_model

# State
s = model.joint_positions()
= model.joint_velocities()
xfb = model.data.model_state.xfb()

# Inputs
f_ext = model.external_forces()
tau = model.joint_generalized_forces_targets()

# Perturbation used for computing finite differences
ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3)

# =====================================================
# Check first-order and second-order derivatives of ABA
# =====================================================

import jaxsim.physics.algos.aba

aba = lambda xfb, s, , tau, f_ext: jaxsim.physics.algos.aba.aba(
model=physics_model, xfb=xfb, q=s, qd=, tau=tau, f_ext=f_ext
)

check_grads(
f=aba,
args=(xfb, s, , tau, f_ext),
order=2,
modes=["rev", "fwd"],
eps=ε,
)

# ======================================================
# Check first-order and second-order derivatives of RNEA
# ======================================================

import jaxsim.physics.algos.rnea

W_v̇_WB = utils_rng.get_rng().uniform(size=6, low=-1)
= utils_rng.get_rng().uniform(size=physics_model.dofs(), low=-1)

rnea = lambda xfb, s, , , W_v̇_WB, f_ext: jaxsim.physics.algos.rnea.rnea(
model=physics_model, xfb=xfb, q=s, qd=, qdd=, a0fb=W_v̇_WB, f_ext=f_ext
)

check_grads(
f=rnea,
args=(xfb, s, , , W_v̇_WB, f_ext),
order=2,
modes=["rev", "fwd"],
eps=ε,
)

# ======================================================
# Check first-order and second-order derivatives of CRBA
# ======================================================

import jaxsim.physics.algos.crba

crba = lambda s: jaxsim.physics.algos.crba.crba(model=physics_model, q=s)

check_grads(
f=crba,
args=(s,),
order=2,
modes=["rev", "fwd"],
eps=ε,
)

# ====================================================
# Check first-order and second-order derivatives of FK
# ====================================================

import jaxsim.physics.algos.forward_kinematics

fk = (
lambda xfb, s: jaxsim.physics.algos.forward_kinematics.forward_kinematics_model(
model=physics_model, xfb=xfb, q=s
)
)

check_grads(
f=fk,
args=(xfb, s),
order=2,
modes=["rev", "fwd"],
eps=ε,
)

# ==========================================================
# Check first-order and second-order derivatives of Jacobian
# ==========================================================

import jaxsim.physics.algos.jacobian

link_indices = [l.index() for l in model.links()]

jacobian = lambda s: jaxsim.physics.algos.jacobian.jacobian(
model=physics_model, q=s, body_index=link_indices[-1]
)

check_grads(
f=jacobian,
args=(s,),
order=2,
modes=["rev", "fwd"],
eps=ε,
)

# =====================================================================
# Check first-order and second-order derivatives of soft contacts model
# =====================================================================

import jaxsim.physics.algos.soft_contacts

p = utils_rng.get_rng().uniform(size=3, low=-1)
v = utils_rng.get_rng().uniform(size=3, low=-1)
m = utils_rng.get_rng().uniform(size=3, low=-1)

parameters = jaxsim.physics.algos.soft_contacts.SoftContactsParams.build(
K=10_000, D=20.0, mu=0.5
)

soft_contacts = lambda p, v, m: jaxsim.physics.algos.soft_contacts.SoftContacts(
parameters=parameters
).contact_model(position=p, velocity=v, tangential_deformation=m)

check_grads(
f=soft_contacts,
args=(p, v, m),
order=2,
modes=["rev", "fwd"],
eps=ε,
)

0 comments on commit e9bb166

Please sign in to comment.