Skip to content

Commit

Permalink
Refactor Coriolis and remove vx star bar
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Nov 30, 2023
1 parent 104b9b1 commit 332e584
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 78 deletions.
4 changes: 4 additions & 0 deletions src/jaxsim/high_level/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,11 +945,15 @@ def com_position(self) -> jtp.Vector:
# Algorithms
# ==========

@functools.partial(oop.jax_tf.method_ro)
def coriolis_matrix(self) -> jtp.Matrix:
from jaxsim.physics.algos.coriolis import coriolis

H, H_dot, C = jaxsim.physics.algos.coriolis.coriolis(
model=self.physics_model,
q=self.data.model_state.joint_positions,
qd=self.data.model_state.joint_velocities,
xfb=self.data.model_state.xfb(),
)

return H, H_dot, C
Expand Down
10 changes: 0 additions & 10 deletions src/jaxsim/math/cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,3 @@ def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix:
def vx_star(velocity_sixd: jtp.Vector) -> jtp.Matrix:
v_cross_star = -Cross.vx(velocity_sixd).T
return v_cross_star

@staticmethod
def vx_star_bar(velocity_sixd: jtp.Vector) -> jtp.Matrix:
v_cross_star_bar = jnp.block(
[
[Skew.wedge(vector=velocity_sixd.squeeze()), jnp.zeros(shape=(3, 3))],
[jnp.zeros(shape=(3, 3)), Skew.wedge(vector=velocity_sixd.squeeze())],
]
)
return v_cross_star_bar
183 changes: 115 additions & 68 deletions src/jaxsim/physics/algos/coriolis.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,41 @@
from typing import Tuple

import jax
import jax.numpy as jnp
import numpy as np

import jaxsim
import jaxsim.typing as jtp
from jaxsim.math.adjoint import Adjoint
from jaxsim.math.cross import Cross
from jaxsim.physics.model.physics_model import PhysicsModel

from . import utils


def coriolis(model: PhysicsModel, q: jnp.ndarray, qd: jnp.ndarray) -> None:
def coriolis(
model: PhysicsModel,
q: jnp.ndarray,
qd: jnp.ndarray,
xfb: jtp.Vector,
) -> Tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
"""
Coriolis matrix
"""

(
x_fb,
q,
qd,
_,
_,
_,
) = utils.process_inputs(
physics_model=model,
xfb=xfb,
q=q,
qd=qd,
)

# Extract data from the physics model
pre_X_λi = model.tree_transforms
M = model.spatial_inertias
Expand All @@ -21,116 +46,138 @@ def coriolis(model: PhysicsModel, q: jnp.ndarray, qd: jnp.ndarray) -> None:
# Initialize buffers
v = jnp.array([jnp.zeros([6, 1])] * model.NB)
Sd = jnp.array([jnp.zeros([6, 1])] * model.NB)
BC = jnp.array([jnp.zeros([6, 1])] * model.NB)
IC = jnp.array([jnp.zeros([6, 1])] * model.NB)
Ic = jnp.zeros([6, 6])
Bc = jnp.zeros([6, 6])
BC = jnp.array([jnp.zeros([6, 6])] * model.NB)
IC = jnp.zeros_like(M)

i_X_λi = jnp.zeros_like(i_X_pre)

# Base pose B_X_W and velocity
base_quat = jnp.vstack(x_fb[0:4])
base_pos = jnp.vstack(x_fb[4:7])

# 6D transform of base velocity
B_X_W = Adjoint.from_quaternion_and_translation(
quaternion=base_quat,
translation=base_pos,
inverse=True,
normalize_quaternion=True,
)
i_X_λi = i_X_λi.at[0].set(B_X_W)

# Transforms link -> base
i_X_0 = jnp.zeros_like(pre_X_λi)
i_X_0 = i_X_0.at[0].set(jnp.eye(6))

Pass1Carry = Tuple[
jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
]

def loop_pass_1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:
i_X_λi, v, Sd, BC, IC = carry
vJ = S[i] * qd[i]
v_i = i_X_λi[i] @ v[λ[i]] + vJ
v = v.at[i].set(v_i)

Sd_i = Cross.vx(v[i]) @ S[i]
Sd = Sd.at[i].set(Sd_i)

IC = IC.at[i].set(MC[i])
IC = IC.at[i].set(M[i])
BC_i = (
Cross.vx_star(v[i]) @ Cross.vx_star_bar(IC[i] @ v[i])
- IC[i] @ Cross.vx(v[i])
Cross.vx_star(v[i]) @ Cross.vx(IC[i] @ v[i]) - IC[i] @ Cross.vx(v[i])
) / 2
BC = BC.at[i].set(BC_i)

return (i_X_λi, v, Sd, BC, IC), None

(i_X_λi, v, Sd, BC, IC), _ = jax.lax.scan(
loop_pass_1,
(i_X_λi, v, Sd, BC, IC),
jnp.arange(1, model.NB + 1),
f=loop_pass_1,
init=(i_X_λi, v, Sd, BC, IC),
xs=np.arange(1, model.NB + 1),
)

F_1 = jnp.zeros([6, 6])
F_2 = jnp.zeros([6, 6])
F_3 = jnp.zeros([6, 6])
C = jnp.zeros([model.NB, model.NB])
H = jnp.zeros([model.NB, model.NB])
Hd = jnp.zeros([model.NB, model.NB])

Pass2Carry = Tuple[
jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
]

