From 193c80ea10188b3e76e93f3a848eab921dca1e9c Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 4 Oct 2023 21:33:22 +0200 Subject: [PATCH] Add documentation for `physics` module --- src/jaxsim/physics/algos/aba.py | 17 +++++ src/jaxsim/physics/algos/crba.py | 11 ++++ .../physics/algos/forward_kinematics.py | 12 ++++ src/jaxsim/physics/algos/jacobian.py | 11 ++++ src/jaxsim/physics/algos/rnea.py | 18 ++++- src/jaxsim/physics/algos/soft_contacts.py | 66 ++++++++++++++++++- src/jaxsim/physics/algos/terrain.py | 34 +++++++++- src/jaxsim/physics/model/ground_contact.py | 15 ++++- src/jaxsim/physics/model/physics_model.py | 45 +++++++++++++ .../physics/model/physics_model_state.py | 32 +++++++++ 10 files changed, 256 insertions(+), 5 deletions(-) diff --git a/src/jaxsim/physics/algos/aba.py b/src/jaxsim/physics/algos/aba.py index 19abb7379..bede3f7c5 100644 --- a/src/jaxsim/physics/algos/aba.py +++ b/src/jaxsim/physics/algos/aba.py @@ -22,6 +22,23 @@ def aba( ) -> Tuple[jtp.Vector, jtp.Vector]: """ Articulated Body Algorithm (ABA) algorithm for forward dynamics. + + Args: + model (PhysicsModel): The physics model of the articulated body or robot. + xfb (jtp.Vector): The floating base state vector containing quaternion (4D) and position (3D). + q (jtp.Vector): Joint positions (Generalized coordinates). + qd (jtp.Vector): Joint velocities. + tau (jtp.Vector): Joint torques or forces. + f_ext (jtp.Matrix, optional): External forces and torques acting on each link. Defaults to None. + + Returns: + Tuple[jtp.Vector, jtp.Vector]: A tuple containing the resulting base acceleration (in inertial-fixed representation) + and joint accelerations. + + Note: + The ABA algorithm is used to compute the accelerations of the links in an articulated body or robot system given + inputs such as joint positions, velocities, torques, and external forces. The algorithm involves multiple passes + to calculate intermediate quantities required for simulating the motion of the robot. """ x_fb, q, qd, _, tau, f_ext = utils.process_inputs( diff --git a/src/jaxsim/physics/algos/crba.py b/src/jaxsim/physics/algos/crba.py index 166ff95aa..c57673d88 100644 --- a/src/jaxsim/physics/algos/crba.py +++ b/src/jaxsim/physics/algos/crba.py @@ -11,6 +11,17 @@ def crba(model: PhysicsModel, q: jtp.Vector) -> jtp.Matrix: + """ + Compute the Composite Rigid-Body Inertia Matrix (CRBA) for an articulated body or robot given joint positions. + + Args: + model (PhysicsModel): The physics model of the articulated body or robot. + q (jtp.Vector): Joint positions (Generalized coordinates). + + Returns: + jtp.Matrix: The Composite Rigid-Body Inertia Matrix (CRBA) of the articulated body or robot. + """ + _, q, _, _, _, _ = utils.process_inputs( physics_model=model, xfb=None, q=q, qd=None, tau=None, f_ext=None ) diff --git a/src/jaxsim/physics/algos/forward_kinematics.py b/src/jaxsim/physics/algos/forward_kinematics.py index 3d7e99fe8..b370c094e 100644 --- a/src/jaxsim/physics/algos/forward_kinematics.py +++ b/src/jaxsim/physics/algos/forward_kinematics.py @@ -14,6 +14,18 @@ def forward_kinematics_model( model: PhysicsModel, q: jtp.Vector, xfb: jtp.Vector ) -> jtp.Array: + """ + Compute the forward kinematics transformations for all links in an articulated body or robot. + + Args: + model (PhysicsModel): The physics model of the articulated body or robot. + q (jtp.Vector): Joint positions (Generalized coordinates). + xfb (jtp.Vector): The base pose vector, including the quaternion (first 4 elements) and translation (last 3 elements). + + Returns: + jtp.Array: A 3D array containing the forward kinematics transformations for all links. + """ + x_fb, q, _, _, _, _ = utils.process_inputs( physics_model=model, xfb=xfb, q=q, qd=None, tau=None, f_ext=None ) diff --git a/src/jaxsim/physics/algos/jacobian.py b/src/jaxsim/physics/algos/jacobian.py index 69b488f38..95e71e7c9 100644 --- a/src/jaxsim/physics/algos/jacobian.py +++ b/src/jaxsim/physics/algos/jacobian.py @@ -12,6 +12,17 @@ def jacobian(model: PhysicsModel, body_index: jtp.Int, q: jtp.Vector) -> jtp.Matrix: + """ + Compute the Jacobian matrix for a specific link in an articulated body or robot. + + Args: + model (PhysicsModel): The physics model of the articulated body or robot. + body_index (jtp.Int): The index of the link for which to compute the Jacobian matrix. + q (jtp.Vector): Joint positions (Generalized coordinates). + + Returns: + jtp.Matrix: The Jacobian matrix for the specified link. + """ _, q, _, _, _, _ = utils.process_inputs(physics_model=model, q=q) S = model.motion_subspaces(q=q) diff --git a/src/jaxsim/physics/algos/rnea.py b/src/jaxsim/physics/algos/rnea.py index 09ef0d264..551db05bd 100644 --- a/src/jaxsim/physics/algos/rnea.py +++ b/src/jaxsim/physics/algos/rnea.py @@ -22,7 +22,23 @@ def rnea( f_ext: jtp.Matrix = None, ) -> Tuple[jtp.Vector, jtp.Vector]: """ - Recursive Newton-Euler Algorithm (RNEA) algorithm for inverse dynamics. + Perform Inverse Dynamics Calculation using the Recursive Newton-Euler Algorithm (RNEA). + + This function calculates the joint torques (forces) required to achieve a desired motion + given the robot's configuration, velocities, accelerations, and external forces. + + Args: + model (PhysicsModel): The robot's physics model containing dynamic parameters. + xfb (jtp.Vector): The floating base state, including orientation and position. + q (jtp.Vector): Joint positions (angles). + qd (jtp.Vector): Joint velocities. + qdd (jtp.Vector): Joint accelerations. + a0fb (jtp.Vector, optional): Base acceleration. Defaults to zeros. + f_ext (jtp.Matrix, optional): External forces acting on the robot. Defaults to None. + + Returns: + W_f0 (jtp.Vector): The base 6D force expressed in the world frame. + tau (jtp.Vector): Joint torques (forces) required for the desired motion. """ xfb, q, qd, qdd, _, f_ext = utils.process_inputs( diff --git a/src/jaxsim/physics/algos/soft_contacts.py b/src/jaxsim/physics/algos/soft_contacts.py index 05815aa7a..4bb341ff9 100644 --- a/src/jaxsim/physics/algos/soft_contacts.py +++ b/src/jaxsim/physics/algos/soft_contacts.py @@ -20,12 +20,24 @@ @jax_dataclasses.pytree_dataclass class SoftContactsState: + """State of the soft contacts model.""" + tangential_deformation: jtp.Matrix @staticmethod def zero( physics_model: jaxsim.physics.model.physics_model.PhysicsModel, ) -> "SoftContactsState": + """ + Modify the SoftContactsState instance imposing zero tangential deformation. + + Args: + physics_model (jaxsim.physics.model.physics_model.PhysicsModel): The physics model. + + Returns: + SoftContactsState: A SoftContactsState instance with zero tangential deformation. + """ + return SoftContactsState( tangential_deformation=jnp.zeros(shape=(3, physics_model.gc.body.size)) ) @@ -33,6 +45,16 @@ def zero( def valid( self, physics_model: jaxsim.physics.model.physics_model.PhysicsModel ) -> bool: + """ + Check if the soft contacts state has valid shape. + + Args: + physics_model (jaxsim.physics.model.physics_model.PhysicsModel): The physics model. + + Returns: + bool: True if the state has a valid shape, otherwise False. + """ + from jaxsim.simulation.utils import check_valid_shape return check_valid_shape( @@ -43,6 +65,17 @@ def valid( ) def replace(self, validate: bool = True, **kwargs) -> "SoftContactsState": + """ + Replace attributes of the soft contacts state. + + Args: + validate (bool, optional): Whether to validate the state after replacement. Defaults to True. + **kwargs: Keyword arguments for attribute replacement. + + Returns: + SoftContactsState: A new SoftContactsState instance with replaced attributes. + """ + with jax_dataclasses.copy_and_mutate(self, validate=validate) as updated_state: _ = [updated_state.__setattr__(k, v) for k, v in kwargs.items()] @@ -57,6 +90,15 @@ def collidable_points_pos_vel( ) -> Tuple[jtp.Matrix, jtp.Matrix]: """ Compute the position and linear velocity of collidable points in the world frame. + + Args: + model (PhysicsModel): The physics model. + q (jtp.Vector): The joint positions. + qd (jtp.Vector): The joint velocities. + xfb (jtp.Vector, optional): The floating base state. Defaults to None. + + Returns: + Tuple[jtp.Matrix, jtp.Matrix]: A tuple containing the position and velocity of collidable points. """ # Make sure that shape and size are correct @@ -195,7 +237,17 @@ class SoftContactsParams: def build( K: float = 1e6, D: float = 2_000, mu: float = 0.5 ) -> "SoftContactsParams": - """""" + """ + Create a SoftContactsParams instance with specified parameters. + + Args: + K (float, optional): The stiffness parameter. Defaults to 1e6. + D (float, optional): The damping parameter. Defaults to 2000. + mu (float, optional): The friction coefficient. Defaults to 0.5. + + Returns: + SoftContactsParams: A SoftContactsParams instance with the specified parameters. + """ return SoftContactsParams( K=jnp.array(K, dtype=float), @@ -220,6 +272,18 @@ def contact_model( velocity: jtp.Vector, tangential_deformation: jtp.Vector, ) -> Tuple[jtp.Vector, jtp.Vector]: + """ + Compute the contact forces and material deformation rate. + + Args: + position (jtp.Vector): The position of the collidable point. + velocity (jtp.Vector): The linear velocity of the collidable point. + tangential_deformation (jtp.Vector): The tangential deformation. + + Returns: + Tuple[jtp.Vector, jtp.Vector]: A tuple containing the contact force and material deformation rate. + """ + # Short name of parameters K = self.parameters.K D = self.parameters.D diff --git a/src/jaxsim/physics/algos/terrain.py b/src/jaxsim/physics/algos/terrain.py index d3316f24c..90ca5cd38 100644 --- a/src/jaxsim/physics/algos/terrain.py +++ b/src/jaxsim/physics/algos/terrain.py @@ -14,6 +14,17 @@ def height(self, x: float, y: float) -> float: pass def normal(self, x: float, y: float) -> jtp.Vector: + """ + Compute the normal vector of the terrain at a specific (x, y) location. + + Args: + x (float): The x-coordinate of the location. + y (float): The y-coordinate of the location. + + Returns: + jtp.Vector: The normal vector of the terrain surface at the specified location. + """ + # https://stackoverflow.com/a/5282364 h_xp = self.height(x=x + self.delta, y=y) h_xm = self.height(x=x - self.delta, y=y) @@ -41,10 +52,29 @@ class PlaneTerrain(Terrain): @staticmethod def build(plane_normal: jtp.Vector) -> "PlaneTerrain": - """""" + """ + Create a PlaneTerrain instance with a specified plane normal vector. + + Args: + plane_normal (jtp.Vector): The normal vector of the terrain plane. + + Returns: + PlaneTerrain: A PlaneTerrain instance. + """ return PlaneTerrain(plane_normal=jnp.array(plane_normal, dtype=float)) def height(self, x: float, y: float) -> float: + """ + Compute the height of the terrain at a specific (x, y) location on a plane. + + Args: + x (float): The x-coordinate of the location. + y (float): The y-coordinate of the location. + + Returns: + float: The height of the terrain at the specified location on the plane. + """ + a, b, c = self.plane_normal - return -(a * x + b * x) / c + return -(a * x + b * x) / c \ No newline at end of file diff --git a/src/jaxsim/physics/model/ground_contact.py b/src/jaxsim/physics/model/ground_contact.py index 88313eca7..04a3be750 100644 --- a/src/jaxsim/physics/model/ground_contact.py +++ b/src/jaxsim/physics/model/ground_contact.py @@ -12,7 +12,20 @@ @jax_dataclasses.pytree_dataclass class GroundContact: """ - Class to store the collidable points of a robot model. + A class for managing collidable points in a robot model. + + This class is used to store and manage information about collidable points on a robot model, + such as their positions and the corresponding bodies (links) they are associated with. + + Attributes: + point (npt.NDArray): An array of shape (3, N) representing the 3D positions of collidable points. + body (Static[npt.NDArray]): An array of integers representing the indices of the bodies (links) + associated with each collidable point. + + Methods: + build_from(model_description: ModelDescription) -> GroundContact: + A static method to build a GroundContact object from a ModelDescription instance. + It extracts collidable points' positions and their associated bodies from the model description. """ point: npt.NDArray = dataclasses.field(default_factory=lambda: jnp.array([])) diff --git a/src/jaxsim/physics/model/physics_model.py b/src/jaxsim/physics/model/physics_model.py index c6dcbddc1..17faf4376 100644 --- a/src/jaxsim/physics/model/physics_model.py +++ b/src/jaxsim/physics/model/physics_model.py @@ -23,6 +23,51 @@ class PhysicsModel(JaxsimDataclass): """ A read-only class to store all the information necessary to run RBDAs on a model. + + This class contains information about the physics model, including the number of bodies, initial state, gravity, + floating base configuration, ground contact points, and more. + + Attributes: + NB (Static[int]): The number of bodies in the physics model. + initial_state (PhysicsModelState): The initial state of the physics model (default: None). + gravity (jtp.Vector): The gravity vector (default: [0, 0, 0, 0, 0, 0]). + is_floating_base (Static[bool]): A flag indicating whether the model has a floating base (default: False). + gc (GroundContact): The ground contact points of the model (default: empty GroundContact instance). + description (Static[jaxsim.parsers.descriptions.model.ModelDescription]): A description of the model (default: None). + + Methods: + build_from( + model_description: jaxsim.parsers.descriptions.model.ModelDescription, + gravity: jtp.Vector = default_gravity() + ) -> PhysicsModel: + Create a PhysicsModel instance from a model description and gravity vector. + + dofs() -> int: + Get the number of degrees of freedom (DOFs) in the model. + + set_gravity(gravity: jtp.Vector) -> None: + Set the gravity vector for the model. + + parent_array() -> jtp.Vector: + Get the parent array (λ(i)) for the model. + + support_body_array(body_index: jtp.Int) -> jtp.Vector: + Get an array of body indices (κ(i)) that support the specified body. + + tree_transforms() -> jtp.Array: + Get an array of tree transforms (pre(i)_X_λ(i)) for all bodies. + + spatial_inertias() -> jtp.Array: + Get an array of spatial inertias (M_links) for all bodies. + + jtype(joint_index: int) -> JointType: + Get the joint type for the specified joint index. + + joint_transforms(q: jtp.Vector) -> jtp.Array: + Compute joint transforms (Xj) for the given joint positions (q). + + motion_subspaces(q: jtp.Vector) -> jtp.Array: + Compute motion subspaces (SS) for the given joint positions (q). """ NB: Static[int] diff --git a/src/jaxsim/physics/model/physics_model_state.py b/src/jaxsim/physics/model/physics_model_state.py index d5f53bcdd..eae4a6072 100644 --- a/src/jaxsim/physics/model/physics_model_state.py +++ b/src/jaxsim/physics/model/physics_model_state.py @@ -8,6 +8,38 @@ @jax_dataclasses.pytree_dataclass class PhysicsModelState(JaxsimDataclass): + """ + A class representing the state of a physics model. + + This class stores the joint positions, joint velocities, and the base state (position, orientation, linear velocity, + and angular velocity) of a physics model. + + Attributes: + joint_positions (jtp.Vector): An array representing the joint positions. + joint_velocities (jtp.Vector): An array representing the joint velocities. + base_position (jtp.Vector): An array representing the base position (default: zeros). + base_quaternion (jtp.Vector): An array representing the base quaternion (default: [1.0, 0, 0, 0]). + base_linear_velocity (jtp.Vector): An array representing the base linear velocity (default: zeros). + base_angular_velocity (jtp.Vector): An array representing the base angular velocity (default: zeros). + + Methods: + zero(physics_model: "jaxsim.physics.model.physics_model.PhysicsModel") -> PhysicsModelState: + Create a zero-initialized PhysicsModelState for the given physics model. + + position() -> jtp.Vector: + Get the full state vector, including joint positions, joint velocities, base position, and base quaternion. + + velocity() -> jtp.Vector: + Get the full velocity vector, including base linear velocity, base angular velocity, and joint velocities. + + xfb() -> jtp.Vector: + Get the full state vector in the "xfb" format, which includes base quaternion, base position, base angular + velocity, and base linear velocity. + + valid(physics_model: "jaxsim.physics.model.physics_model.PhysicsModel") -> bool: + Check if the state has valid shapes for the given physics model. + """ + # Joint state joint_positions: jtp.Vector joint_velocities: jtp.Vector