From 332e58477d77a0432ecab81e2616eb8b7b39b756 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 30 Nov 2023 12:02:54 +0100 Subject: [PATCH] Refactor Coriolis and remove vx star bar --- src/jaxsim/high_level/model.py | 4 + src/jaxsim/math/cross.py | 10 -- src/jaxsim/physics/algos/coriolis.py | 183 +++++++++++++++++---------- 3 files changed, 119 insertions(+), 78 deletions(-) diff --git a/src/jaxsim/high_level/model.py b/src/jaxsim/high_level/model.py index 2582d7df1..cd92b7646 100644 --- a/src/jaxsim/high_level/model.py +++ b/src/jaxsim/high_level/model.py @@ -945,11 +945,15 @@ def com_position(self) -> jtp.Vector: # Algorithms # ========== + @functools.partial(oop.jax_tf.method_ro) def coriolis_matrix(self) -> jtp.Matrix: + from jaxsim.physics.algos.coriolis import coriolis + H, H_dot, C = jaxsim.physics.algos.coriolis.coriolis( model=self.physics_model, q=self.data.model_state.joint_positions, qd=self.data.model_state.joint_velocities, + xfb=self.data.model_state.xfb(), ) return H, H_dot, C diff --git a/src/jaxsim/math/cross.py b/src/jaxsim/math/cross.py index 812f0d96f..ff4195300 100644 --- a/src/jaxsim/math/cross.py +++ b/src/jaxsim/math/cross.py @@ -23,13 +23,3 @@ def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix: def vx_star(velocity_sixd: jtp.Vector) -> jtp.Matrix: v_cross_star = -Cross.vx(velocity_sixd).T return v_cross_star - - @staticmethod - def vx_star_bar(velocity_sixd: jtp.Vector) -> jtp.Matrix: - v_cross_star_bar = jnp.block( - [ - [Skew.wedge(vector=velocity_sixd.squeeze()), jnp.zeros(shape=(3, 3))], - [jnp.zeros(shape=(3, 3)), Skew.wedge(vector=velocity_sixd.squeeze())], - ] - ) - return v_cross_star_bar diff --git a/src/jaxsim/physics/algos/coriolis.py b/src/jaxsim/physics/algos/coriolis.py index 167d42ec0..210545b45 100644 --- a/src/jaxsim/physics/algos/coriolis.py +++ b/src/jaxsim/physics/algos/coriolis.py @@ -1,16 +1,41 @@ +from typing import Tuple + import jax import jax.numpy as jnp +import numpy as np -import jaxsim +import jaxsim.typing as jtp +from jaxsim.math.adjoint import Adjoint from jaxsim.math.cross import Cross from jaxsim.physics.model.physics_model import PhysicsModel +from . import utils + -def coriolis(model: PhysicsModel, q: jnp.ndarray, qd: jnp.ndarray) -> None: +def coriolis( + model: PhysicsModel, + q: jnp.ndarray, + qd: jnp.ndarray, + xfb: jtp.Vector, +) -> Tuple[jtp.Vector, jtp.Vector, jtp.Vector]: """ Coriolis matrix """ + ( + x_fb, + q, + qd, + _, + _, + _, + ) = utils.process_inputs( + physics_model=model, + xfb=xfb, + q=q, + qd=qd, + ) + # Extract data from the physics model pre_X_λi = model.tree_transforms M = model.spatial_inertias @@ -21,116 +46,138 @@ def coriolis(model: PhysicsModel, q: jnp.ndarray, qd: jnp.ndarray) -> None: # Initialize buffers v = jnp.array([jnp.zeros([6, 1])] * model.NB) Sd = jnp.array([jnp.zeros([6, 1])] * model.NB) - BC = jnp.array([jnp.zeros([6, 1])] * model.NB) - IC = jnp.array([jnp.zeros([6, 1])] * model.NB) - Ic = jnp.zeros([6, 6]) - Bc = jnp.zeros([6, 6]) + BC = jnp.array([jnp.zeros([6, 6])] * model.NB) + IC = jnp.zeros_like(M) + + i_X_λi = jnp.zeros_like(i_X_pre) + + # Base pose B_X_W and velocity + base_quat = jnp.vstack(x_fb[0:4]) + base_pos = jnp.vstack(x_fb[4:7]) + + # 6D transform of base velocity + B_X_W = Adjoint.from_quaternion_and_translation( + quaternion=base_quat, + translation=base_pos, + inverse=True, + normalize_quaternion=True, + ) + i_X_λi = i_X_λi.at[0].set(B_X_W) + + # Transforms link -> base + i_X_0 = jnp.zeros_like(pre_X_λi) + i_X_0 = i_X_0.at[0].set(jnp.eye(6)) Pass1Carry = Tuple[ jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax ] def loop_pass_1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]: + i_X_λi, v, Sd, BC, IC = carry vJ = S[i] * qd[i] v_i = i_X_λi[i] @ v[λ[i]] + vJ v = v.at[i].set(v_i) Sd_i = Cross.vx(v[i]) @ S[i] + Sd = Sd.at[i].set(Sd_i) - IC = IC.at[i].set(MC[i]) + IC = IC.at[i].set(M[i]) BC_i = ( - Cross.vx_star(v[i]) @ Cross.vx_star_bar(IC[i] @ v[i]) - - IC[i] @ Cross.vx(v[i]) + Cross.vx_star(v[i]) @ Cross.vx(IC[i] @ v[i]) - IC[i] @ Cross.vx(v[i]) ) / 2 + BC = BC.at[i].set(BC_i) return (i_X_λi, v, Sd, BC, IC), None (i_X_λi, v, Sd, BC, IC), _ = jax.lax.scan( - loop_pass_1, - (i_X_λi, v, Sd, BC, IC), - jnp.arange(1, model.NB + 1), + f=loop_pass_1, + init=(i_X_λi, v, Sd, BC, IC), + xs=np.arange(1, model.NB + 1), ) - F_1 = jnp.zeros([6, 6]) - F_2 = jnp.zeros([6, 6]) - F_3 = jnp.zeros([6, 6]) + C = jnp.zeros([model.NB, model.NB]) + H = jnp.zeros([model.NB, model.NB]) + Hd = jnp.zeros([model.NB, model.NB]) Pass2Carry = Tuple[ jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax ] - def loop_pass_2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]: - ii = i - 1 - i_X_λi, v, Sd, BC, IC = carry + def loop_pass_2(carry: Pass2Carry, j: jtp.Int) -> Tuple[Pass2Carry, None]: + jj = λ[j] - 1 - # Compute parent-to-child transform - i_X_λi_i = i_X_pre[i] @ pre_X_λi[i] - i_X_λi = i_X_λi.at[i].set(i_X_λi_i) + C, H, Hd, IC, BC = carry - # Propagate link velocity - vJ = S[i] * qd[ii] * (qd.size != 0) - v_i = i_X_λi[i] @ v[λ[i]] + vJ - v = v.at[i].set(v_i) + F_1 = IC[j] @ Sd[j] + BC[j] @ S[j] + F_2 = IC[j] @ S[j] + F_3 = BC[j].T @ S[j] - Sd_i = Cross.vx(v[i]) @ S[i] + C = C.at[jj, jj].set((S[j].T @ F_1).squeeze()) + H = H.at[jj, jj].set((S[j].T @ F_2).squeeze()) + Hd = Hd.at[jj, jj].set((Sd[j].T @ F_2 + S[j].T @ F_3).squeeze()) - IC = IC.at[i].set(MC[i]) - BC_i = ( - Cross.vx_star(v[i]) @ Cross.vx_star_bar(IC[i] @ v[i]) - - IC[i] @ Cross.vx(v[i]) - ) / 2 + F_1 = i_X_λi[j] @ F_1 + F_2 = i_X_λi[j] @ F_2 + F_3 = i_X_λi[j] @ F_3 - InnerLoopCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax] + InnerLoopCarry = Tuple[ + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + ] - def inner_loop_body( - carry: InnerLoopCarry, i: jtp.Int - ) -> Tuple[InnerLoopCarry, None]: - F_1 = i_X_λi[i] @ F_1 - F_2 = i_X_λi[i] @ F_2 - F_3 = i_X_λi[i] @ F_3 + def inner_loop_body(carry: InnerLoopCarry) -> Tuple[InnerLoopCarry]: + C, H, Hd, F_1, F_2, F_3, i = carry + ii = λ[i] - 1 - C_ij = S[i].T @ F_1 - C_ji = (Sd[i].T @ F_2) + (S[i].T @ F_3).T + C = C.at[ii, jj].set((S[i].T @ F_1).squeeze()) + C = C.at[jj, ii].set((S[i].T @ F_1).squeeze()) - H_ij = S[i].T @ F_2 - H_ji = H_ij.T + H = H.at[ii, ii].set((S[i].T @ F_2).squeeze()) + Hd = Hd.at[ii].set((Sd[i].T @ F_2 + S[i].T @ F_3).squeeze()) - Hd_ij = Sd[i].T @ F_2 + S[i].T @ (F_1 + F_3) - Hd_ji = Hd_ij.T + F_1 = F_1 + i_X_λi[i] @ F_1 + F_2 = F_2 + i_X_λi[i] @ F_2 + F_3 = F_3 + i_X_λi[i] @ F_3 i = λ[i] - return (F_1, F_2, F_3), None + return C, H, Hd, F_1, F_2, F_3, i - jax.lax.while_loop( + (C, H, Hd, F_1, F_2, F_3, _) = jax.lax.while_loop( body_fun=inner_loop_body, - cond_fun=i > 0, - init_val=0, + cond_fun=lambda idx: idx[-1] > 0, + init_val=(C, H, Hd, F_1, F_2, F_3, 0), ) - Ic = Ic + i_X_λi[i] @ IC[i] @ i_X_λi[i].T - Bc = Bc + i_X_λi[i] @ BC[i] @ i_X_λi[i].T - return (i_X_λi, v, Sd, BC, IC), None + def propagate( + IC_BC: Tuple[jtp.MatrixJax, jtp.MatrixJax] + ) -> Tuple[jtp.MatrixJax, jtp.MatrixJax]: + IC, BC = IC_BC - (i_X_λi, v, Sd, BC, IC), _ = jax.lax.scan( - loop_pass_2, - (i_X_λi, v, Sd, BC, IC), - jnp.arange(1, model.NB + 1), - ) + IC = IC.at[λ[j]].set(IC[λ[j]] + i_X_λi[j] @ IC[j] @ i_X_λi[j].T) + BC = BC.at[λ[j]].set(BC[λ[j]] + i_X_λi[j] @ BC[j] @ i_X_λi[j].T) - return Ic, Bc + return IC, BC + IC, BC = jax.lax.cond( + pred=jnp.array([λ[j] != 0, model.is_floating_base]).any(), + true_fun=propagate, + false_fun=lambda IC_BC: IC_BC, + operand=(IC, BC), + ) -# if __name__ == "__main__": -# import jax.numpy as jnp -# import jaxsim -# from jaxsim.high_level.model import Model -# from pathlib import Path + return (C, H, Hd, IC, BC), None -# urdf_path = Path( -# "/home/flferretti/git/element_rl-for-codesign/assets/model/Hopper.sdf" -# ) + (C, H, Hd, IC, BC), _ = jax.lax.scan( + f=loop_pass_2, + init=(C, H, Hd, IC, BC), + xs=np.flip(np.arange(1, model.NB + 1)), + ) -# model = Model.build_from_model_description(model_description=urdf_path) + assert jnp.allclose(Hd - (C @ C.T), jnp.zeros_like(Hd)) -# with jax.disable_jit(): -# H, H_dot, C = model.coriolis_matrix() + return H, Hd, C