Skip to content

Commit

Permalink
Add documentation for physics module
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Oct 4, 2023
1 parent 66f0bb7 commit 193c80e
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 5 deletions.
17 changes: 17 additions & 0 deletions src/jaxsim/physics/algos/aba.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ def aba(
) -> Tuple[jtp.Vector, jtp.Vector]:
"""
Articulated Body Algorithm (ABA) algorithm for forward dynamics.
Args:
model (PhysicsModel): The physics model of the articulated body or robot.
xfb (jtp.Vector): The floating base state vector containing quaternion (4D) and position (3D).
q (jtp.Vector): Joint positions (Generalized coordinates).
qd (jtp.Vector): Joint velocities.
tau (jtp.Vector): Joint torques or forces.
f_ext (jtp.Matrix, optional): External forces and torques acting on each link. Defaults to None.
Returns:
Tuple[jtp.Vector, jtp.Vector]: A tuple containing the resulting base acceleration (in inertial-fixed representation)
and joint accelerations.
Note:
The ABA algorithm is used to compute the accelerations of the links in an articulated body or robot system given
inputs such as joint positions, velocities, torques, and external forces. The algorithm involves multiple passes
to calculate intermediate quantities required for simulating the motion of the robot.
"""

x_fb, q, qd, _, tau, f_ext = utils.process_inputs(
Expand Down
11 changes: 11 additions & 0 deletions src/jaxsim/physics/algos/crba.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@


def crba(model: PhysicsModel, q: jtp.Vector) -> jtp.Matrix:
"""
Compute the Composite Rigid-Body Inertia Matrix (CRBA) for an articulated body or robot given joint positions.
Args:
model (PhysicsModel): The physics model of the articulated body or robot.
q (jtp.Vector): Joint positions (Generalized coordinates).
Returns:
jtp.Matrix: The Composite Rigid-Body Inertia Matrix (CRBA) of the articulated body or robot.
"""

_, q, _, _, _, _ = utils.process_inputs(
physics_model=model, xfb=None, q=q, qd=None, tau=None, f_ext=None
)
Expand Down
12 changes: 12 additions & 0 deletions src/jaxsim/physics/algos/forward_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@
def forward_kinematics_model(
model: PhysicsModel, q: jtp.Vector, xfb: jtp.Vector
) -> jtp.Array:
"""
Compute the forward kinematics transformations for all links in an articulated body or robot.
Args:
model (PhysicsModel): The physics model of the articulated body or robot.
q (jtp.Vector): Joint positions (Generalized coordinates).
xfb (jtp.Vector): The base pose vector, including the quaternion (first 4 elements) and translation (last 3 elements).
Returns:
jtp.Array: A 3D array containing the forward kinematics transformations for all links.
"""

x_fb, q, _, _, _, _ = utils.process_inputs(
physics_model=model, xfb=xfb, q=q, qd=None, tau=None, f_ext=None
)
Expand Down
11 changes: 11 additions & 0 deletions src/jaxsim/physics/algos/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@


def jacobian(model: PhysicsModel, body_index: jtp.Int, q: jtp.Vector) -> jtp.Matrix:
"""
Compute the Jacobian matrix for a specific link in an articulated body or robot.
Args:
model (PhysicsModel): The physics model of the articulated body or robot.
body_index (jtp.Int): The index of the link for which to compute the Jacobian matrix.
q (jtp.Vector): Joint positions (Generalized coordinates).
Returns:
jtp.Matrix: The Jacobian matrix for the specified link.
"""
_, q, _, _, _, _ = utils.process_inputs(physics_model=model, q=q)

S = model.motion_subspaces(q=q)
Expand Down
18 changes: 17 additions & 1 deletion src/jaxsim/physics/algos/rnea.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,23 @@ def rnea(
f_ext: jtp.Matrix = None,
) -> Tuple[jtp.Vector, jtp.Vector]:
"""
Recursive Newton-Euler Algorithm (RNEA) algorithm for inverse dynamics.
Perform Inverse Dynamics Calculation using the Recursive Newton-Euler Algorithm (RNEA).
This function calculates the joint torques (forces) required to achieve a desired motion
given the robot's configuration, velocities, accelerations, and external forces.
Args:
model (PhysicsModel): The robot's physics model containing dynamic parameters.
xfb (jtp.Vector): The floating base state, including orientation and position.
q (jtp.Vector): Joint positions (angles).
qd (jtp.Vector): Joint velocities.
qdd (jtp.Vector): Joint accelerations.
a0fb (jtp.Vector, optional): Base acceleration. Defaults to zeros.
f_ext (jtp.Matrix, optional): External forces acting on the robot. Defaults to None.
Returns:
W_f0 (jtp.Vector): The base 6D force expressed in the world frame.
tau (jtp.Vector): Joint torques (forces) required for the desired motion.
"""

xfb, q, qd, qdd, _, f_ext = utils.process_inputs(
Expand Down
66 changes: 65 additions & 1 deletion src/jaxsim/physics/algos/soft_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,41 @@

@jax_dataclasses.pytree_dataclass
class SoftContactsState:
"""State of the soft contacts model."""

tangential_deformation: jtp.Matrix

@staticmethod
def zero(
physics_model: jaxsim.physics.model.physics_model.PhysicsModel,
) -> "SoftContactsState":
"""
Modify the SoftContactsState instance imposing zero tangential deformation.
Args:
physics_model (jaxsim.physics.model.physics_model.PhysicsModel): The physics model.
Returns:
SoftContactsState: A SoftContactsState instance with zero tangential deformation.
"""

return SoftContactsState(
tangential_deformation=jnp.zeros(shape=(3, physics_model.gc.body.size))
)

def valid(
self, physics_model: jaxsim.physics.model.physics_model.PhysicsModel
) -> bool:
"""
Check if the soft contacts state has valid shape.
Args:
physics_model (jaxsim.physics.model.physics_model.PhysicsModel): The physics model.
Returns:
bool: True if the state has a valid shape, otherwise False.
"""

from jaxsim.simulation.utils import check_valid_shape

return check_valid_shape(
Expand All @@ -43,6 +65,17 @@ def valid(
)

def replace(self, validate: bool = True, **kwargs) -> "SoftContactsState":
"""
Replace attributes of the soft contacts state.
Args:
validate (bool, optional): Whether to validate the state after replacement. Defaults to True.
**kwargs: Keyword arguments for attribute replacement.
Returns:
SoftContactsState: A new SoftContactsState instance with replaced attributes.
"""

with jax_dataclasses.copy_and_mutate(self, validate=validate) as updated_state:
_ = [updated_state.__setattr__(k, v) for k, v in kwargs.items()]

Expand All @@ -57,6 +90,15 @@ def collidable_points_pos_vel(
) -> Tuple[jtp.Matrix, jtp.Matrix]:
"""
Compute the position and linear velocity of collidable points in the world frame.
Args:
model (PhysicsModel): The physics model.
q (jtp.Vector): The joint positions.
qd (jtp.Vector): The joint velocities.
xfb (jtp.Vector, optional): The floating base state. Defaults to None.
Returns:
Tuple[jtp.Matrix, jtp.Matrix]: A tuple containing the position and velocity of collidable points.
"""

# Make sure that shape and size are correct
Expand Down Expand Up @@ -195,7 +237,17 @@ class SoftContactsParams:
def build(
K: float = 1e6, D: float = 2_000, mu: float = 0.5
) -> "SoftContactsParams":
""""""
"""
Create a SoftContactsParams instance with specified parameters.
Args:
K (float, optional): The stiffness parameter. Defaults to 1e6.
D (float, optional): The damping parameter. Defaults to 2000.
mu (float, optional): The friction coefficient. Defaults to 0.5.
Returns:
SoftContactsParams: A SoftContactsParams instance with the specified parameters.
"""

return SoftContactsParams(
K=jnp.array(K, dtype=float),
Expand All @@ -220,6 +272,18 @@ def contact_model(
velocity: jtp.Vector,
tangential_deformation: jtp.Vector,
) -> Tuple[jtp.Vector, jtp.Vector]:
"""
Compute the contact forces and material deformation rate.
Args:
position (jtp.Vector): The position of the collidable point.
velocity (jtp.Vector): The linear velocity of the collidable point.
tangential_deformation (jtp.Vector): The tangential deformation.
Returns:
Tuple[jtp.Vector, jtp.Vector]: A tuple containing the contact force and material deformation rate.
"""

# Short name of parameters
K = self.parameters.K
D = self.parameters.D
Expand Down
34 changes: 32 additions & 2 deletions src/jaxsim/physics/algos/terrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ def height(self, x: float, y: float) -> float:
pass

def normal(self, x: float, y: float) -> jtp.Vector:
"""
Compute the normal vector of the terrain at a specific (x, y) location.
Args:
x (float): The x-coordinate of the location.
y (float): The y-coordinate of the location.
Returns:
jtp.Vector: The normal vector of the terrain surface at the specified location.
"""

# https://stackoverflow.com/a/5282364
h_xp = self.height(x=x + self.delta, y=y)
h_xm = self.height(x=x - self.delta, y=y)
Expand Down Expand Up @@ -41,10 +52,29 @@ class PlaneTerrain(Terrain):

@staticmethod
def build(plane_normal: jtp.Vector) -> "PlaneTerrain":
""""""
"""
Create a PlaneTerrain instance with a specified plane normal vector.
Args:
plane_normal (jtp.Vector): The normal vector of the terrain plane.
Returns:
PlaneTerrain: A PlaneTerrain instance.
"""

return PlaneTerrain(plane_normal=jnp.array(plane_normal, dtype=float))

def height(self, x: float, y: float) -> float:
"""
Compute the height of the terrain at a specific (x, y) location on a plane.
Args:
x (float): The x-coordinate of the location.
y (float): The y-coordinate of the location.
Returns:
float: The height of the terrain at the specified location on the plane.
"""

a, b, c = self.plane_normal
return -(a * x + b * x) / c
return -(a * x + b * x) / c
15 changes: 14 additions & 1 deletion src/jaxsim/physics/model/ground_contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,20 @@
@jax_dataclasses.pytree_dataclass
class GroundContact:
"""
Class to store the collidable points of a robot model.
A class for managing collidable points in a robot model.
This class is used to store and manage information about collidable points on a robot model,
such as their positions and the corresponding bodies (links) they are associated with.
Attributes:
point (npt.NDArray): An array of shape (3, N) representing the 3D positions of collidable points.
body (Static[npt.NDArray]): An array of integers representing the indices of the bodies (links)
associated with each collidable point.
Methods:
build_from(model_description: ModelDescription) -> GroundContact:
A static method to build a GroundContact object from a ModelDescription instance.
It extracts collidable points' positions and their associated bodies from the model description.
"""

point: npt.NDArray = dataclasses.field(default_factory=lambda: jnp.array([]))
Expand Down
45 changes: 45 additions & 0 deletions src/jaxsim/physics/model/physics_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,51 @@
class PhysicsModel(JaxsimDataclass):
"""
A read-only class to store all the information necessary to run RBDAs on a model.
This class contains information about the physics model, including the number of bodies, initial state, gravity,
floating base configuration, ground contact points, and more.
Attributes:
NB (Static[int]): The number of bodies in the physics model.
initial_state (PhysicsModelState): The initial state of the physics model (default: None).
gravity (jtp.Vector): The gravity vector (default: [0, 0, 0, 0, 0, 0]).
is_floating_base (Static[bool]): A flag indicating whether the model has a floating base (default: False).
gc (GroundContact): The ground contact points of the model (default: empty GroundContact instance).
description (Static[jaxsim.parsers.descriptions.model.ModelDescription]): A description of the model (default: None).
Methods:
build_from(
model_description: jaxsim.parsers.descriptions.model.ModelDescription,
gravity: jtp.Vector = default_gravity()
) -> PhysicsModel:
Create a PhysicsModel instance from a model description and gravity vector.
dofs() -> int:
Get the number of degrees of freedom (DOFs) in the model.
set_gravity(gravity: jtp.Vector) -> None:
Set the gravity vector for the model.
parent_array() -> jtp.Vector:
Get the parent array (λ(i)) for the model.
support_body_array(body_index: jtp.Int) -> jtp.Vector:
Get an array of body indices (κ(i)) that support the specified body.
tree_transforms() -> jtp.Array:
Get an array of tree transforms (pre(i)_X_λ(i)) for all bodies.
spatial_inertias() -> jtp.Array:
Get an array of spatial inertias (M_links) for all bodies.
jtype(joint_index: int) -> JointType:
Get the joint type for the specified joint index.
joint_transforms(q: jtp.Vector) -> jtp.Array:
Compute joint transforms (Xj) for the given joint positions (q).
motion_subspaces(q: jtp.Vector) -> jtp.Array:
Compute motion subspaces (SS) for the given joint positions (q).
"""

NB: Static[int]
Expand Down
32 changes: 32 additions & 0 deletions src/jaxsim/physics/model/physics_model_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,38 @@

@jax_dataclasses.pytree_dataclass
class PhysicsModelState(JaxsimDataclass):
"""
A class representing the state of a physics model.
This class stores the joint positions, joint velocities, and the base state (position, orientation, linear velocity,
and angular velocity) of a physics model.
Attributes:
joint_positions (jtp.Vector): An array representing the joint positions.
joint_velocities (jtp.Vector): An array representing the joint velocities.
base_position (jtp.Vector): An array representing the base position (default: zeros).
base_quaternion (jtp.Vector): An array representing the base quaternion (default: [1.0, 0, 0, 0]).
base_linear_velocity (jtp.Vector): An array representing the base linear velocity (default: zeros).
base_angular_velocity (jtp.Vector): An array representing the base angular velocity (default: zeros).
Methods:
zero(physics_model: "jaxsim.physics.model.physics_model.PhysicsModel") -> PhysicsModelState:
Create a zero-initialized PhysicsModelState for the given physics model.
position() -> jtp.Vector:
Get the full state vector, including joint positions, joint velocities, base position, and base quaternion.
velocity() -> jtp.Vector:
Get the full velocity vector, including base linear velocity, base angular velocity, and joint velocities.
xfb() -> jtp.Vector:
Get the full state vector in the "xfb" format, which includes base quaternion, base position, base angular
velocity, and base linear velocity.
valid(physics_model: "jaxsim.physics.model.physics_model.PhysicsModel") -> bool:
Check if the state has valid shapes for the given physics model.
"""

# Joint state
joint_positions: jtp.Vector
joint_velocities: jtp.Vector
Expand Down

0 comments on commit 193c80e

Please sign in to comment.