diff --git a/src/jaxsim/math/quaternion.py b/src/jaxsim/math/quaternion.py index 69163059f..56d365463 100644 --- a/src/jaxsim/math/quaternion.py +++ b/src/jaxsim/math/quaternion.py @@ -33,7 +33,7 @@ def derivative( omega_in_body_fixed: bool = False, K: float = 0.1, ) -> jtp.Vector: - w = omega.squeeze() + ω = omega.squeeze() quaternion = quaternion.squeeze() def Q_body(q: jtp.Vector) -> jtp.Matrix: @@ -67,10 +67,20 @@ def Q_inertial(q: jtp.Vector) -> jtp.Matrix: operand=quaternion, ) + norm_ω = jax.lax.cond( + pred=ω.dot(ω) < (1e-6) ** 2, + true_fun=lambda _: 1e-6, + false_fun=lambda _: jnp.linalg.norm(ω), + operand=None, + ) + qd = 0.5 * ( Q @ jnp.hstack( - [K * jnp.linalg.norm(w) * (1 - jnp.linalg.norm(quaternion)), w] + [ + K * norm_ω * (1 - jnp.linalg.norm(quaternion)), + ω, + ] ) ) diff --git a/src/jaxsim/physics/algos/soft_contacts.py b/src/jaxsim/physics/algos/soft_contacts.py index 572d053f2..07c81475d 100644 --- a/src/jaxsim/physics/algos/soft_contacts.py +++ b/src/jaxsim/physics/algos/soft_contacts.py @@ -229,6 +229,10 @@ def contact_model( m = tangential_deformation.squeeze() ṁ = jnp.zeros_like(m) + # Note: all the small hardcoded tolerances in this method have been introduced + # to allow jax differentiating through this algorithm. They should not affect + # the accuracy of the simulation, although they might make it less readable. + # ======================== # Normal force computation # ======================== @@ -249,7 +253,11 @@ def contact_model( # Non-linear spring-damper model. # This is the force magnitude along the direction normal to the terrain. - force_normal_mag = jnp.sqrt(δ) * (K * δ + D * δ̇) + force_normal_mag = jax.lax.select( + pred=δ >= 1e-9, + on_true=jnp.sqrt(δ + 1e-12) * (K * δ + D * δ̇), + on_false=jnp.array(0.0), + ) # Prevent negative normal forces that might occur when δ̇ is largely negative force_normal_mag = jnp.maximum(0.0, force_normal_mag) @@ -304,7 +312,7 @@ def below_terrain(): v_tangential = W_ṗ_C - v_normal # Compute the tangential force. If inside the friction cone, the contact - f_tangential = -jnp.sqrt(δ) * (K * m + D * v_tangential) + f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential) def sticking_contact(): # Sum the normal and tangential forces, and create the 6D force @@ -319,9 +327,17 @@ def sticking_contact(): return CW_f, ṁ def slipping_contact(): + # Clip the tangential force if too small, allowing jax to + # differentiate through the norm computation + f_tangential_no_nan = jax.lax.select( + pred=f_tangential.dot(f_tangential) >= 1e-9**2, + on_true=f_tangential, + on_false=jnp.array([1e-12, 0, 0]), + ) + # Project the force to the friction cone boundary f_tangential_projected = (μ * force_normal_mag) * ( - f_tangential / jnp.linalg.norm(f_tangential) + f_tangential / jnp.linalg.norm(f_tangential_no_nan) ) # Sum the normal and tangential forces, and create the 6D force @@ -331,18 +347,18 @@ def slipping_contact(): # Correct the material deformation derivative for slipping contacts. # Basically we compute ṁ such that we get `f_tangential` on the cone # given the current (m, δ). - ε = 1e-6 - α = -K * jnp.sqrt(δ) + ε = 1e-9 δε = jnp.maximum(δ, ε) - βε = -D * jnp.sqrt(δε) - ṁ = (f_tangential_projected - α * m) / βε + α = -K * jnp.sqrt(δε) + β = -D * jnp.sqrt(δε) + ṁ = (f_tangential_projected - α * m) / β # Return the 6D force in the contact frame and # the deformation derivative return CW_f, ṁ CW_f, ṁ = jax.lax.cond( - pred=jnp.linalg.norm(f_tangential) > μ * force_normal_mag, + pred=f_tangential.dot(f_tangential) > (μ * force_normal_mag) ** 2, true_fun=lambda _: slipping_contact(), false_fun=lambda _: sticking_contact(), operand=None, diff --git a/tests/test_ad_physics.py b/tests/test_ad_physics.py new file mode 100644 index 000000000..9c7db2c0e --- /dev/null +++ b/tests/test_ad_physics.py @@ -0,0 +1,190 @@ +import jax.numpy as jnp +import numpy as np +import pytest +from jax.test_util import check_grads +from pytest import param as p + +from jaxsim.high_level.common import VelRepr +from jaxsim.high_level.model import Model + +from . import utils_models, utils_rng +from .utils_models import Robot + + +@pytest.mark.parametrize( + "robot, vel_repr", + [ + p(*[Robot.Ur10, VelRepr.Inertial], id="Ur10-Inertial"), + p(*[Robot.AnymalC, VelRepr.Inertial], id="AnymalC-Inertial"), + p(*[Robot.Cassie, VelRepr.Inertial], id="Cassie-Inertial"), + ], +) +def test_ad_physics(robot: utils_models.Robot, vel_repr: VelRepr) -> None: + """Unit test of the application of Automatic Differentiation on RBD algorithms.""" + + robot = Robot.Ur10 + vel_repr = VelRepr.Inertial + + # Initialize the gravity + gravity = np.array([0, 0, -10.0]) + + # Get the URDF of the robot + urdf_file_path = utils_models.ModelFactory.get_model_description(robot=robot) + + # Build the high-level model + model = Model.build_from_model_description( + model_description=urdf_file_path, + vel_repr=vel_repr, + gravity=gravity, + is_urdf=True, + ).mutable(mutable=True, validate=True) + + # Initialize the model with a random state + model.data.model_state = utils_rng.random_physics_model_state( + physics_model=model.physics_model + ) + + # Initialize the model with a random input + model.data.model_input = utils_rng.random_physics_model_input( + physics_model=model.physics_model + ) + + # ======================== + # Extract state and inputs + # ======================== + + # Extract the physics model used in the low-level physics algorithms + physics_model = model.physics_model + + # State + s = model.joint_positions() + ṡ = model.joint_velocities() + xfb = model.data.model_state.xfb() + + # Inputs + f_ext = model.external_forces() + tau = model.joint_generalized_forces_targets() + + # Perturbation used for computing finite differences + ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3) + + # ===================================================== + # Check first-order and second-order derivatives of ABA + # ===================================================== + + import jaxsim.physics.algos.aba + + aba = lambda xfb, s, ṡ, tau, f_ext: jaxsim.physics.algos.aba.aba( + model=physics_model, xfb=xfb, q=s, qd=ṡ, tau=tau, f_ext=f_ext + ) + + check_grads( + f=aba, + args=(xfb, s, ṡ, tau, f_ext), + order=2, + modes=["rev", "fwd"], + eps=ε, + ) + + # ====================================================== + # Check first-order and second-order derivatives of RNEA + # ====================================================== + + import jaxsim.physics.algos.rnea + + W_v̇_WB = utils_rng.get_rng().uniform(size=6, low=-1) + s̈ = utils_rng.get_rng().uniform(size=physics_model.dofs(), low=-1) + + rnea = lambda xfb, s, ṡ, s̈, W_v̇_WB, f_ext: jaxsim.physics.algos.rnea.rnea( + model=physics_model, xfb=xfb, q=s, qd=ṡ, qdd=s̈, a0fb=W_v̇_WB, f_ext=f_ext + ) + + check_grads( + f=rnea, + args=(xfb, s, ṡ, s̈, W_v̇_WB, f_ext), + order=2, + modes=["rev", "fwd"], + eps=ε, + ) + + # ====================================================== + # Check first-order and second-order derivatives of CRBA + # ====================================================== + + import jaxsim.physics.algos.crba + + crba = lambda s: jaxsim.physics.algos.crba.crba(model=physics_model, q=s) + + check_grads( + f=crba, + args=(s,), + order=2, + modes=["rev", "fwd"], + eps=ε, + ) + + # ==================================================== + # Check first-order and second-order derivatives of FK + # ==================================================== + + import jaxsim.physics.algos.forward_kinematics + + fk = ( + lambda xfb, s: jaxsim.physics.algos.forward_kinematics.forward_kinematics_model( + model=physics_model, xfb=xfb, q=s + ) + ) + + check_grads( + f=fk, + args=(xfb, s), + order=2, + modes=["rev", "fwd"], + eps=ε, + ) + + # ========================================================== + # Check first-order and second-order derivatives of Jacobian + # ========================================================== + + import jaxsim.physics.algos.jacobian + + link_indices = [l.index() for l in model.links()] + + jacobian = lambda s: jaxsim.physics.algos.jacobian.jacobian( + model=physics_model, q=s, body_index=link_indices[-1] + ) + + check_grads( + f=jacobian, + args=(s,), + order=2, + modes=["rev", "fwd"], + eps=ε, + ) + + # ===================================================================== + # Check first-order and second-order derivatives of soft contacts model + # ===================================================================== + + import jaxsim.physics.algos.soft_contacts + + p = utils_rng.get_rng().uniform(size=3, low=-1) + v = utils_rng.get_rng().uniform(size=3, low=-1) + m = utils_rng.get_rng().uniform(size=3, low=-1) + + parameters = jaxsim.physics.algos.soft_contacts.SoftContactsParams.build( + K=10_000, D=20.0, mu=0.5 + ) + + soft_contacts = lambda p, v, m: jaxsim.physics.algos.soft_contacts.SoftContacts( + parameters=parameters + ).contact_model(position=p, velocity=v, tangential_deformation=m) + + check_grads( + f=soft_contacts, + args=(p, v, m), + order=2, + modes=["rev", "fwd"], + eps=ε, + )