Skip to content

Commit

Permalink
Fix lax.scan loops
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Nov 30, 2023
1 parent ab1b3b8 commit 6da12a0
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/jaxsim/physics/algos/coriolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def coriolis(model: PhysicsModel, q: jnp.ndarray, qd: jnp.ndarray) -> None:
Ic = jnp.zeros([6, 6])
Bc = jnp.zeros([6, 6])

Pass1Carry = Tuple[
jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
]

def loop_pass_1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:
vJ = S[i] * qd[i]
v_i = i_X_λi[i] @ v[λ[i]] + vJ
Expand All @@ -41,16 +45,20 @@ def loop_pass_1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:

return (i_X_λi, v, Sd, BC, IC), None

jax.lax.scan(
(i_X_λi, v, Sd, BC, IC), _ = jax.lax.scan(
loop_pass_1,
(i_X_λi, v, Sd, BC, IC),
jnp.arange(1, n + 1),
jnp.arange(1, model.NB + 1),
)

F_1 = jnp.zeros([6, 6])
F_2 = jnp.zeros([6, 6])
F_3 = jnp.zeros([6, 6])

Pass2Carry = Tuple[
jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
]

def loop_pass_2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]:
ii = i - 1
i_X_λi, v, Sd, BC, IC = carry
Expand Down Expand Up @@ -103,10 +111,10 @@ def inner_loop_body(
Bc = Bc + i_X_λi[i] @ BC[i] @ i_X_λi[i].T
return (i_X_λi, v, Sd, BC, IC), None

jax.lax.scan(
(i_X_λi, v, Sd, BC, IC), _ = jax.lax.scan(
loop_pass_2,
(i_X_λi, v, Sd, BC, IC),
jnp.arange(1, n + 1),
jnp.arange(1, model.NB + 1),
)

return Ic, Bc

0 comments on commit 6da12a0

Please sign in to comment.