From 69800af95e1e33d4d8180ecb0d6d636ff7422c73 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 25 Oct 2023 15:58:34 +0200 Subject: [PATCH] Revert `rnea` modification --- src/jaxsim/physics/algos/rnea.py | 74 +++++++------------------------- 1 file changed, 15 insertions(+), 59 deletions(-) diff --git a/src/jaxsim/physics/algos/rnea.py b/src/jaxsim/physics/algos/rnea.py index 2b85e9efa..09ef0d264 100644 --- a/src/jaxsim/physics/algos/rnea.py +++ b/src/jaxsim/physics/algos/rnea.py @@ -41,24 +41,6 @@ def rnea( S = model.motion_subspaces(q=q) i_X_λi = jnp.zeros_like(pre_X_λi) - Γ = ( - jnp.array(list(model._joint_motor_gear_ratio.values())) - if hasattr(model, "_joint_motor_gear_ratio") - else jnp.ones(model.dofs) - ) - IM = ( - jnp.array(list(model._joint_motor_inertia.values())) - if hasattr(model, "_joint_motor_inertia") - else jnp.zeros(model.dofs) - ) - K_v = ( - jnp.array(list(model._joint_motor_viscous_friction.values())) - if hasattr(model, "_joint_motor_viscous_friction") - else jnp.zeros(model.dofs) - ) - K̅ᵥ = jnp.diag(Γ.T * jnp.diag(K_v) * Γ) - S_m = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0) - i_X_0 = jnp.zeros_like(pre_X_λi) i_X_0 = i_X_0.at[0].set(jnp.eye(6)) @@ -70,10 +52,6 @@ def rnea( a = jnp.array([jnp.zeros([6, 1])] * model.NB) f = jnp.array([jnp.zeros([6, 1])] * model.NB) - v_m = jnp.array([jnp.zeros([6, 1])] * model.NB) - a_m = jnp.array([jnp.zeros([6, 1])] * model.NB) - f_m = jnp.array([jnp.zeros([6, 1])] * model.NB) - # 6D transform of base velocity B_X_W = Adjoint.from_quaternion_and_translation( quaternion=xfb[0:4], @@ -103,41 +81,26 @@ def rnea( f = f.at[0].set(f_0) ForwardPassCarry = Tuple[ - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, + jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax ] - forward_pass_carry = (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m) + forward_pass_carry = (i_X_λi, v, a, i_X_0, f) def forward_pass( carry: ForwardPassCarry, i: jtp.Int ) -> Tuple[ForwardPassCarry, None]: ii = i - 1 - i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m = carry + i_X_λi, v, a, i_X_0, f = carry vJ = S[i] * qd[ii] - vJ_m = S_m[i] * qd[ii] - i_X_λi_i = i_X_pre[i] @ pre_X_λi[i] i_X_λi = i_X_λi.at[i].set(i_X_λi_i) v_i = i_X_λi[i] @ v[λ[i]] + vJ v = v.at[i].set(v_i) - v_i_m = i_X_λi[i] @ v_m[λ[i]] + vJ_m - v_m = v_m.at[i].set(v_i_m) - a_i = i_X_λi[i] @ a[λ[i]] + S[i] * qdd[ii] + Cross.vx(v[i]) @ vJ a = a.at[i].set(a_i) - a_i_m = i_X_λi[i] @ a_m[λ[i]] + S_m[i] * qdd[ii] + Cross.vx(v_m[i]) @ vJ_m - a_m = a_m.at[i].set(a_i_m) - i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]] i_X_0 = i_X_0.at[i].set(i_X_0_i) i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T @@ -149,12 +112,9 @@ def forward_pass( ) f = f.at[i].set(f_i) - f_i_m = IM[i] * a_m[i] + Cross.vx_star(v_m[i]) * IM[i] @ v_m[i] - f_m = f_m.at[i].set(f_i_m) + return (i_X_λi, v, a, i_X_0, f), None - return (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), None - - (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), _ = jax.lax.scan( + (i_X_λi, v, a, i_X_0, f), _ = jax.lax.scan( f=forward_pass, init=forward_pass_carry, xs=np.arange(start=1, stop=model.NB), @@ -162,37 +122,33 @@ def forward_pass( tau = jnp.zeros_like(q) - BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax] - backward_pass_carry = (tau, f, f_m) + BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax] + backward_pass_carry = (tau, f) def backward_pass( carry: BackwardPassCarry, i: jtp.Int ) -> Tuple[BackwardPassCarry, None]: ii = i - 1 - tau, f, f_m = carry + tau, f = carry - value = S[i].T @ f[i] + S_m[i].T @ f_m[i] # + K̅ᵥ[i] * qd[ii] + value = S[i].T @ f[i] tau = tau.at[ii].set(value.squeeze()) - def update_f(ffm: Tuple[jtp.MatrixJax, jtp.MatrixJax]) -> jtp.MatrixJax: - f, f_m = ffm + def update_f(f: jtp.MatrixJax) -> jtp.MatrixJax: f_λi = f[λ[i]] + i_X_λi[i].T @ f[i] f = f.at[λ[i]].set(f_λi) + return f - f_m_λi = f_m[λ[i]] + i_X_λi[i].T @ f_m[i] - f_m = f_m.at[λ[i]].set(f_m_λi) - return f, f_m - - f, f_m = jax.lax.cond( + f = jax.lax.cond( pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(), true_fun=update_f, false_fun=lambda f: f, - operand=(f, f_m), + operand=f, ) - return (tau, f, f_m), None + return (tau, f), None - (tau, f, f_m), _ = jax.lax.scan( + (tau, f), _ = jax.lax.scan( f=backward_pass, init=backward_pass_carry, xs=np.flip(np.arange(start=1, stop=model.NB)),