diff --git a/src/jaxsim/high_level/joint.py b/src/jaxsim/high_level/joint.py index aef8b8360..73c7b9729 100644 --- a/src/jaxsim/high_level/joint.py +++ b/src/jaxsim/high_level/joint.py @@ -102,6 +102,27 @@ def position_limit(self, dof: int = None) -> tuple[jtp.Float, jtp.Float]: return jnp.array(low, dtype=float), jnp.array(high, dtype=float) + # ============= + # Motor methods + # ============= + @functools.partial(oop.jax_tf.method_ro) + def motor_inertia(self) -> jtp.Vector: + """""" + + return jnp.array(self.joint_description.motor_inertia, dtype=float) + + @functools.partial(oop.jax_tf.method_ro) + def motor_gear_ratio(self) -> jtp.Vector: + """""" + + return jnp.array(self.joint_description.motor_gear_ratio, dtype=float) + + @functools.partial(oop.jax_tf.method_ro) + def motor_viscous_friction(self) -> jtp.Vector: + """""" + + return jnp.array(self.joint_description.motor_viscous_friction, dtype=float) + # ================= # Multi-DoF methods # ================= diff --git a/src/jaxsim/high_level/model.py b/src/jaxsim/high_level/model.py index db64e4294..63bea11ef 100644 --- a/src/jaxsim/high_level/model.py +++ b/src/jaxsim/high_level/model.py @@ -1137,6 +1137,15 @@ def forward_dynamics_crb( τ = jnp.atleast_1d(τ.squeeze()) τ = jnp.vstack(τ) if τ.size > 0 else jnp.empty(shape=(0, 1)) + # Extract motor parameters from the physics model + GR = self.motor_gear_ratios() + IM = self.motor_inertias() + KV = jnp.diag(self.motor_viscous_frictions()) + + # Compute auxiliary quantities + Γ = jnp.diag(GR) + K̅ᵥ = Γ.T @ KV @ Γ + # Compute terms of the floating-base EoM M = self.free_floating_mass_matrix() h = jnp.vstack(self.free_floating_bias_forces()) @@ -1144,6 +1153,15 @@ def forward_dynamics_crb( f_ext = jnp.vstack(self.external_forces().flatten()) S = jnp.block([jnp.zeros(shape=(self.dofs(), 6)), jnp.eye(self.dofs())]).T + # Configure the slice for fixed/floating base robots + sl = np.s_[0:] if self.physics_model.is_floating_base else np.s_[6:] + sl_m = np.s_[M.shape[0] - self.dofs() :] + + # Add the motor related terms to the EoM + M = M.at[sl_m, sl_m].set(M[sl_m, sl_m] + jnp.diag(Γ.T @ IM @ Γ)) + h = h.at[sl_m].set(h[sl_m] + K̅ᵥ @ self.joint_velocities()[:, None]) + S = S.at[sl_m].set(S[sl_m]) + # Compute the generalized acceleration by inverting the EoM ν̇ = jax.lax.select( pred=self.floating_base(), @@ -1479,6 +1497,87 @@ def integrate( }, ) + # ============== + # Motor dynamics + # ============== + + @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) + def set_motor_inertias( + self, inertias: jtp.Vector, joint_names: tuple[str, ...] = None + ) -> None: + joint_names = joint_names or self.joint_names() + + if inertias.size != len(joint_names): + raise ValueError("Wrong arguments size", inertias.size, len(joint_names)) + + self.physics_model._joint_motor_inertia.update( + dict(zip(self.physics_model._joint_motor_inertia, inertias)) + ) + + logging.info("Setting attribute `motor_inertias`") + + @functools.partial(oop.jax_tf.method_rw, jit=False) + def set_motor_gear_ratios( + self, gear_ratios: jtp.Vector, joint_names: tuple[str, ...] = None + ) -> None: + joint_names = joint_names or self.joint_names() + + if gear_ratios.size != len(joint_names): + raise ValueError("Wrong arguments size", gear_ratios.size, len(joint_names)) + + # Check on gear ratios if motor_inertias are not zero + for idx, gr in enumerate(gear_ratios): + if gr != 0 and self.motor_inertias()[idx] == 0: + raise ValueError( + f"Zero motor inertia with non-zero gear ratio found in position {idx}" + ) + + self.physics_model._joint_motor_gear_ratio.update( + dict(zip(self.physics_model._joint_motor_gear_ratio, gear_ratios)) + ) + + logging.info("Setting attribute `motor_gear_ratios`") + + @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) + def set_motor_viscous_frictions( + self, viscous_frictions: jtp.Vector, joint_names: tuple[str, ...] = None + ) -> None: + joint_names = joint_names or self.joint_names() + + if viscous_frictions.size != len(joint_names): + raise ValueError( + "Wrong arguments size", viscous_frictions.size, len(joint_names) + ) + + self.physics_model._joint_motor_viscous_friction.update( + dict( + zip( + self.physics_model._joint_motor_viscous_friction, + viscous_frictions, + ) + ) + ) + + logging.info("Setting attribute `motor_viscous_frictions`") + + @functools.partial(oop.jax_tf.method_ro, jit=False) + def motor_inertias(self) -> jtp.Vector: + return jnp.array( + [*self.physics_model._joint_motor_inertia.values()], dtype=float + ) + + @functools.partial(oop.jax_tf.method_ro, jit=False) + def motor_gear_ratios(self) -> jtp.Vector: + return jnp.array( + [*self.physics_model._joint_motor_gear_ratio.values()], dtype=float + ) + + @functools.partial(oop.jax_tf.method_ro, jit=False) + def motor_viscous_frictions(self) -> jtp.Vector: + return jnp.array( + [*self.physics_model._joint_motor_viscous_friction.values()], dtype=float + ) + # =============== # Private methods # =============== diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index e1622bbdc..4a6d54fc5 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -74,6 +74,10 @@ class JointDescription(JaxsimDataclass): position_limit: Tuple[float, float] = (0.0, 0.0) initial_position: Union[float, npt.NDArray] = 0.0 + motor_inertia: float = 0.0 + motor_viscous_friction: float = 0.0 + motor_gear_ratio: float = 1.0 + def __post_init__(self): if self.axis is not None: with self.mutable_context( diff --git a/src/jaxsim/physics/algos/aba_motors.py b/src/jaxsim/physics/algos/aba_motors.py new file mode 100644 index 000000000..0b3b1d61f --- /dev/null +++ b/src/jaxsim/physics/algos/aba_motors.py @@ -0,0 +1,284 @@ +from typing import Tuple + +import jax +import jax.numpy as jnp +import numpy as np + +import jaxsim.typing as jtp +from jaxsim.math.adjoint import Adjoint +from jaxsim.math.cross import Cross +from jaxsim.physics.model.physics_model import PhysicsModel + +from . import utils + + +def aba( + model: PhysicsModel, + xfb: jtp.Vector, + q: jtp.Vector, + qd: jtp.Vector, + tau: jtp.Vector, + f_ext: jtp.Matrix = None, +) -> Tuple[jtp.Vector, jtp.Vector]: + """ + Articulated Body Algorithm (ABA) algorithm with motor dynamics for forward dynamics. + """ + + x_fb, q, qd, _, tau, f_ext = utils.process_inputs( + physics_model=model, xfb=xfb, q=q, qd=qd, tau=tau, f_ext=f_ext + ) + + # Extract data from the physics model + pre_X_λi = model.tree_transforms + M = model.spatial_inertias + i_X_pre = model.joint_transforms(q=q) + S = model.motion_subspaces(q=q) + λ = model.parent_array() + + # Extract motor parameters from the physics model + Γ = jnp.array([*model._joint_motor_gear_ratio.values()]) + IM = jnp.array( + [jnp.eye(6) * m for m in [*model._joint_motor_inertia.values()]] * model.NB + ) + K̅ᵥ = Γ.T * jnp.array([*model._joint_motor_viscous_friction.values()]) * Γ + m_S = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0) + + # Initialize buffers + v = jnp.array([jnp.zeros([6, 1])] * model.NB) + MA = jnp.array([jnp.zeros([6, 6])] * model.NB) + pA = jnp.array([jnp.zeros([6, 1])] * model.NB) + c = jnp.array([jnp.zeros([6, 1])] * model.NB) + i_X_λi = jnp.zeros_like(i_X_pre) + + m_v = jnp.array([jnp.zeros([6, 1])] * model.NB) + m_c = jnp.array([jnp.zeros([6, 1])] * model.NB) + pR = jnp.array([jnp.zeros([6, 1])] * model.NB) + + # Base pose B_X_W and velocity + base_quat = jnp.vstack(x_fb[0:4]) + base_pos = jnp.vstack(x_fb[4:7]) + base_vel = jnp.vstack(jnp.hstack([x_fb[10:13], x_fb[7:10]])) + + # 6D transform of base velocity + B_X_W = Adjoint.from_quaternion_and_translation( + quaternion=base_quat, + translation=base_pos, + inverse=True, + normalize_quaternion=True, + ) + i_X_λi = i_X_λi.at[0].set(B_X_W) + + # Transforms link -> base + i_X_0 = jnp.zeros_like(pre_X_λi) + i_X_0 = i_X_0.at[0].set(jnp.eye(6)) + + # Initialize base quantities + if model.is_floating_base: + # Base velocity v₀ + v_0 = B_X_W @ base_vel + v = v.at[0].set(v_0) + + # AB inertia (Mᴬ) and AB bias forces (pᴬ) + MA_0 = M[0] + MA = MA.at[0].set(MA_0) + pA_0 = Cross.vx_star(v[0]) @ MA_0 @ v[0] - Adjoint.inverse( + B_X_W + ).T @ jnp.vstack(f_ext[0]) + pA = pA.at[0].set(pA_0) + + Pass1Carry = Tuple[ + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + ] + + pass_1_carry = (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0) + + # Pass 1 + def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]: + ii = i - 1 + i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0 = carry + + # Compute parent-to-child transform + i_X_λi_i = i_X_pre[i] @ pre_X_λi[i] + i_X_λi = i_X_λi.at[i].set(i_X_λi_i) + + # Propagate link velocity + vJ = S[i] * qd[ii] * (qd.size != 0) + m_vJ = m_S[i] * qd[ii] * (qd.size != 0) + + v_i = i_X_λi[i] @ v[λ[i]] + vJ + v = v.at[i].set(v_i) + + m_v_i = i_X_λi[i] @ v[λ[i]] + m_vJ + m_v = m_v.at[i].set(v_i) + + c_i = Cross.vx(v[i]) @ vJ + c = c.at[i].set(c_i) + m_c_i = Cross.vx(m_v[i]) @ m_vJ + m_c = m_c.at[i].set(m_c_i) + + # Initialize articulated-body inertia + MA_i = jnp.array(M[i]) + MA = MA.at[i].set(MA_i) + + # Initialize articulated-body bias forces + i_X_0_i = i_X_λi[i] @ i_X_0[model.parent[i]] + i_X_0 = i_X_0.at[i].set(i_X_0_i) + i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T + + pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(f_ext[i]) + pA = pA.at[i].set(pA_i) + + pR_i = Cross.vx_star(m_v[i]) @ IM[i] @ m_v[i] - K̅ᵥ[i] * m_v[i] + pR = pR.at[i].set(pR_i) + + return (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0), None + + (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0), _ = jax.lax.scan( + f=loop_body_pass1, + init=pass_1_carry, + xs=np.arange(start=1, stop=model.NB), + ) + + U = jnp.zeros_like(S) + m_U = jnp.zeros_like(S) + d = jnp.zeros(shape=(model.NB, 1)) + u = jnp.zeros(shape=(model.NB, 1)) + m_u = jnp.zeros(shape=(model.NB, 1)) + + Pass2Carry = Tuple[ + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + ] + + pass_2_carry = (U, m_U, d, u, m_u, MA, pA) + + # Pass 2 + def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]: + ii = i - 1 + U, m_U, d, u, m_u, MA, pA = carry + + # Compute intermediate results + u_i = tau[ii] - S[i].T @ pA[i] if tau.size != 0 else -S[i].T @ pA[i] + u = u.at[i].set(u_i.squeeze()) + + has_motors = jnp.allclose(Γ[i], 1.0) + + m_u_i = ( + tau[ii] / Γ[i] * has_motors - m_S[i].T @ pR[i] + if tau.size != 0 + else -m_S[i].T @ pR[i] + ) + m_u = m_u.at[i].set(m_u_i.squeeze()) + + U_i = MA[i] @ S[i] + U = U.at[i].set(U_i) + + m_U_i = IM[i] @ m_S[i] + m_U = m_U.at[i].set(m_U_i) + + d_i = S[i].T @ MA[i] @ S[i] + m_S[i].T @ IM[i] @ m_S[i] + d = d.at[i].set(d_i.squeeze()) + + # Compute the articulated-body inertia and bias forces of this link + Ma = MA[i] + IM[i] - U[i] / d[i] @ U[i].T - m_U[i] / d[i] @ m_U[i].T + pa = ( + pA[i] + + pR[i] + + Ma[i] @ c[i] + + IM[i] @ m_c[i] + + U[i] / d[i] * u[i] + + m_U[i] / d[i] * m_u[i] + ) + + # Propagate them to the parent, handling the base link + def propagate( + MA_pA: Tuple[jtp.MatrixJax, jtp.MatrixJax] + ) -> Tuple[jtp.MatrixJax, jtp.MatrixJax]: + MA, pA = MA_pA + + MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i] + MA = MA.at[λ[i]].set(MA_λi) + + pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa + pA = pA.at[λ[i]].set(pA_λi) + + return MA, pA + + MA, pA = jax.lax.cond( + pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(), + true_fun=propagate, + false_fun=lambda MA_pA: MA_pA, + operand=(MA, pA), + ) + + return (U, m_U, d, u, m_u, MA, pA), None + + (U, m_U, d, u, m_u, MA, pA), _ = jax.lax.scan( + f=loop_body_pass2, + init=pass_2_carry, + xs=np.flip(np.arange(start=1, stop=model.NB)), + ) + + if model.is_floating_base: + a0 = jnp.linalg.solve(-MA[0], pA[0]) + else: + a0 = -B_X_W @ jnp.vstack(model.gravity) + + a = jnp.zeros_like(S) + a = a.at[0].set(a0) + qdd = jnp.zeros_like(q) + + Pass3Carry = Tuple[jtp.MatrixJax, jtp.VectorJax] + pass_3_carry = (a, qdd) + + # Pass 3 + def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> Tuple[Pass3Carry, None]: + ii = i - 1 + a, qdd = carry + + # Propagate link accelerations + a_i = i_X_λi[i] @ a[λ[i]] + c[i] + + # Compute joint accelerations + qdd_ii = (u[i] + m_u[i] - (U[i].T + m_U[i].T) @ a_i) / d[i] + qdd = qdd.at[ii].set(qdd_ii.squeeze()) if qdd.size != 0 else qdd + + a_i = a_i + S[i] * qdd[ii] if qdd.size != 0 else a_i + a = a.at[i].set(a_i) + + return (a, qdd), None + + (a_, qdd), _ = jax.lax.scan( + f=loop_body_pass3, + init=pass_3_carry, + xs=np.arange(1, model.NB), + ) + + # Handle 1 DoF models + qdd = jnp.atleast_1d(qdd.squeeze()) + qdd = jnp.vstack(qdd) if qdd.size > 0 else jnp.empty(shape=(0, 1)) + + # Get the resulting base acceleration (w/o gravity) in body-fixed representation + B_a_WB = a[0] + + # Convert the base acceleration to inertial-fixed representation, and add gravity + W_a_WB = jnp.vstack( + jnp.linalg.solve(B_X_W, B_a_WB) + jnp.vstack(model.gravity) + if model.is_floating_base + else jnp.zeros(6) + ) + + return W_a_WB, qdd diff --git a/src/jaxsim/physics/algos/rnea_motors.py b/src/jaxsim/physics/algos/rnea_motors.py new file mode 100644 index 000000000..f80407700 --- /dev/null +++ b/src/jaxsim/physics/algos/rnea_motors.py @@ -0,0 +1,196 @@ +from typing import Tuple + +import jax +import jax.numpy as jnp +import numpy as np + +import jaxsim.typing as jtp +from jaxsim.math.adjoint import Adjoint +from jaxsim.math.cross import Cross +from jaxsim.physics.model.physics_model import PhysicsModel + +from . import utils + + +def rnea( + model: PhysicsModel, + xfb: jtp.Vector, + q: jtp.Vector, + qd: jtp.Vector, + qdd: jtp.Vector, + a0fb: jtp.Vector = jnp.zeros(6), + f_ext: jtp.Matrix = None, +) -> Tuple[jtp.Vector, jtp.Vector]: + """ + Recursive Newton-Euler Algorithm (RNEA) algorithm for inverse dynamics. + """ + + xfb, q, qd, qdd, _, f_ext = utils.process_inputs( + physics_model=model, xfb=xfb, q=q, qd=qd, qdd=qdd, f_ext=f_ext + ) + + a0fb = a0fb.squeeze() + gravity = model.gravity.squeeze() + + if a0fb.shape[0] != 6: + raise ValueError(a0fb.shape) + + M = model.spatial_inertias + pre_X_λi = model.tree_transforms + i_X_pre = model.joint_transforms(q=q) + S = model.motion_subspaces(q=q) + i_X_λi = jnp.zeros_like(pre_X_λi) + + Γ = jnp.array([*model._joint_motor_gear_ratio.values()]) + IM = jnp.array([*model._joint_motor_inertia.values()]) + K_v = jnp.array([*model._joint_motor_viscous_friction.values()]) + K̅ᵥ = jnp.diag(Γ.T * jnp.diag(K_v) * Γ) + m_S = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0) + + i_X_0 = jnp.zeros_like(pre_X_λi) + i_X_0 = i_X_0.at[0].set(jnp.eye(6)) + + # Parent array mapping: i -> λ(i). + # Exception: λ(0) must not be used, it's initialized to -1. + λ = model.parent_array() + + v = jnp.array([jnp.zeros([6, 1])] * model.NB) + a = jnp.array([jnp.zeros([6, 1])] * model.NB) + f = jnp.array([jnp.zeros([6, 1])] * model.NB) + + v_m = jnp.array([jnp.zeros([6, 1])] * model.NB) + a_m = jnp.array([jnp.zeros([6, 1])] * model.NB) + f_m = jnp.array([jnp.zeros([6, 1])] * model.NB) + + # 6D transform of base velocity + B_X_W = Adjoint.from_quaternion_and_translation( + quaternion=xfb[0:4], + translation=xfb[4:7], + inverse=True, + normalize_quaternion=True, + ) + i_X_λi = i_X_λi.at[0].set(B_X_W) + + a_0 = -B_X_W @ jnp.vstack(gravity) + a = a.at[0].set(a_0) + + if model.is_floating_base: + W_v_WB = jnp.vstack(jnp.hstack([xfb[10:13], xfb[7:10]])) + + v_0 = B_X_W @ W_v_WB + v = v.at[0].set(v_0) + + a_0 = B_X_W @ (jnp.vstack(a0fb) - jnp.vstack(gravity)) + a = a.at[0].set(a_0) + + f_0 = ( + M[0] @ a[0] + + Cross.vx_star(v[0]) @ M[0] @ v[0] + - Adjoint.inverse(B_X_W).T @ jnp.vstack(f_ext[0]) + ) + f = f.at[0].set(f_0) + + ForwardPassCarry = Tuple[ + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + jtp.MatrixJax, + ] + forward_pass_carry = (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m) + + def forward_pass( + carry: ForwardPassCarry, i: jtp.Int + ) -> Tuple[ForwardPassCarry, None]: + ii = i - 1 + i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m = carry + + vJ = S[i] * qd[ii] + vJ_m = m_S[i] * qd[ii] + + i_X_λi_i = i_X_pre[i] @ pre_X_λi[i] + i_X_λi = i_X_λi.at[i].set(i_X_λi_i) + + v_i = i_X_λi[i] @ v[λ[i]] + vJ + v = v.at[i].set(v_i) + + v_i_m = i_X_λi[i] @ v_m[λ[i]] + vJ_m + v_m = v_m.at[i].set(v_i_m) + + a_i = i_X_λi[i] @ a[λ[i]] + S[i] * qdd[ii] + Cross.vx(v[i]) @ vJ + a = a.at[i].set(a_i) + + a_i_m = i_X_λi[i] @ a_m[λ[i]] + m_S[i] * qdd[ii] + Cross.vx(v_m[i]) @ vJ_m + a_m = a_m.at[i].set(a_i_m) + + i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]] + i_X_0 = i_X_0.at[i].set(i_X_0_i) + i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T + + f_i = ( + M[i] @ a[i] + + Cross.vx_star(v[i]) @ M[i] @ v[i] + - i_Xf_W @ jnp.vstack(f_ext[i]) + ) + f = f.at[i].set(f_i) + + f_i_m = IM[i] * a_m[i] + Cross.vx_star(v_m[i]) * IM[i] @ v_m[i] + f_m = f_m.at[i].set(f_i_m) + + return (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), None + + (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), _ = jax.lax.scan( + f=forward_pass, + init=forward_pass_carry, + xs=np.arange(start=1, stop=model.NB), + ) + + tau = jnp.zeros_like(q) + + BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax] + backward_pass_carry = (tau, f, f_m) + + def backward_pass( + carry: BackwardPassCarry, i: jtp.Int + ) -> Tuple[BackwardPassCarry, None]: + ii = i - 1 + tau, f, f_m = carry + + value = S[i].T @ f[i] + m_S[i].T @ f_m[i] # + K̅ᵥ[i] * qd[ii] + tau = tau.at[ii].set(value.squeeze()) + + def update_f(ffm: Tuple[jtp.MatrixJax, jtp.MatrixJax]) -> jtp.MatrixJax: + f, f_m = ffm + f_λi = f[λ[i]] + i_X_λi[i].T @ f[i] + f = f.at[λ[i]].set(f_λi) + + f_m_λi = f_m[λ[i]] + i_X_λi[i].T @ f_m[i] + f_m = f_m.at[λ[i]].set(f_m_λi) + return f, f_m + + f, f_m = jax.lax.cond( + pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(), + true_fun=update_f, + false_fun=lambda f: f, + operand=(f, f_m), + ) + + return (tau, f, f_m), None + + (tau, f, f_m), _ = jax.lax.scan( + f=backward_pass, + init=backward_pass_carry, + xs=np.flip(np.arange(start=1, stop=model.NB)), + ) + + # Handle 1 DoF models + tau = jnp.atleast_1d(tau.squeeze()) + tau = jnp.vstack(tau) if tau.size > 0 else jnp.empty(shape=(0, 1)) + + # Express the base 6D force in the world frame + W_f0 = B_X_W.T @ jnp.vstack(f[0]) + + return W_f0, tau diff --git a/src/jaxsim/physics/model/physics_model.py b/src/jaxsim/physics/model/physics_model.py index 7eb4dbcc3..689276cfe 100644 --- a/src/jaxsim/physics/model/physics_model.py +++ b/src/jaxsim/physics/model/physics_model.py @@ -53,6 +53,12 @@ class PhysicsModel(JaxsimDataclass): _joint_limit_spring: Dict[int, float] = dataclasses.field(default_factory=dict) _joint_limit_damper: Dict[int, float] = dataclasses.field(default_factory=dict) + _joint_motor_inertia: Dict[int, float] = dataclasses.field(default_factory=dict) + _joint_motor_gear_ratio: Dict[int, float] = dataclasses.field(default_factory=dict) + _joint_motor_viscous_friction: Dict[int, float] = dataclasses.field( + default_factory=dict + ) + def __post_init__(self): if self.initial_state is None: initial_state = PhysicsModelState.zero(physics_model=self) @@ -119,6 +125,21 @@ def build_from( for joint in model_description.joints } + # Dicts from the joint index to the motor inertia, gear ratio and viscous friction. + # Note: the joint index is equal to its child link index. + joint_motor_inertia = { + joint.index: jnp.array(joint.motor_inertia, dtype=float) + for joint in model_description.joints + } + joint_motor_gear_ratio = { + joint.index: jnp.array(joint.motor_gear_ratio, dtype=float) + for joint in model_description.joints + } + joint_motor_viscous_friction = { + joint.index: jnp.array(joint.motor_viscous_friction, dtype=float) + for joint in model_description.joints + } + # Transform between model's root and model's base link # (this is just the pose of the base link in the SDF description) base_link = model_description.links_dict[model_description.link_names()[0]] @@ -179,6 +200,9 @@ def build_from( _joint_friction_viscous=joint_friction_viscous, _joint_limit_spring=joint_limit_spring, _joint_limit_damper=joint_limit_damper, + _joint_motor_gear_ratio=joint_motor_gear_ratio, + _joint_motor_inertia=joint_motor_inertia, + _joint_motor_viscous_friction=joint_motor_viscous_friction, gravity=jnp.hstack([gravity.squeeze(), np.zeros(3)]), is_floating_base=True, gc=GroundContact.build_from(model_description=model_description),