Skip to content

Commit

Permalink
Correct typos in ABA and comment oop decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Sep 20, 2023
1 parent b0657e7 commit ec3d4a1
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
30 changes: 24 additions & 6 deletions src/jaxsim/high_level/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,8 +993,16 @@ def forward_dynamics_crb(
KV = jnp.diag(
jnp.array(list(self.physics_model._joint_motor_viscous_friction.values()))
)
# ! The following line raises ArrayImpl -> bool conversion error in JIT
Γ = jnp.diag(GR) if ((jnp.diag(IM) != 0).any()).item() else jnp.eye(GR.size)
Γ = jnp.diag(GR)

# Check on the motor parameters
jax.lax.cond(
pred=(jnp.diag(IM) != 0).any(),
operand=Γ,
true_fun=lambda Γ: jnp.diag(GR),
false_fun=lambda Γ: jnp.eye(GR.size),
)

Γ_inv = jnp.linalg.inv(Γ)

K̅ᵥ = Γ.T @ KV @ Γ
Expand All @@ -1013,7 +1021,7 @@ def forward_dynamics_crb(
# Add the motor related terms to the EoM
M = M.at[sl_m, sl_m].set(M[sl_m, sl_m] + Γ.T @ IM @ Γ)
h = h.at[sl_m].set(h[sl_m] + K̅ᵥ @ self.joint_velocities()[:, None])
S = S.at[sl_m].set(S[sl_m]) # + Γ_inv @ τ)
S = S.at[sl_m].set(S[sl_m])

# Compute the generalized acceleration by inverting the EoM
ν̇ = (
Expand Down Expand Up @@ -1346,7 +1354,7 @@ def integrate(
# Motor dynamics
# ==============

@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
# @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
def set_motor_inertias(
self, inertias: jtp.Vector, joint_names: List[str] = None
) -> None:
Expand All @@ -1367,7 +1375,7 @@ def set_motor_inertias(

logging.info("Setting attribute 'motor_inertias'")

@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
# @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
def set_motor_gear_ratios(
self, gear_ratios: jtp.Vector, joint_names: List[str] = None
) -> None:
Expand All @@ -1377,6 +1385,16 @@ def set_motor_gear_ratios(
if gear_ratios.size != len(joint_names):
raise ValueError("Wrong arguments size", gear_ratios.size, len(joint_names))

# Check on gear ratios if motor_inertias are not zero
jax.lax.cond(
pred=(jnp.diag(self.physics_model._joint_motor_inertia) != 0).any(),
operand=gear_ratios,
true_fun=lambda gr: gr,
false_fun=lambda: (_ for _ in _).throw(
ValueError("Motor inertias are zero")
),
)

self.physics_model._joint_motor_gear_ratio.update(
{
j: gr
Expand All @@ -1388,7 +1406,7 @@ def set_motor_gear_ratios(

logging.info("Setting attribute 'motor_gear_ratios'")

@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
# @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
def set_motor_viscous_frictions(
self, viscous_frictions: Tuple, joint_names: List[str] = None
) -> None:
Expand Down
14 changes: 7 additions & 7 deletions src/jaxsim/physics/algos/aba.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:

U = jnp.zeros_like(S)
m_U = jnp.zeros_like(S)
D = jnp.zeros(shape=(model.NB, 1))
d = jnp.zeros(shape=(model.NB, 1))
u = jnp.zeros(shape=(model.NB, 1))
m_u = jnp.zeros(shape=(model.NB, 1))

Expand Down Expand Up @@ -189,18 +189,18 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]:
m_U_i = IM[i] @ m_S[i]
m_U = m_U.at[i].set(m_U_i)

D_i = S[i].T @ MA[i] @ S[i] + m_S[i].T @ IM[i] @ m_S[i]
D = D.at[i].set(D_i.squeeze())
d_i = S[i].T @ MA[i] @ S[i] + m_S[i].T @ IM[i] @ m_S[i]
d = d.at[i].set(d_i.squeeze())

# Compute the articulated-body inertia and bias forces of this link
Ma = MA[i] + IM[i] - U[i] / D[i] @ U[i].T - m_U[i] / D[i] @ m_U[i].T
Ma = MA[i] + IM[i] - U[i] / d[i] @ U[i].T - m_U[i] / d[i] @ m_U[i].T
pa = (
pA[i]
+ pR[i]
+ Ma[i] @ c[i]
+ IM[i] @ m_c[i]
+ U[i] / D[i] * u[i]
+ m_U[i] / D[i] * m_u[i]
+ U[i] / d[i] * u[i]
+ m_U[i] / d[i] * m_u[i]
)

# Propagate them to the parent, handling the base link
Expand Down Expand Up @@ -253,7 +253,7 @@ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> Tuple[Pass3Carry, None]:
a_i = i_X_λi[i] @ a[λ[i]]

# Compute joint accelerations
qdd_ii = (u[i] + m_u[i] - (U[i].T + m_U[i].T) @ a_i) / D[i]
qdd_ii = (u[i] + m_u[i] - (U[i].T + m_U[i].T) @ a_i) / d[i]
qdd = qdd.at[ii].set(qdd_ii.squeeze()) if qdd.size != 0 else qdd

a_i = a_i + S[i] * qdd[ii] + c[i] if qdd.size != 0 else a_i
Expand Down

0 comments on commit ec3d4a1

Please sign in to comment.