Skip to content

Commit

Permalink
Merge pull request #319 from ami-iit/safe_norm
Browse files Browse the repository at this point in the history
Add safe norm function and refactor usages
  • Loading branch information
flferretti authored Jan 7, 2025
2 parents 1ee249c + acbe8b7 commit ceb1fab
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 36 deletions.
12 changes: 4 additions & 8 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,8 @@ def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix
# we introduce a Baumgarte stabilization to let the quaternion converge to
# a unit quaternion. In this case, it is not guaranteed that the quaternion
# stored in the state is a unit quaternion.
W_Q_B = jnp.where(
jnp.allclose(W_Q_B.dot(W_Q_B), 1.0), W_Q_B, W_Q_B / jnp.linalg.norm(W_Q_B)
)
norm = jaxsim.math.safe_norm(W_Q_B)
W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))

return (W_Q_B if not dcm else jaxsim.math.Quaternion.to_dcm(W_Q_B)).astype(
float
Expand Down Expand Up @@ -611,11 +610,8 @@ def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:

W_Q_B = jnp.array(base_quaternion, dtype=float)

W_Q_B = jax.lax.select(
pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
on_true=W_Q_B,
on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
)
norm = jaxsim.math.safe_norm(W_Q_B)
W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))

return self.replace(
validate=True,
Expand Down
1 change: 1 addition & 0 deletions src/jaxsim/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
from .rotation import Rotation
from .skew import Skew
from .transform import Transform
from .utils import safe_norm

from .joint_model import JointModel, supported_joint_motion # isort:skip
11 changes: 4 additions & 7 deletions src/jaxsim/math/quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import jaxsim.typing as jtp

from .utils import safe_norm


class Quaternion:
@staticmethod
Expand Down Expand Up @@ -111,18 +113,13 @@ def Q_inertial(q: jtp.Vector) -> jtp.Matrix:
operand=quaternion,
)

norm_ω = jax.lax.cond(
pred=ω.dot(ω) < (1e-6) ** 2,
true_fun=lambda _: 1e-6,
false_fun=lambda _: jnp.linalg.norm(ω),
operand=None,
)
norm_ω = safe_norm(ω)

qd = 0.5 * (
Q
@ jnp.hstack(
[
K * norm_ω * (1 - jnp.linalg.norm(quaternion)),
K * norm_ω * (1 - safe_norm(quaternion)),
ω,
]
)
Expand Down
17 changes: 4 additions & 13 deletions src/jaxsim/math/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jaxsim.typing as jtp

from .skew import Skew
from .utils import safe_norm


class Rotation:
Expand Down Expand Up @@ -67,7 +68,7 @@ def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix:
def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix:

v = axis
theta = jnp.linalg.norm(v)
theta = safe_norm(v)

s = jnp.sin(theta)
c = jnp.cos(theta)
Expand All @@ -81,19 +82,9 @@ def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix:

return R.transpose()

# Use the double-where trick to prevent JAX problems when the
# jax.jit and jax.grad transforms are applied.
return jnp.where(
jnp.linalg.norm(vector) > 0,
theta_is_not_zero(
axis=jnp.where(
jnp.linalg.norm(vector) > 0,
vector,
# The following line is a workaround to prevent division by 0.
# Considering the outer where, this branch is never executed.
jnp.ones(3),
)
),
jnp.allclose(vector, 0.0),
# Return an identity rotation matrix when the input vector is zero.
jnp.eye(3),
theta_is_not_zero(axis=vector),
)
31 changes: 31 additions & 0 deletions src/jaxsim/math/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import jax.numpy as jnp

import jaxsim.typing as jtp


def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array:
"""
Provides a calculation for an array norm so that it is safe
to compute the gradient and handle NaNs.
Args:
array: The array for which to compute the norm.
axis: The axis for which to compute the norm.
Returns:
The norm of the array with handling for zero arrays to avoid NaNs.
"""

# Check if the entire array is composed of zeros.
is_zero = jnp.allclose(array, 0.0)

# Replace zeros with an array of ones temporarily to avoid division by zero.
# This ensures the computation of norm does not produce NaNs or Infs.
array = jnp.where(is_zero, jnp.ones_like(array), array)

# Compute the norm of the array along the specified axis.
norm = jnp.linalg.norm(array, axis=axis)

# Use `jnp.where` to set the norm to 0.0 where the input array was all zeros.
# This usage supports potential batch processing for future scalability.
return jnp.where(is_zero, 0.0, norm)
11 changes: 4 additions & 7 deletions src/jaxsim/rbda/contacts/soft.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,19 +309,16 @@ def hunt_crossley_contact_model(

# Compute the direction of the tangential force.
# To prevent dividing by zero, we use a switch statement.
# The ε, instead, is needed to make AD happy.
f_tangential_direction = jnp.where(
f_tangential.dot(f_tangential) != 0,
f_tangential / jnp.linalg.norm(f_tangential + ε),
jnp.zeros(3),
norm = jaxsim.math.safe_norm(f_tangential)
f_tangential_direction = f_tangential / (
norm + jnp.finfo(float).eps * (norm == 0)
)

# Project the tangential force to the friction cone if slipping.
f_tangential = jnp.where(
sticking,
f_tangential,
jnp.minimum(μ * force_normal_mag, jnp.linalg.norm(f_tangential + ε))
* f_tangential_direction,
jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction,
)

# Set the tangential force to zero if there is no contact.
Expand Down
3 changes: 2 additions & 1 deletion src/jaxsim/terrain/terrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax_dataclasses
import numpy as np

import jaxsim.math
import jaxsim.typing as jtp
from jaxsim import exceptions

Expand Down Expand Up @@ -41,7 +42,7 @@ def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
[(h_xm - h_xp) / (2 * self.delta), (h_ym - h_yp) / (2 * self.delta), 1.0]
)

return n / jnp.linalg.norm(n)
return n / jaxsim.math.safe_norm(n)


@jax_dataclasses.pytree_dataclass
Expand Down
43 changes: 43 additions & 0 deletions tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jax
import jax.numpy as jnp
import numpy as np
from jax.test_util import check_grads

import jaxsim.api as js
Expand Down Expand Up @@ -413,3 +414,45 @@ def step(
modes=["rev", "fwd"],
eps=ε,
)


def test_ad_safe_norm(
prng_key: jax.Array,
):

_, subkey = jax.random.split(prng_key, num=2)
array = jax.random.uniform(subkey, shape=(4,), minval=-5, maxval=5)

# ====
# Test
# ====

# Test that the safe_norm function is compatible with batching.
array = jnp.stack([array, array])
assert jaxsim.math.safe_norm(array, axis=1).shape == (2,)

# Test that the safe_norm function is correctly computing the norm.
assert np.allclose(jaxsim.math.safe_norm(array), np.linalg.norm(array))

# Function exposing only the parameters to be differentiated.
def safe_norm(array: jtp.Array) -> jtp.Array:

return jaxsim.math.safe_norm(array)

# Check derivatives against finite differences.
check_grads(
f=safe_norm,
args=(array,),
order=AD_ORDER,
modes=["rev", "fwd"],
eps=ε,
)

# Check derivatives against finite differences when the array is zero.
check_grads(
f=safe_norm,
args=(jnp.zeros_like(array),),
order=AD_ORDER,
modes=["rev", "fwd"],
eps=ε,
)

0 comments on commit ceb1fab

Please sign in to comment.