def loop_pass_2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]:
ii = i - 1
i_X_λi, v, Sd, BC, IC = carry
def loop_pass_2(carry: Pass2Carry, j: jtp.Int) -> Tuple[Pass2Carry, None]:
jj = λ[j] - 1

# Compute parent-to-child transform
i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
C, H, Hd, IC, BC = carry

# Propagate link velocity
vJ = S[i] * qd[ii] * (qd.size != 0)
v_i = i_X_λi[i] @ v[λ[i]] + vJ
v = v.at[i].set(v_i)
F_1 = IC[j] @ Sd[j] + BC[j] @ S[j]
F_2 = IC[j] @ S[j]
F_3 = BC[j].T @ S[j]

Sd_i = Cross.vx(v[i]) @ S[i]
C = C.at[jj, jj].set((S[j].T @ F_1).squeeze())
H = H.at[jj, jj].set((S[j].T @ F_2).squeeze())
Hd = Hd.at[jj, jj].set((Sd[j].T @ F_2 + S[j].T @ F_3).squeeze())

IC = IC.at[i].set(MC[i])
BC_i = (
Cross.vx_star(v[i]) @ Cross.vx_star_bar(IC[i] @ v[i])
- IC[i] @ Cross.vx(v[i])
) / 2
F_1 = i_X_λi[j] @ F_1
F_2 = i_X_λi[j] @ F_2
F_3 = i_X_λi[j] @ F_3

InnerLoopCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax]
InnerLoopCarry = Tuple[
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
]

def inner_loop_body(
carry: InnerLoopCarry, i: jtp.Int
) -> Tuple[InnerLoopCarry, None]:
F_1 = i_X_λi[i] @ F_1
F_2 = i_X_λi[i] @ F_2
F_3 = i_X_λi[i] @ F_3
def inner_loop_body(carry: InnerLoopCarry) -> Tuple[InnerLoopCarry]:
C, H, Hd, F_1, F_2, F_3, i = carry
ii = λ[i] - 1

C_ij = S[i].T @ F_1
C_ji = (Sd[i].T @ F_2) + (S[i].T @ F_3).T
C = C.at[ii, jj].set((S[i].T @ F_1).squeeze())
C = C.at[jj, ii].set((S[i].T @ F_1).squeeze())

H_ij = S[i].T @ F_2
H_ji = H_ij.T
H = H.at[ii, ii].set((S[i].T @ F_2).squeeze())
Hd = Hd.at[ii].set((Sd[i].T @ F_2 + S[i].T @ F_3).squeeze())

Hd_ij = Sd[i].T @ F_2 + S[i].T @ (F_1 + F_3)
Hd_ji = Hd_ij.T
F_1 = F_1 + i_X_λi[i] @ F_1
F_2 = F_2 + i_X_λi[i] @ F_2
F_3 = F_3 + i_X_λi[i] @ F_3

i = λ[i]
return (F_1, F_2, F_3), None
return C, H, Hd, F_1, F_2, F_3, i

jax.lax.while_loop(
(C, H, Hd, F_1, F_2, F_3, _) = jax.lax.while_loop(
body_fun=inner_loop_body,
cond_fun=i > 0,
init_val=0,
cond_fun=lambda idx: idx[-1] > 0,
init_val=(C, H, Hd, F_1, F_2, F_3, 0),
)

Ic = Ic + i_X_λi[i] @ IC[i] @ i_X_λi[i].T
Bc = Bc + i_X_λi[i] @ BC[i] @ i_X_λi[i].T
return (i_X_λi, v, Sd, BC, IC), None
def propagate(
IC_BC: Tuple[jtp.MatrixJax, jtp.MatrixJax]
) -> Tuple[jtp.MatrixJax, jtp.MatrixJax]:
IC, BC = IC_BC

(i_X_λi, v, Sd, BC, IC), _ = jax.lax.scan(
loop_pass_2,
(i_X_λi, v, Sd, BC, IC),
jnp.arange(1, model.NB + 1),
)
IC = IC.at[λ[j]].set(IC[λ[j]] + i_X_λi[j] @ IC[j] @ i_X_λi[j].T)
BC = BC.at[λ[j]].set(BC[λ[j]] + i_X_λi[j] @ BC[j] @ i_X_λi[j].T)

return Ic, Bc
return IC, BC

IC, BC = jax.lax.cond(
pred=jnp.array([λ[j] != 0, model.is_floating_base]).any(),
true_fun=propagate,
false_fun=lambda IC_BC: IC_BC,
operand=(IC, BC),
)

# if __name__ == "__main__":
# import jax.numpy as jnp
# import jaxsim
# from jaxsim.high_level.model import Model
# from pathlib import Path
return (C, H, Hd, IC, BC), None

# urdf_path = Path(
# "/home/flferretti/git/element_rl-for-codesign/assets/model/Hopper.sdf"
# )
(C, H, Hd, IC, BC), _ = jax.lax.scan(
f=loop_pass_2,
init=(C, H, Hd, IC, BC),
xs=np.flip(np.arange(1, model.NB + 1)),
)

# model = Model.build_from_model_description(model_description=urdf_path)
assert jnp.allclose(Hd - (C @ C.T), jnp.zeros_like(Hd))

# with jax.disable_jit():
# H, H_dot, C = model.coriolis_matrix()
return H, Hd, C

0 comments on commit 332e584

Please sign in to comment.