diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 530c5dd64..dc70b499e 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -393,7 +393,7 @@ def joint_transforms_and_motion_subspaces( pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)( jnp.array(self.joint_model.joint_types[1:]).astype(int), jnp.array(joint_positions), - jnp.array(self.joint_model.joint_axis), + jnp.array([j.axis for j in self.joint_model.joint_axis]), ) # Extract the transforms and motion subspaces of the joints. diff --git a/src/jaxsim/math/joint_model.py b/src/jaxsim/math/joint_model.py index 2e29553e9..106409b88 100644 --- a/src/jaxsim/math/joint_model.py +++ b/src/jaxsim/math/joint_model.py @@ -39,7 +39,7 @@ class JointModel: joint_dofs: Static[tuple[int, ...]] joint_names: Static[tuple[str, ...]] - joint_types: Static[tuple[JointType, ...]] + joint_types: Static[tuple[int, ...]] joint_axis: Static[tuple[JointGenericAxis, ...]] @staticmethod @@ -109,7 +109,7 @@ def build(description: ModelDescription) -> JointModel: joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]), joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]), joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]), - joint_axis=tuple([j.axis for j in ordered_joints]), + joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints), ) def parent_H_child( @@ -201,7 +201,7 @@ def predecessor_H_successor( pre_H_suc, S = supported_joint_motion( self.joint_types[joint_index], joint_position, - self.joint_axis[joint_index], + self.joint_axis[joint_index].axis, ) return pre_H_suc, S @@ -224,9 +224,9 @@ def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix: @jax.jit def supported_joint_motion( - joint_type: JointType, + joint_type: jtp.IntLike, joint_position: jtp.VectorLike, - joint_axis: JointGenericAxis, + joint_axis: jtp.VectorLike | None = None, /, ) -> tuple[jtp.Matrix, jtp.Array]: """ @@ -234,8 +234,8 @@ def supported_joint_motion( Args: joint_type: The type of the joint. - joint_axis: The axis of rotation or translation of the joint. joint_position: The position of the joint. + joint_axis: The optional 3D axis of rotation or translation of the joint. Returns: A tuple containing the homogeneous transformation and the motion subspace. @@ -244,26 +244,33 @@ def supported_joint_motion( # Prepare the joint position s = jnp.array(joint_position).astype(float) - def compute_F(): + def compute_F() -> tuple[jtp.Matrix, jtp.Array]: return jaxlie.SE3.identity(), jnp.zeros(shape=(6, 1)) - def compute_R(): + def compute_R() -> tuple[jtp.Matrix, jtp.Array]: + + # Get the additional argument specifying the joint axis. + # This is a metadata required by only some joint types. + axis = jnp.array(joint_axis).astype(float).squeeze() + pre_H_suc = jaxlie.SE3.from_rotation( - rotation=jaxlie.SO3.from_matrix( - Rotation.from_axis_angle(vector=s * joint_axis) - ) + rotation=jaxlie.SO3.from_matrix(Rotation.from_axis_angle(vector=s * axis)) ) - S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_axis.squeeze()])) + S = jnp.vstack(jnp.hstack([jnp.zeros(3), axis])) + return pre_H_suc, S - def compute_P(): - pre_H_suc = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3.identity(), - translation=jnp.array(s * joint_axis), - ) + def compute_P() -> tuple[jtp.Matrix, jtp.Array]: + + # Get the additional argument specifying the joint axis. + # This is a metadata required by only some joint types. + axis = jnp.array(joint_axis).astype(float).squeeze() + + pre_H_suc = jaxlie.SE3.from_translation(translation=jnp.array(s * axis)) + + S = jnp.vstack(jnp.hstack([axis, jnp.zeros(3)])) - S = jnp.vstack(jnp.hstack([joint_axis.squeeze(), jnp.zeros(3)])) return pre_H_suc, S pre_H_suc, S = jax.lax.switch( diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index 9abf7d257..a2139e0e6 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -30,9 +30,11 @@ class JointGenericAxis: axis: jtp.Vector def __hash__(self) -> int: - return hash((tuple(np.array(self.axis).tolist()))) + + return hash(tuple(self.axis.tolist())) def __eq__(self, other: JointGenericAxis) -> bool: + if not isinstance(other, JointGenericAxis): return False