diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index f67cf7298..128c2094b 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -66,6 +66,9 @@ class JointDescription: index: int = 0 + friction_static: float = 0.0 + friction_viscous: float = 0.0 + position_limit: Tuple[float, float] = (0.0, 0.0) initial_position: Union[float, npt.NDArray] = 0.0 diff --git a/src/jaxsim/parsers/sdf/parser.py b/src/jaxsim/parsers/sdf/parser.py index a2b87d356..92f01d3af 100644 --- a/src/jaxsim/parsers/sdf/parser.py +++ b/src/jaxsim/parsers/sdf/parser.py @@ -184,6 +184,16 @@ def extract_data_from_sdf( if j.axis is not None and j.axis.limit is not None else np.finfo(float).max, ), + friction_static=j.axis.dynamics.friction + if j.axis is not None + and j.axis.dynamics is not None + and j.axis.dynamics.friction is not None + else 0.0, + friction_viscous=j.axis.dynamics.damping + if j.axis is not None + and j.axis.dynamics is not None + and j.axis.dynamics.friction is not None + else 0.0, ) for j in sdf_tree.model.joints if j.type in {"revolute", "prismatic", "fixed"} diff --git a/src/jaxsim/physics/model/physics_model.py b/src/jaxsim/physics/model/physics_model.py index 0de2d65db..4be78eeb0 100644 --- a/src/jaxsim/physics/model/physics_model.py +++ b/src/jaxsim/physics/model/physics_model.py @@ -44,6 +44,9 @@ class PhysicsModel(JaxsimDataclass): ) _link_inertias_dict: Dict[int, jtp.Matrix] = dataclasses.field(default_factory=dict) + _joint_friction_static: Dict[int, float] = dataclasses.field(default_factory=dict) + _joint_friction_viscous: Dict[int, float] = dataclasses.field(default_factory=dict) + def __post_init__(self): if self.initial_state is None: @@ -90,6 +93,15 @@ def build_from( joint.index: joint.jtype for joint in model_description.joints } + # Dicts from the joint index to the static and viscous friction. + # Note: the joint index is equal to its child link index. + joint_friction_static = { + joint.index: joint.friction_static for joint in model_description.joints + } + joint_friction_viscous = { + joint.index: joint.friction_viscous for joint in model_description.joints + } + # Transform between model's root and model's base link # (this is just the pose of the base link in the SDF description) base_link = model_description.links_dict[model_description.link_names()[0]] @@ -146,6 +158,8 @@ def build_from( _jtype_dict=joint_types_dict, _tree_transforms_dict=tree_transforms_dict, _link_inertias_dict=link_spatial_inertias_dict, + _joint_friction_static=joint_friction_static, + _joint_friction_viscous=joint_friction_viscous, gravity=jnp.hstack([gravity.squeeze(), np.zeros(3)]), is_floating_base=True, gc=GroundContact.build_from(model_description=model_description), diff --git a/src/jaxsim/simulation/ode.py b/src/jaxsim/simulation/ode.py index 6c9b1bcb9..3cf8c24fc 100644 --- a/src/jaxsim/simulation/ode.py +++ b/src/jaxsim/simulation/ode.py @@ -105,6 +105,20 @@ def dx_dt( terrain=terrain, ) + # ============== + # Joint friction + # ============== + + # Static and viscous joint friction parameters + kc = jnp.array(list(physics_model._joint_friction_static.values())) + kv = jnp.array(list(physics_model._joint_friction_viscous.values())) + + # Compute the joint friction torque + tau_friction = -( + jnp.diag(kc) @ jnp.sign(ode_state.physics_model.joint_positions) + + jnp.diag(kv) @ ode_state.physics_model.joint_velocities + ) + # ======================== # Compute forward dynamics # ======================== @@ -112,15 +126,8 @@ def dx_dt( # Compute the total forces applied to the bodies total_forces = ode_input.physics_model.f_ext + contact_forces_links - # Compute the mechanical joint torques (real torque sent to the joint) by - # subtracting the optional joint friction - # TODO: add support of coulomb/viscous parameters in parsers and PhysicsModel - kp_friction = jnp.array([0.0] * physics_model.dofs()) - kd_friction = jnp.array([0.0] * physics_model.dofs()) - tau = ode_input.physics_model.tau - ( - jnp.diag(kp_friction) @ jnp.sign(ode_state.physics_model.joint_positions) - + jnp.diag(kd_friction) @ ode_state.physics_model.joint_velocities - ) + # Compute the joint torques to actuate + tau = ode_input.physics_model.tau + tau_friction W_a_WB, qdd = algos.aba.aba( model=physics_model,