diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index b880547b9..af3ea7ef2 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -382,9 +382,8 @@ def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix # we introduce a Baumgarte stabilization to let the quaternion converge to # a unit quaternion. In this case, it is not guaranteed that the quaternion # stored in the state is a unit quaternion. - W_Q_B = jnp.where( - jnp.allclose(W_Q_B.dot(W_Q_B), 1.0), W_Q_B, W_Q_B / jnp.linalg.norm(W_Q_B) - ) + norm = jaxsim.math.safe_norm(W_Q_B) + W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0)) return (W_Q_B if not dcm else jaxsim.math.Quaternion.to_dcm(W_Q_B)).astype( float @@ -611,11 +610,8 @@ def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self: W_Q_B = jnp.array(base_quaternion, dtype=float) - W_Q_B = jax.lax.select( - pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0), - on_true=W_Q_B, - on_false=W_Q_B / jnp.linalg.norm(W_Q_B), - ) + norm = jaxsim.math.safe_norm(W_Q_B) + W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0)) return self.replace( validate=True, diff --git a/src/jaxsim/math/__init__.py b/src/jaxsim/math/__init__.py index 2e7b9c352..008a94630 100644 --- a/src/jaxsim/math/__init__.py +++ b/src/jaxsim/math/__init__.py @@ -8,5 +8,6 @@ from .rotation import Rotation from .skew import Skew from .transform import Transform +from .utils import safe_norm from .joint_model import JointModel, supported_joint_motion # isort:skip diff --git a/src/jaxsim/math/quaternion.py b/src/jaxsim/math/quaternion.py index e9115cb26..4870f1aa0 100644 --- a/src/jaxsim/math/quaternion.py +++ b/src/jaxsim/math/quaternion.py @@ -4,6 +4,8 @@ import jaxsim.typing as jtp +from .utils import safe_norm + class Quaternion: @staticmethod @@ -111,18 +113,13 @@ def Q_inertial(q: jtp.Vector) -> jtp.Matrix: operand=quaternion, ) - norm_ω = jax.lax.cond( - pred=ω.dot(ω) < (1e-6) ** 2, - true_fun=lambda _: 1e-6, - false_fun=lambda _: jnp.linalg.norm(ω), - operand=None, - ) + norm_ω = safe_norm(ω) qd = 0.5 * ( Q @ jnp.hstack( [ - K * norm_ω * (1 - jnp.linalg.norm(quaternion)), + K * norm_ω * (1 - safe_norm(quaternion)), ω, ] ) diff --git a/src/jaxsim/math/rotation.py b/src/jaxsim/math/rotation.py index f445e1d74..471f496b8 100644 --- a/src/jaxsim/math/rotation.py +++ b/src/jaxsim/math/rotation.py @@ -4,6 +4,7 @@ import jaxsim.typing as jtp from .skew import Skew +from .utils import safe_norm class Rotation: @@ -67,7 +68,7 @@ def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix: def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix: v = axis - theta = jnp.linalg.norm(v) + theta = safe_norm(v) s = jnp.sin(theta) c = jnp.cos(theta) @@ -81,19 +82,9 @@ def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix: return R.transpose() - # Use the double-where trick to prevent JAX problems when the - # jax.jit and jax.grad transforms are applied. return jnp.where( - jnp.linalg.norm(vector) > 0, - theta_is_not_zero( - axis=jnp.where( - jnp.linalg.norm(vector) > 0, - vector, - # The following line is a workaround to prevent division by 0. - # Considering the outer where, this branch is never executed. - jnp.ones(3), - ) - ), + jnp.allclose(vector, 0.0), # Return an identity rotation matrix when the input vector is zero. jnp.eye(3), + theta_is_not_zero(axis=vector), ) diff --git a/src/jaxsim/math/utils.py b/src/jaxsim/math/utils.py new file mode 100644 index 000000000..64d7a24ca --- /dev/null +++ b/src/jaxsim/math/utils.py @@ -0,0 +1,31 @@ +import jax.numpy as jnp + +import jaxsim.typing as jtp + + +def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array: + """ + Provides a calculation for an array norm so that it is safe + to compute the gradient and handle NaNs. + + Args: + array: The array for which to compute the norm. + axis: The axis for which to compute the norm. + + Returns: + The norm of the array with handling for zero arrays to avoid NaNs. + """ + + # Check if the entire array is composed of zeros. + is_zero = jnp.allclose(array, 0.0) + + # Replace zeros with an array of ones temporarily to avoid division by zero. + # This ensures the computation of norm does not produce NaNs or Infs. + array = jnp.where(is_zero, jnp.ones_like(array), array) + + # Compute the norm of the array along the specified axis. + norm = jnp.linalg.norm(array, axis=axis) + + # Use `jnp.where` to set the norm to 0.0 where the input array was all zeros. + # This usage supports potential batch processing for future scalability. + return jnp.where(is_zero, 0.0, norm) diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index be27fc9da..8d4c0d545 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -309,19 +309,16 @@ def hunt_crossley_contact_model( # Compute the direction of the tangential force. # To prevent dividing by zero, we use a switch statement. - # The ε, instead, is needed to make AD happy. - f_tangential_direction = jnp.where( - f_tangential.dot(f_tangential) != 0, - f_tangential / jnp.linalg.norm(f_tangential + ε), - jnp.zeros(3), + norm = jaxsim.math.safe_norm(f_tangential) + f_tangential_direction = f_tangential / ( + norm + jnp.finfo(float).eps * (norm == 0) ) # Project the tangential force to the friction cone if slipping. f_tangential = jnp.where( sticking, f_tangential, - jnp.minimum(μ * force_normal_mag, jnp.linalg.norm(f_tangential + ε)) - * f_tangential_direction, + jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction, ) # Set the tangential force to zero if there is no contact. diff --git a/src/jaxsim/terrain/terrain.py b/src/jaxsim/terrain/terrain.py index 9b2316425..f6b4ddcc2 100644 --- a/src/jaxsim/terrain/terrain.py +++ b/src/jaxsim/terrain/terrain.py @@ -7,6 +7,7 @@ import jax_dataclasses import numpy as np +import jaxsim.math import jaxsim.typing as jtp from jaxsim import exceptions @@ -41,7 +42,7 @@ def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector: [(h_xm - h_xp) / (2 * self.delta), (h_ym - h_yp) / (2 * self.delta), 1.0] ) - return n / jnp.linalg.norm(n) + return n / jaxsim.math.safe_norm(n) @jax_dataclasses.pytree_dataclass diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 477f6245d..c4347c204 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -2,6 +2,7 @@ import jax import jax.numpy as jnp +import numpy as np from jax.test_util import check_grads import jaxsim.api as js @@ -413,3 +414,45 @@ def step( modes=["rev", "fwd"], eps=ε, ) + + +def test_ad_safe_norm( + prng_key: jax.Array, +): + + _, subkey = jax.random.split(prng_key, num=2) + array = jax.random.uniform(subkey, shape=(4,), minval=-5, maxval=5) + + # ==== + # Test + # ==== + + # Test that the safe_norm function is compatible with batching. + array = jnp.stack([array, array]) + assert jaxsim.math.safe_norm(array, axis=1).shape == (2,) + + # Test that the safe_norm function is correctly computing the norm. + assert np.allclose(jaxsim.math.safe_norm(array), np.linalg.norm(array)) + + # Function exposing only the parameters to be differentiated. + def safe_norm(array: jtp.Array) -> jtp.Array: + + return jaxsim.math.safe_norm(array) + + # Check derivatives against finite differences. + check_grads( + f=safe_norm, + args=(array,), + order=AD_ORDER, + modes=["rev", "fwd"], + eps=ε, + ) + + # Check derivatives against finite differences when the array is zero. + check_grads( + f=safe_norm, + args=(jnp.zeros_like(array),), + order=AD_ORDER, + modes=["rev", "fwd"], + eps=ε, + )