From b9e7b70ddd6c55da80aa085df483d1ca70d752f3 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 30 Nov 2023 09:38:56 +0100 Subject: [PATCH] Fix lax.scan loops --- src/jaxsim/physics/algos/coriolis.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/physics/algos/coriolis.py b/src/jaxsim/physics/algos/coriolis.py index 9ef840bc4..2d853b77c 100644 --- a/src/jaxsim/physics/algos/coriolis.py +++ b/src/jaxsim/physics/algos/coriolis.py @@ -26,6 +26,10 @@ def coriolis(model: PhysicsModel, q: jnp.ndarray, qd: jnp.ndarray) -> None: Ic = jnp.zeros([6, 6]) Bc = jnp.zeros([6, 6]) + Pass1Carry = Tuple[ + jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax + ] + def loop_pass_1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]: vJ = S[i] * qd[i] v_i = i_X_λi[i] @ v[λ[i]] + vJ @@ -41,16 +45,20 @@ def loop_pass_1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]: return (i_X_λi, v, Sd, BC, IC), None - jax.lax.scan( + (i_X_λi, v, Sd, BC, IC), _ = jax.lax.scan( loop_pass_1, (i_X_λi, v, Sd, BC, IC), - jnp.arange(1, n + 1), + jnp.arange(1, model.NB + 1), ) F_1 = jnp.zeros([6, 6]) F_2 = jnp.zeros([6, 6]) F_3 = jnp.zeros([6, 6]) + Pass2Carry = Tuple[ + jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax + ] + def loop_pass_2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]: ii = i - 1 i_X_λi, v, Sd, BC, IC = carry @@ -103,10 +111,10 @@ def inner_loop_body( Bc = Bc + i_X_λi[i] @ BC[i] @ i_X_λi[i].T return (i_X_λi, v, Sd, BC, IC), None - jax.lax.scan( + (i_X_λi, v, Sd, BC, IC), _ = jax.lax.scan( loop_pass_2, (i_X_λi, v, Sd, BC, IC), - jnp.arange(1, n + 1), + jnp.arange(1, model.NB + 1), ) return Ic, Bc