Skip to content

Commit

Permalink
Revert rnea modification
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Oct 25, 2023
1 parent 6a78dc6 commit 69800af
Showing 1 changed file with 15 additions and 59 deletions.
74 changes: 15 additions & 59 deletions src/jaxsim/physics/algos/rnea.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,6 @@ def rnea(
S = model.motion_subspaces(q=q)
i_X_λi = jnp.zeros_like(pre_X_λi)

Γ = (
jnp.array(list(model._joint_motor_gear_ratio.values()))
if hasattr(model, "_joint_motor_gear_ratio")
else jnp.ones(model.dofs)
)
IM = (
jnp.array(list(model._joint_motor_inertia.values()))
if hasattr(model, "_joint_motor_inertia")
else jnp.zeros(model.dofs)
)
K_v = (
jnp.array(list(model._joint_motor_viscous_friction.values()))
if hasattr(model, "_joint_motor_viscous_friction")
else jnp.zeros(model.dofs)
)
K̅ᵥ = jnp.diag(Γ.T * jnp.diag(K_v) * Γ)
S_m = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0)

i_X_0 = jnp.zeros_like(pre_X_λi)
i_X_0 = i_X_0.at[0].set(jnp.eye(6))

Expand All @@ -70,10 +52,6 @@ def rnea(
a = jnp.array([jnp.zeros([6, 1])] * model.NB)
f = jnp.array([jnp.zeros([6, 1])] * model.NB)

v_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
a_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
f_m = jnp.array([jnp.zeros([6, 1])] * model.NB)

# 6D transform of base velocity
B_X_W = Adjoint.from_quaternion_and_translation(
quaternion=xfb[0:4],
Expand Down Expand Up @@ -103,41 +81,26 @@ def rnea(
f = f.at[0].set(f_0)

ForwardPassCarry = Tuple[
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
]
forward_pass_carry = (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m)
forward_pass_carry = (i_X_λi, v, a, i_X_0, f)

def forward_pass(
carry: ForwardPassCarry, i: jtp.Int
) -> Tuple[ForwardPassCarry, None]:
ii = i - 1
i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m = carry
i_X_λi, v, a, i_X_0, f = carry

vJ = S[i] * qd[ii]
vJ_m = S_m[i] * qd[ii]

i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)

v_i = i_X_λi[i] @ v[λ[i]] + vJ
v = v.at[i].set(v_i)

v_i_m = i_X_λi[i] @ v_m[λ[i]] + vJ_m
v_m = v_m.at[i].set(v_i_m)

a_i = i_X_λi[i] @ a[λ[i]] + S[i] * qdd[ii] + Cross.vx(v[i]) @ vJ
a = a.at[i].set(a_i)

a_i_m = i_X_λi[i] @ a_m[λ[i]] + S_m[i] * qdd[ii] + Cross.vx(v_m[i]) @ vJ_m
a_m = a_m.at[i].set(a_i_m)

i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
i_X_0 = i_X_0.at[i].set(i_X_0_i)
i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
Expand All @@ -149,50 +112,43 @@ def forward_pass(
)
f = f.at[i].set(f_i)

f_i_m = IM[i] * a_m[i] + Cross.vx_star(v_m[i]) * IM[i] @ v_m[i]
f_m = f_m.at[i].set(f_i_m)
return (i_X_λi, v, a, i_X_0, f), None

return (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), None

(i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), _ = jax.lax.scan(
(i_X_λi, v, a, i_X_0, f), _ = jax.lax.scan(
f=forward_pass,
init=forward_pass_carry,
xs=np.arange(start=1, stop=model.NB),
)

tau = jnp.zeros_like(q)

BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax]
backward_pass_carry = (tau, f, f_m)
BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
backward_pass_carry = (tau, f)

def backward_pass(
carry: BackwardPassCarry, i: jtp.Int
) -> Tuple[BackwardPassCarry, None]:
ii = i - 1
tau, f, f_m = carry
tau, f = carry

value = S[i].T @ f[i] + S_m[i].T @ f_m[i] # + K̅ᵥ[i] * qd[ii]
value = S[i].T @ f[i]
tau = tau.at[ii].set(value.squeeze())

def update_f(ffm: Tuple[jtp.MatrixJax, jtp.MatrixJax]) -> jtp.MatrixJax:
f, f_m = ffm
def update_f(f: jtp.MatrixJax) -> jtp.MatrixJax:
f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
f = f.at[λ[i]].set(f_λi)
return f

f_m_λi = f_m[λ[i]] + i_X_λi[i].T @ f_m[i]
f_m = f_m.at[λ[i]].set(f_m_λi)
return f, f_m

f, f_m = jax.lax.cond(
f = jax.lax.cond(
pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(),
true_fun=update_f,
false_fun=lambda f: f,
operand=(f, f_m),
operand=f,
)

return (tau, f, f_m), None
return (tau, f), None

(tau, f, f_m), _ = jax.lax.scan(
(tau, f), _ = jax.lax.scan(
f=backward_pass,
init=backward_pass_carry,
xs=np.flip(np.arange(start=1, stop=model.NB)),
Expand Down

0 comments on commit 69800af

Please sign in to comment.