From 258d60f8b3f1c58ba195e2914ebc58bb8059e1a1 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 14 Aug 2024 10:28:10 +0200 Subject: [PATCH 1/2] Allow models with a single link in RBDAs with `JAX_DISABLE_JIT` --- src/jaxsim/rbda/crba.py | 43 +++++++++++++++++++-------- src/jaxsim/rbda/forward_kinematics.py | 12 +++++--- src/jaxsim/rbda/jacobian.py | 36 ++++++++++++++-------- 3 files changed, 63 insertions(+), 28 deletions(-) diff --git a/src/jaxsim/rbda/crba.py b/src/jaxsim/rbda/crba.py index 904048832..74088b54b 100644 --- a/src/jaxsim/rbda/crba.py +++ b/src/jaxsim/rbda/crba.py @@ -59,10 +59,14 @@ def propagate_kinematics( return (i_X_0,), None - (i_X_0,), _ = jax.lax.scan( - f=propagate_kinematics, - init=forward_pass_carry, - xs=jnp.arange(start=1, stop=model.number_of_links()), + (i_X_0,), _ = ( + jax.lax.scan( + f=propagate_kinematics, + init=forward_pass_carry, + xs=jnp.arange(start=1, stop=model.number_of_links()), + ) + if model.number_of_links() > 1 + else [(i_X_0,), None] ) # =================== @@ -128,10 +132,21 @@ def compute_inner(carry: CarryInnerFn) -> tuple[CarryInnerFn, None]: operand=carry, ) - (j, Fi, M), _ = jax.lax.scan( - f=inner_fn, - init=carry_inner_fn, - xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), + (j, Fi, M), _ = ( + jax.lax.scan( + f=inner_fn, + init=carry_inner_fn, + xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), + ) + if model.number_of_links() > 1 + else [ + ( + j, + Fi, + M, + ), + None, + ] ) Fi = i_X_0[j].T @ Fi @@ -143,10 +158,14 @@ def compute_inner(carry: CarryInnerFn) -> tuple[CarryInnerFn, None]: # This scan performs the backward pass to compute Mbj, Mjb and Mjj, that # also includes a fake while loop implemented with a scan and two cond. - (Mc, M), _ = jax.lax.scan( - f=backward_pass, - init=backward_pass_carry, - xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), + (Mc, M), _ = ( + jax.lax.scan( + f=backward_pass, + init=backward_pass_carry, + xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), + ) + if model.number_of_links() > 1 + else [(Mc, M), None] ) # Store the locked 6D rigid-body inertia matrix Mbb ∈ ℝ⁶ˣ⁶. diff --git a/src/jaxsim/rbda/forward_kinematics.py b/src/jaxsim/rbda/forward_kinematics.py index 77fe3a362..d11e9b45d 100644 --- a/src/jaxsim/rbda/forward_kinematics.py +++ b/src/jaxsim/rbda/forward_kinematics.py @@ -75,10 +75,14 @@ def propagate_kinematics( return (W_X_i,), None - (W_X_i,), _ = jax.lax.scan( - f=propagate_kinematics, - init=propagate_kinematics_carry, - xs=jnp.arange(start=1, stop=model.number_of_links()), + (W_X_i,), _ = ( + jax.lax.scan( + f=propagate_kinematics, + init=propagate_kinematics_carry, + xs=jnp.arange(start=1, stop=model.number_of_links()), + ) + if model.number_of_links() > 1 + else [(W_X_i,), None] ) return jax.vmap(Adjoint.to_transform)(W_X_i) diff --git a/src/jaxsim/rbda/jacobian.py b/src/jaxsim/rbda/jacobian.py index 197a45ee2..4c8992ff7 100644 --- a/src/jaxsim/rbda/jacobian.py +++ b/src/jaxsim/rbda/jacobian.py @@ -67,10 +67,14 @@ def propagate_kinematics( return (i_X_0,), None - (i_X_0,), _ = jax.lax.scan( - f=propagate_kinematics, - init=propagate_kinematics_carry, - xs=np.arange(start=1, stop=model.number_of_links()), + (i_X_0,), _ = ( + jax.lax.scan( + f=propagate_kinematics, + init=propagate_kinematics_carry, + xs=np.arange(start=1, stop=model.number_of_links()), + ) + if model.number_of_links() > 1 + else [(i_X_0,), None] ) # ============================ @@ -105,10 +109,14 @@ def update_jacobian(J: jtp.Matrix, i: jtp.Int) -> jtp.Matrix: return J, None - L_J_WL_B, _ = jax.lax.scan( - f=compute_jacobian, - init=J, - xs=np.arange(start=1, stop=model.number_of_links()), + L_J_WL_B, _ = ( + jax.lax.scan( + f=compute_jacobian, + init=J, + xs=np.arange(start=1, stop=model.number_of_links()), + ) + if model.number_of_links() > 1 + else [J, None] ) return L_J_WL_B @@ -184,10 +192,14 @@ def compute_full_jacobian( return (B_X_i, J), None - (B_X_i, J), _ = jax.lax.scan( - f=compute_full_jacobian, - init=compute_full_jacobian_carry, - xs=np.arange(start=1, stop=model.number_of_links()), + (B_X_i, J), _ = ( + jax.lax.scan( + f=compute_full_jacobian, + init=compute_full_jacobian_carry, + xs=np.arange(start=1, stop=model.number_of_links()), + ) + if model.number_of_links() > 1 + else [(B_X_i, J), None] ) # Convert adjoints to SE(3) transforms. From d87e0766c35de09080232d9171fdcbec737f48b4 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 20 Aug 2024 09:44:51 +0200 Subject: [PATCH 2/2] Fix formatting Co-authored-by: Diego Ferigo --- src/jaxsim/rbda/crba.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/jaxsim/rbda/crba.py b/src/jaxsim/rbda/crba.py index 74088b54b..1d45dbd8d 100644 --- a/src/jaxsim/rbda/crba.py +++ b/src/jaxsim/rbda/crba.py @@ -139,14 +139,7 @@ def compute_inner(carry: CarryInnerFn) -> tuple[CarryInnerFn, None]: xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), ) if model.number_of_links() > 1 - else [ - ( - j, - Fi, - M, - ), - None, - ] + else [(j, Fi, M), None] ) Fi = i_X_0[j].T @ Fi