Skip to content

Commit

Permalink
Add function to safely compute the norm
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 16, 2024
1 parent fe7998a commit 30472ed
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 36 deletions.
12 changes: 4 additions & 8 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,8 @@ def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix
# we introduce a Baumgarte stabilization to let the quaternion converge to
# a unit quaternion. In this case, it is not guaranteed that the quaternion
# stored in the state is a unit quaternion.
W_Q_B = jnp.where(
jnp.allclose(W_Q_B.dot(W_Q_B), 1.0), W_Q_B, W_Q_B / jnp.linalg.norm(W_Q_B)
)
norm = jaxsim.math.safe_norm(W_Q_B)
W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))

return (W_Q_B if not dcm else jaxsim.math.Quaternion.to_dcm(W_Q_B)).astype(
float
Expand Down Expand Up @@ -611,11 +610,8 @@ def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:

W_Q_B = jnp.array(base_quaternion, dtype=float)

W_Q_B = jax.lax.select(
pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
on_true=W_Q_B,
on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
)
norm = jaxsim.math.safe_norm(W_Q_B)
W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))

return self.replace(
validate=True,
Expand Down
1 change: 1 addition & 0 deletions src/jaxsim/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
from .rotation import Rotation
from .skew import Skew
from .transform import Transform
from .utils import safe_norm

from .joint_model import JointModel, supported_joint_motion # isort:skip
11 changes: 4 additions & 7 deletions src/jaxsim/math/quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import jaxsim.typing as jtp

from .utils import safe_norm


class Quaternion:
@staticmethod
Expand Down Expand Up @@ -111,18 +113,13 @@ def Q_inertial(q: jtp.Vector) -> jtp.Matrix:
operand=quaternion,
)

norm_ω = jax.lax.cond(
pred=ω.dot(ω) < (1e-6) ** 2,
true_fun=lambda _: 1e-6,
false_fun=lambda _: jnp.linalg.norm(ω),
operand=None,
)
norm_ω = safe_norm(ω)

qd = 0.5 * (
Q
@ jnp.hstack(
[
K * norm_ω * (1 - jnp.linalg.norm(quaternion)),
K * norm_ω * (1 - safe_norm(quaternion)),
ω,
]
)
Expand Down
17 changes: 4 additions & 13 deletions src/jaxsim/math/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jaxsim.typing as jtp

from .skew import Skew
from .utils import safe_norm


class Rotation:
Expand Down Expand Up @@ -67,7 +68,7 @@ def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix:
def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix:

v = axis
theta = jnp.linalg.norm(v)
theta = safe_norm(v)

s = jnp.sin(theta)
c = jnp.cos(theta)
Expand All @@ -81,19 +82,9 @@ def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix:

return R.transpose()

# Use the double-where trick to prevent JAX problems when the
# jax.jit and jax.grad transforms are applied.
return jnp.where(
jnp.linalg.norm(vector) > 0,
theta_is_not_zero(
axis=jnp.where(
jnp.linalg.norm(vector) > 0,
vector,
# The following line is a workaround to prevent division by 0.
# Considering the outer where, this branch is never executed.
jnp.ones(3),
)
),
jnp.allclose(vector, 0.0),
# Return an identity rotation matrix when the input vector is zero.
jnp.eye(3),
theta_is_not_zero(axis=vector),
)
19 changes: 19 additions & 0 deletions src/jaxsim/math/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import jax.numpy as jnp

import jaxsim.typing as jtp


def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array:
"""
Provides a calculation for an array norm so that it is safe
to compute the gradient and the NaNs are handled.
Args:
array: The array for which to compute the norm
axis: The axis for which to compute the norm
"""
is_zero = jnp.allclose(array, 0.0)
array = jnp.where(is_zero, jnp.ones_like(array), array)

norm = jnp.linalg.norm(array, axis=axis)
return jnp.where(is_zero, 0.0, norm)
11 changes: 4 additions & 7 deletions src/jaxsim/rbda/contacts/soft.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,19 +309,16 @@ def hunt_crossley_contact_model(

# Compute the direction of the tangential force.
# To prevent dividing by zero, we use a switch statement.
# The ε, instead, is needed to make AD happy.
f_tangential_direction = jnp.where(
f_tangential.dot(f_tangential) != 0,
f_tangential / jnp.linalg.norm(f_tangential + ε),
jnp.zeros(3),
norm = jaxsim.math.safe_norm(f_tangential)
f_tangential_direction = f_tangential / (
norm + jnp.finfo(float).eps * (norm == 0)
)

# Project the tangential force to the friction cone if slipping.
f_tangential = jnp.where(
sticking,
f_tangential,
jnp.minimum(μ * force_normal_mag, jnp.linalg.norm(f_tangential + ε))
* f_tangential_direction,
jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction,
)

# Set the tangential force to zero if there is no contact.
Expand Down
3 changes: 2 additions & 1 deletion src/jaxsim/terrain/terrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax_dataclasses
import numpy as np

import jaxsim.math
import jaxsim.typing as jtp
from jaxsim import exceptions

Expand Down Expand Up @@ -41,7 +42,7 @@ def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
[(h_xm - h_xp) / (2 * self.delta), (h_ym - h_yp) / (2 * self.delta), 1.0]
)

return n / jnp.linalg.norm(n)
return n / jaxsim.math.safe_norm(n)


@jax_dataclasses.pytree_dataclass
Expand Down

0 comments on commit 30472ed

Please sign in to comment.