Skip to content

Commit

Permalink
Refactor and add viscous frictions in ABA
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Oct 25, 2023
1 parent 69800af commit c1c2b45
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 18 deletions.
9 changes: 3 additions & 6 deletions src/jaxsim/high_level/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,15 +1138,12 @@ def forward_dynamics_crb(
τ = jnp.vstack(τ) if τ.size > 0 else jnp.empty(shape=(0, 1))

# Extract motor parameters from the physics model
GR = jnp.array(list(self.physics_model._joint_motor_gear_ratio.values()))
IM = jnp.diag(jnp.array(list(self.physics_model._joint_motor_inertia.values())))
KV = jnp.diag(
jnp.array(list(self.physics_model._joint_motor_viscous_friction.values()))
)
GR = self.motor_gear_ratios()
IM = self.motor_inertias()
KV = jnp.diag(self.motor_viscous_frictions())

# Compute auxiliary quantities
Γ = jnp.diag(GR)
Γ_inv = jnp.linalg.inv(Γ)
K̅ᵥ = Γ.T @ KV @ Γ

# Compute terms of the floating-base EoM
Expand Down
24 changes: 12 additions & 12 deletions src/jaxsim/physics/algos/aba.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,11 @@ def aba(
λ = model.parent_array()

# Extract motor parameters from the physics model
Γ = jnp.array(list(model._joint_motor_gear_ratio.values()))
Γ = jnp.array([*model._joint_motor_gear_ratio.values()])
IM = jnp.array(
[jnp.eye(6) * m for m in list(model._joint_motor_inertia.values())] * model.NB
)
K̅ᵥ = jnp.diag(
Γ.T
@ jnp.diag(jnp.array(list(model._joint_motor_viscous_friction.values())))
/ Γ
[jnp.eye(6) * m for m in [*model._joint_motor_inertia.values()]] * model.NB
)
K̅ᵥ = Γ.T * jnp.array([*model._joint_motor_viscous_friction.values()]) / Γ
m_S = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0)

# Initialize buffers
Expand Down Expand Up @@ -137,10 +133,10 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:
i_X_0 = i_X_0.at[i].set(i_X_0_i)
i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T

pA_i = Cross.vx_star(v[i]) @ MA[i] @ v[i] - i_Xf_W @ jnp.vstack(f_ext[i])
pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(f_ext[i])
pA = pA.at[i].set(pA_i)

pR_i = Cross.vx_star(m_v[i]) @ IM[i] @ m_v[i]
pR_i = Cross.vx_star(m_v[i]) @ IM[i] @ m_v[i] + K̅ᵥ[i] * m_v[i]
pR = pR.at[i].set(pR_i)

return (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0), None
Expand Down Expand Up @@ -178,8 +174,12 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]:
u_i = tau[ii] - S[i].T @ pA[i] if tau.size != 0 else -S[i].T @ pA[i]
u = u.at[i].set(u_i.squeeze())

has_motors = (pR[i] != 0).any()

m_u_i = (
tau[ii] / Γ[i] - m_S[i].T @ pR[i] if tau.size != 0 else -m_S[i].T @ pR[i]
tau[ii] / Γ[i] * has_motors - m_S[i].T @ pR[i]
if tau.size != 0
else -m_S[i].T @ pR[i]
)
m_u = m_u.at[i].set(m_u_i.squeeze())

Expand Down Expand Up @@ -250,13 +250,13 @@ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> Tuple[Pass3Carry, None]:
a, qdd = carry

# Propagate link accelerations
a_i = i_X_λi[i] @ a[λ[i]]
a_i = i_X_λi[i] @ a[λ[i]] + c[i]

# Compute joint accelerations
qdd_ii = (u[i] + m_u[i] - (U[i].T + m_U[i].T) @ a_i) / d[i]
qdd = qdd.at[ii].set(qdd_ii.squeeze()) if qdd.size != 0 else qdd

a_i = a_i + S[i] * qdd[ii] + c[i] if qdd.size != 0 else a_i
a_i = a_i + S[i] * qdd[ii] if qdd.size != 0 else a_i
a = a.at[i].set(a_i)

return (a, qdd), None
Expand Down

0 comments on commit c1c2b45

Please sign in to comment.