From c6bdb532f96061095de345eec2a435cd9eb0aeaf Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Sat, 4 Jan 2025 11:03:57 +0100 Subject: [PATCH 1/6] Remove unnecessary typehints in docstrings --- src/jaxsim/math/adjoint.py | 19 ++++++++--------- src/jaxsim/math/cross.py | 4 ++-- src/jaxsim/math/inertia.py | 8 +++---- src/jaxsim/math/quaternion.py | 12 +++++------ src/jaxsim/math/rotation.py | 6 +++--- src/jaxsim/math/skew.py | 4 ++-- src/jaxsim/parsers/descriptions/joint.py | 27 ++++++++++++------------ 7 files changed, 39 insertions(+), 41 deletions(-) diff --git a/src/jaxsim/math/adjoint.py b/src/jaxsim/math/adjoint.py index 0356fd026..5d763eb3b 100644 --- a/src/jaxsim/math/adjoint.py +++ b/src/jaxsim/math/adjoint.py @@ -18,11 +18,10 @@ def from_quaternion_and_translation( Create an adjoint matrix from a quaternion and a translation. Args: - quaternion (jtp.Vector): A quaternion vector (4D) representing orientation. - translation (jtp.Vector): A translation vector (3D). - inverse (bool): Whether to compute the inverse adjoint. Default is False. - normalize_quaternion (bool): Whether to normalize the quaternion before creating the adjoint. - Default is False. + quaternion: A quaternion vector (4D) representing orientation. + translation: A translation vector (3D). + inverse: Whether to compute the inverse adjoint. + normalize_quaternion: Whether to normalize the quaternion before creating the adjoint. Returns: jtp.Matrix: The adjoint matrix. @@ -69,9 +68,9 @@ def from_rotation_and_translation( Create an adjoint matrix from a rotation matrix and a translation vector. Args: - rotation (jtp.Matrix): A 3x3 rotation matrix. - translation (jtp.Vector): A translation vector (3D). - inverse (bool): Whether to compute the inverse adjoint. Default is False. + rotation: A 3x3 rotation matrix. + translation: A translation vector (3D). + inverse: Whether to compute the inverse adjoint. Default is False. Returns: jtp.Matrix: The adjoint matrix. @@ -105,7 +104,7 @@ def to_transform(adjoint: jtp.Matrix) -> jtp.Matrix: Convert an adjoint matrix to a transformation matrix. Args: - adjoint (jtp.Matrix): The adjoint matrix (6x6). + adjoint: The adjoint matrix (6x6). Returns: jtp.Matrix: The transformation matrix (4x4). @@ -131,7 +130,7 @@ def inverse(adjoint: jtp.Matrix) -> jtp.Matrix: Compute the inverse of an adjoint matrix. Args: - adjoint (jtp.Matrix): The adjoint matrix. + adjoint: The adjoint matrix. Returns: jtp.Matrix: The inverse adjoint matrix. diff --git a/src/jaxsim/math/cross.py b/src/jaxsim/math/cross.py index b94ceaa8a..0b1c5579b 100644 --- a/src/jaxsim/math/cross.py +++ b/src/jaxsim/math/cross.py @@ -12,7 +12,7 @@ def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix: Compute the cross product matrix for 6D velocities. Args: - velocity_sixd (jtp.Vector): A 6D velocity vector [v, ω]. + velocity_sixd: A 6D velocity vector [v, ω]. Returns: jtp.Matrix: The cross product matrix (6x6). @@ -37,7 +37,7 @@ def vx_star(velocity_sixd: jtp.Vector) -> jtp.Matrix: Compute the negative transpose of the cross product matrix for 6D velocities. Args: - velocity_sixd (jtp.Vector): A 6D velocity vector [v, ω]. + velocity_sixd: A 6D velocity vector [v, ω]. Returns: jtp.Matrix: The negative transpose of the cross product matrix (6x6). diff --git a/src/jaxsim/math/inertia.py b/src/jaxsim/math/inertia.py index f8dc89dce..12349eade 100644 --- a/src/jaxsim/math/inertia.py +++ b/src/jaxsim/math/inertia.py @@ -12,9 +12,9 @@ def to_sixd(mass: jtp.Float, com: jtp.Vector, I: jtp.Matrix) -> jtp.Matrix: Convert mass, center of mass, and inertia matrix to a 6x6 inertia matrix. Args: - mass (jtp.Float): The mass of the body. - com (jtp.Vector): The center of mass position (3D). - I (jtp.Matrix): The 3x3 inertia matrix. + mass: The mass of the body. + com: The center of mass position (3D). + I: The 3x3 inertia matrix. Returns: jtp.Matrix: The 6x6 inertia matrix. @@ -42,7 +42,7 @@ def to_params(M: jtp.Matrix) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]: Convert a 6x6 inertia matrix to mass, center of mass, and inertia matrix. Args: - M (jtp.Matrix): The 6x6 inertia matrix. + M: The 6x6 inertia matrix. Returns: tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3). diff --git a/src/jaxsim/math/quaternion.py b/src/jaxsim/math/quaternion.py index 4870f1aa0..a87321bac 100644 --- a/src/jaxsim/math/quaternion.py +++ b/src/jaxsim/math/quaternion.py @@ -14,7 +14,7 @@ def to_xyzw(wxyz: jtp.Vector) -> jtp.Vector: Convert a quaternion from WXYZ to XYZW representation. Args: - wxyz (jtp.Vector): Quaternion in WXYZ representation. + wxyz: Quaternion in WXYZ representation. Returns: jtp.Vector: Quaternion in XYZW representation. @@ -27,7 +27,7 @@ def to_wxyz(xyzw: jtp.Vector) -> jtp.Vector: Convert a quaternion from XYZW to WXYZ representation. Args: - xyzw (jtp.Vector): Quaternion in XYZW representation. + xyzw: Quaternion in XYZW representation. Returns: jtp.Vector: Quaternion in WXYZ representation. @@ -40,7 +40,7 @@ def to_dcm(quaternion: jtp.Vector) -> jtp.Matrix: Convert a quaternion to a direction cosine matrix (DCM). Args: - quaternion (jtp.Vector): Quaternion in XYZW representation. + quaternion: Quaternion in XYZW representation. Returns: jtp.Matrix: Direction cosine matrix (DCM). @@ -53,7 +53,7 @@ def from_dcm(dcm: jtp.Matrix) -> jtp.Vector: Convert a direction cosine matrix (DCM) to a quaternion. Args: - dcm (jtp.Matrix): Direction cosine matrix (DCM). + dcm: Direction cosine matrix (DCM). Returns: jtp.Vector: Quaternion in XYZW representation. @@ -71,8 +71,8 @@ def derivative( Compute the derivative of a quaternion given angular velocity. Args: - quaternion (jtp.Vector): Quaternion in XYZW representation. - omega (jtp.Vector): Angular velocity vector. + quaternion: Quaternion in XYZW representation. + omega: Angular velocity vector. omega_in_body_fixed (bool): Whether the angular velocity is in the body-fixed frame. K (float): A scaling factor. diff --git a/src/jaxsim/math/rotation.py b/src/jaxsim/math/rotation.py index 471f496b8..a2d942a55 100644 --- a/src/jaxsim/math/rotation.py +++ b/src/jaxsim/math/rotation.py @@ -15,7 +15,7 @@ def x(theta: jtp.Float) -> jtp.Matrix: Generate a 3D rotation matrix around the X-axis. Args: - theta (jtp.Float): Rotation angle in radians. + theta: Rotation angle in radians. Returns: jtp.Matrix: 3D rotation matrix. @@ -29,7 +29,7 @@ def y(theta: jtp.Float) -> jtp.Matrix: Generate a 3D rotation matrix around the Y-axis. Args: - theta (jtp.Float): Rotation angle in radians. + theta: Rotation angle in radians. Returns: jtp.Matrix: 3D rotation matrix. @@ -43,7 +43,7 @@ def z(theta: jtp.Float) -> jtp.Matrix: Generate a 3D rotation matrix around the Z-axis. Args: - theta (jtp.Float): Rotation angle in radians. + theta: Rotation angle in radians. Returns: jtp.Matrix: 3D rotation matrix. diff --git a/src/jaxsim/math/skew.py b/src/jaxsim/math/skew.py index eb1163df4..3a0f0bc4d 100644 --- a/src/jaxsim/math/skew.py +++ b/src/jaxsim/math/skew.py @@ -14,7 +14,7 @@ def wedge(vector: jtp.Vector) -> jtp.Matrix: Compute the skew-symmetric matrix (wedge operator) of a 3D vector. Args: - vector (jtp.Vector): A 3D vector. + vector: A 3D vector. Returns: jtp.Matrix: The skew-symmetric matrix corresponding to the input vector. @@ -31,7 +31,7 @@ def vee(matrix: jtp.Matrix) -> jtp.Vector: Extract the 3D vector from a skew-symmetric matrix (vee operator). Args: - matrix (jtp.Matrix): A 3x3 skew-symmetric matrix. + matrix: A 3x3 skew-symmetric matrix. Returns: jtp.Vector: The 3D vector extracted from the input matrix. diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index 1698d8099..c4432365b 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -47,20 +47,19 @@ class JointDescription(JaxsimDataclass): In-memory description of a robot link. Attributes: - name (str): The name of the joint. - axis (npt.NDArray): The axis of rotation or translation for the joint. - pose (npt.NDArray): The pose transformation matrix of the joint. - jtype (JointType): The type of the joint. - child (LinkDescription): The child link attached to the joint. - parent (LinkDescription): The parent link attached to the joint. - index (Optional[int]): An optional index for the joint. - friction_static (float): The static friction coefficient for the joint. - friction_viscous (float): The viscous friction coefficient for the joint. - position_limit_damper (float): The damper coefficient for position limits. - position_limit_spring (float): The spring coefficient for position limits. - position_limit (Tuple[float, float]): The position limits for the joint. - initial_position (Union[float, npt.NDArray]): The initial position of the joint. - + name: The name of the joint. + axis: The axis of rotation or translation for the joint. + pose: The pose transformation matrix of the joint. + jtype: The type of the joint. + child: The child link attached to the joint. + parent: The parent link attached to the joint. + index: An optional index for the joint. + friction_static: The static friction coefficient for the joint. + friction_viscous: The viscous friction coefficient for the joint. + position_limit_damper: The damper coefficient for position limits. + position_limit_spring: The spring coefficient for position limits. + position_limit: The position limits for the joint. + initial_position: The initial position of the joint. """ name: jax_dataclasses.Static[str] From e1aec20414726a6cbcb2cced9c78dbdc41961e7e Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Sat, 4 Jan 2025 11:04:26 +0100 Subject: [PATCH 2/6] Add `pydocstyle` check in `ruff` configuration --- pyproject.toml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 32d9823c5..d5fb15a05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,6 +150,7 @@ preview = true # https://docs.astral.sh/ruff/rules/ select = [ "B", + "D", "E", "F", "I", @@ -162,6 +163,15 @@ select = [ ignore = [ "B008", # Function call in default argument "B024", # Abstract base class without abstract methods + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic method + "D200", # One-line docstring should fit on one line with quotes + "D202", # No blank lines allowed after function docstring + "D205", # 1 blank line required between summary line and description + "D212", # Multi-line docstring summary should start at the first line + "D411", # Missing blank line before section + "D413", # Missing blank line after last section "E402", # Module level import not at top of file "E501", # Line too long "E731", # Do not assign a `lambda` expression, use a `def` @@ -173,9 +183,11 @@ ignore = [ [tool.ruff.lint.per-file-ignores] # Ignore `E402` (import violations) in all `__init__.py` files "**/{tests,docs,tools}/*" = ["E402"] -"**/{tests}/*" = ["B007"] +"**/{tests,examples}/*" = ["B007", "D100", "D102", "D103"] "__init__.py" = ["F401"] "docs/conf.py" = ["F401"] +"src/jaxsim/exceptions.py" = ["D401"] +"src/jaxsim/logging.py" = ["D101", "D103"] # ================== # Pixi configuration From 095b3e0562e4f92ad200a22dfbeb8553da98c22e Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Sat, 4 Jan 2025 11:26:07 +0100 Subject: [PATCH 3/6] Use imperative mood in docstrings first line --- src/jaxsim/api/com.py | 2 +- src/jaxsim/api/common.py | 2 +- src/jaxsim/api/model.py | 8 +++--- src/jaxsim/mujoco/loaders.py | 12 ++++----- src/jaxsim/mujoco/model.py | 30 +++++++++++------------ src/jaxsim/mujoco/visualizer.py | 10 ++++---- src/jaxsim/parsers/rod/meshes.py | 10 ++++---- src/jaxsim/parsers/rod/parser.py | 2 +- src/jaxsim/rbda/contacts/relaxed_rigid.py | 2 +- src/jaxsim/rbda/contacts/rigid.py | 3 ++- src/jaxsim/utils/jaxsim_dataclass.py | 6 ++--- src/jaxsim/utils/tracing.py | 4 +-- tests/conftest.py | 4 +-- tests/test_simulations.py | 3 ++- 14 files changed, 50 insertions(+), 48 deletions(-) diff --git a/src/jaxsim/api/com.py b/src/jaxsim/api/com.py index f2122ced1..e71c87e99 100644 --- a/src/jaxsim/api/com.py +++ b/src/jaxsim/api/com.py @@ -279,7 +279,7 @@ def other_representation_to_body( C_v̇_WL: jtp.Vector, C_v_WC: jtp.Vector, L_H_C: jtp.Matrix, L_v_LC: jtp.Vector ) -> jtp.Vector: """ - Helper to convert the body-fixed representation of the link bias acceleration + Convert the body-fixed representation of the link bias acceleration C_v̇_WL expressed in a generic frame C to the body-fixed representation L_v̇_WL. """ diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index 7d723120e..6b47ca03d 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -26,7 +26,7 @@ def named_scope(fn, name: str | None = None) -> Callable[_P, _R]: - """Applies a JAX named scope to a function for improved profiling and clarity.""" + """Apply a JAX named scope to a function for improved profiling and clarity.""" @functools.wraps(fn) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index fc30bf66d..85ab825d1 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1015,7 +1015,7 @@ def to_active( W_v̇_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WB: jtp.Vector, W_v_WC: jtp.Vector ) -> jtp.Vector: """ - Helper to convert the inertial-fixed apparent base acceleration W_v̇_WB to + Convert the inertial-fixed apparent base acceleration W_v̇_WB to another representation C_v̇_WB expressed in a generic frame C. """ @@ -1376,7 +1376,7 @@ def inverse_dynamics( def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): """ - Helper to convert the active representation of the base acceleration C_v̇_WB + Convert the active representation of the base acceleration C_v̇_WB expressed in a generic frame C to the inertial-fixed representation W_v̇_WB. """ @@ -1825,7 +1825,7 @@ def other_representation_to_inertial( C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector ) -> jtp.Vector: """ - Helper to convert the active representation of the base acceleration C_v̇_WB + Convert the active representation of the base acceleration C_v̇_WB expressed in a generic frame C to the inertial-fixed representation W_v̇_WB. """ @@ -1961,7 +1961,7 @@ def body_to_other_representation( L_v̇_WL: jtp.Vector, L_v_WL: jtp.Vector, C_H_L: jtp.Matrix, L_v_CL: jtp.Vector ) -> jtp.Vector: """ - Helper to convert the body-fixed apparent acceleration L_v̇_WL to + Convert the body-fixed apparent acceleration L_v̇_WL to another representation C_v̇_WL expressed in a generic frame C. """ diff --git a/src/jaxsim/mujoco/loaders.py b/src/jaxsim/mujoco/loaders.py index 9dd056783..dd751f9ca 100644 --- a/src/jaxsim/mujoco/loaders.py +++ b/src/jaxsim/mujoco/loaders.py @@ -22,7 +22,7 @@ def load_rod_model( model_name: str | None = None, ) -> rod.Model: """ - Loads a ROD model from a URDF/SDF file or a ROD model. + Load a ROD model from a URDF/SDF file or a ROD model. Args: model_description: The URDF/SDF file or ROD model to load. @@ -69,7 +69,7 @@ def assets_from_rod_model( rod_model: rod.Model, ) -> dict[str, bytes]: """ - Generates a dictionary of assets from a ROD model. + Generate a dictionary of assets from a ROD model. Args: rod_model: The ROD model to extract the assets from. @@ -112,7 +112,7 @@ def add_floating_joint( floating_joint_name: str = "world_to_base", ) -> str: """ - Adds a floating joint to a URDF string. + Add a floating joint to a URDF string. Args: urdf_string: The URDF string to modify. @@ -171,7 +171,7 @@ def convert( cameras: MujocoCameraType = (), ) -> tuple[str, dict[str, Any]]: """ - Converts a ROD model to a Mujoco MJCF string. + Convert a ROD model to a Mujoco MJCF string. Args: rod_model: The ROD model to convert. @@ -532,7 +532,7 @@ def convert( cameras: MujocoCameraType = (), ) -> tuple[str, dict[str, Any]]: """ - Converts a URDF file to a Mujoco MJCF string. + Convert a URDF file to a Mujoco MJCF string. Args: urdf: The URDF file to convert. @@ -574,7 +574,7 @@ def convert( cameras: MujocoCameraType = (), ) -> tuple[str, dict[str, Any]]: """ - Converts a SDF file to a Mujoco MJCF string. + Convert a SDF file to a Mujoco MJCF string. Args: sdf: The SDF file to convert. diff --git a/src/jaxsim/mujoco/model.py b/src/jaxsim/mujoco/model.py index e9c207c89..38a6c928b 100644 --- a/src/jaxsim/mujoco/model.py +++ b/src/jaxsim/mujoco/model.py @@ -254,17 +254,17 @@ def is_dcm(R): # ================== def number_of_joints(self) -> int: - """Returns the number of joints in the model.""" + """Return the number of joints in the model.""" return self.model.njnt def number_of_dofs(self) -> int: - """Returns the number of DoFs in the model.""" + """Return the number of DoFs in the model.""" return self.model.nq def joint_names(self) -> list[str]: - """Returns the names of the joints in the model.""" + """Return the names of the joints in the model.""" return [ mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, idx) @@ -272,7 +272,7 @@ def joint_names(self) -> list[str]: ] def joint_dofs(self, joint_name: str) -> int: - """Returns the number of DoFs of a joint.""" + """Return the number of DoFs of a joint.""" if joint_name not in self.joint_names(): raise ValueError(f"Joint '{joint_name}' not found") @@ -280,7 +280,7 @@ def joint_dofs(self, joint_name: str) -> int: return self.data.joint(joint_name).qpos.size def joint_position(self, joint_name: str) -> npt.NDArray: - """Returns the position of a joint.""" + """Return the position of a joint.""" if joint_name not in self.joint_names(): raise ValueError(f"Joint '{joint_name}' not found") @@ -288,7 +288,7 @@ def joint_position(self, joint_name: str) -> npt.NDArray: return self.data.joint(joint_name).qpos def joint_positions(self, joint_names: list[str] | None = None) -> npt.NDArray: - """Returns the positions of the joints.""" + """Return the positions of the joints.""" joint_names = joint_names if joint_names is not None else self.joint_names() @@ -299,7 +299,7 @@ def joint_positions(self, joint_names: list[str] | None = None) -> npt.NDArray: def set_joint_position( self, joint_name: str, position: npt.NDArray | float ) -> None: - """Sets the position of a joint.""" + """Set the position of a joint.""" position = np.atleast_1d(np.array(position).squeeze()) @@ -328,12 +328,12 @@ def set_joint_positions( # ================== def number_of_bodies(self) -> int: - """Returns the number of bodies in the model.""" + """Return the number of bodies in the model.""" return self.model.nbody def body_names(self) -> list[str]: - """Returns the names of the bodies in the model.""" + """Return the names of the bodies in the model.""" return [ mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, idx) @@ -341,7 +341,7 @@ def body_names(self) -> list[str]: ] def body_position(self, body_name: str) -> npt.NDArray: - """Returns the position of a body.""" + """Return the position of a body.""" if body_name not in self.body_names(): raise ValueError(f"Body '{body_name}' not found") @@ -349,7 +349,7 @@ def body_position(self, body_name: str) -> npt.NDArray: return self.data.body(body_name).xpos def body_orientation(self, body_name: str, dcm: bool = False) -> npt.NDArray: - """Returns the orientation of a body.""" + """Return the orientation of a body.""" if body_name not in self.body_names(): raise ValueError(f"Body '{body_name}' not found") @@ -363,12 +363,12 @@ def body_orientation(self, body_name: str, dcm: bool = False) -> npt.NDArray: # ====================== def number_of_geometries(self) -> int: - """Returns the number of geometries in the model.""" + """Return the number of geometries in the model.""" return self.model.ngeom def geometry_names(self) -> list[str]: - """Returns the names of the geometries in the model.""" + """Return the names of the geometries in the model.""" return [ mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_GEOM, idx) @@ -376,7 +376,7 @@ def geometry_names(self) -> list[str]: ] def geometry_position(self, geometry_name: str) -> npt.NDArray: - """Returns the position of a geometry.""" + """Return the position of a geometry.""" if geometry_name not in self.geometry_names(): raise ValueError(f"Geometry '{geometry_name}' not found") @@ -386,7 +386,7 @@ def geometry_position(self, geometry_name: str) -> npt.NDArray: def geometry_orientation( self, geometry_name: str, dcm: bool = False ) -> npt.NDArray: - """Returns the orientation of a geometry.""" + """Return the orientation of a geometry.""" if geometry_name not in self.geometry_names(): raise ValueError(f"Geometry '{geometry_name}' not found") diff --git a/src/jaxsim/mujoco/visualizer.py b/src/jaxsim/mujoco/visualizer.py index 9a07dfad1..05f4d3e4a 100644 --- a/src/jaxsim/mujoco/visualizer.py +++ b/src/jaxsim/mujoco/visualizer.py @@ -64,7 +64,7 @@ def reset( self.model = model if model is not None else self.model def render_frame(self, camera_name: str = "track") -> npt.NDArray: - """Renders a frame.""" + """Render a frame.""" mujoco.mj_forward(self.model, self.data) self.renderer.update_scene(data=self.data, camera=camera_name) @@ -72,13 +72,13 @@ def render_frame(self, camera_name: str = "track") -> npt.NDArray: return self.renderer.render() def record_frame(self, camera_name: str = "track") -> None: - """Stores a frame in the buffer.""" + """Store a frame in the buffer.""" frame = self.render_frame(camera_name=camera_name) self.frames.append(frame) def write_video(self, path: pathlib.Path, exist_ok: bool = False) -> None: - """Writes the video to a file.""" + """Write the video to a file.""" # Resolve the path to the video. path = path.expanduser().resolve() @@ -139,7 +139,7 @@ def sync( model: mj.MjModel | None = None, data: mj.MjData | None = None, ) -> None: - """Updates the viewer with the current model and data.""" + """Update the viewer with the current model and data.""" data = data if data is not None else self.data model = model if model is not None else self.model @@ -150,7 +150,7 @@ def sync( def open_viewer( self, model: mj.MjModel | None = None, data: mj.MjData | None = None ) -> mj.viewer.Handle: - """Opens a viewer.""" + """Open a viewer.""" data = data if data is not None else self.data model = model if model is not None else self.model diff --git a/src/jaxsim/parsers/rod/meshes.py b/src/jaxsim/parsers/rod/meshes.py index 9d1ada7b0..3679597e8 100644 --- a/src/jaxsim/parsers/rod/meshes.py +++ b/src/jaxsim/parsers/rod/meshes.py @@ -6,14 +6,14 @@ def extract_points_vertices(mesh: trimesh.Trimesh) -> np.ndarray: """ - Extracts the vertices of a mesh as points. + Extract the vertices of a mesh as points. """ return mesh.vertices def extract_points_random_surface_sampling(mesh: trimesh.Trimesh, n) -> np.ndarray: """ - Extracts N random points from the surface of a mesh. + Extract N random points from the surface of a mesh. Args: mesh: The mesh from which to extract points. @@ -30,7 +30,7 @@ def extract_points_uniform_surface_sampling( mesh: trimesh.Trimesh, n: int ) -> np.ndarray: """ - Extracts N uniformly sampled points from the surface of a mesh. + Extract N uniformly sampled points from the surface of a mesh. Args: mesh: The mesh from which to extract points. @@ -47,7 +47,7 @@ def extract_points_select_points_over_axis( mesh: trimesh.Trimesh, axis: str, direction: str, n: int ) -> np.ndarray: """ - Extracts N points from a mesh along a specified axis. The points are selected based on their position along the axis. + Extract N points from a mesh along a specified axis. The points are selected based on their position along the axis. Args: mesh: The mesh from which to extract points. @@ -75,7 +75,7 @@ def extract_points_aap( lower: float | None = None, ) -> np.ndarray: """ - Extracts points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis. + Extract points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis. Args: mesh: The mesh from which to extract points. diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index fc23420ae..83af6796f 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -364,7 +364,7 @@ def build_model_description( is_urdf: bool | None = None, ) -> descriptions.ModelDescription: """ - Builds a model description from an SDF/URDF resource. + Build a model description from an SDF/URDF resource. Args: model_description: A path to an SDF/URDF file, a string containing its content, diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index fd84c6941..454f33852 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -526,7 +526,7 @@ def imp_aref( vel: jtp.Vector, ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector]: """ - Calculates impedance and offset acceleration in constraint frame. + Calculate impedance and offset acceleration in constraint frame. Args: pos: position in constraint frame. diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 591451172..3d24637ec 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -173,7 +173,8 @@ def compute_impact_velocity( J_WC: jtp.MatrixLike, data: js.data.JaxSimModelData, ) -> jtp.Vector: - """Returns the new velocity of the system after a potential impact. + """ + Return the new velocity of the system after a potential impact. Args: inactive_collidable_points: The activation state of the collidable points. diff --git a/src/jaxsim/utils/jaxsim_dataclass.py b/src/jaxsim/utils/jaxsim_dataclass.py index 23509a1e0..40235a64a 100644 --- a/src/jaxsim/utils/jaxsim_dataclass.py +++ b/src/jaxsim/utils/jaxsim_dataclass.py @@ -124,7 +124,7 @@ def restore_self() -> None: @staticmethod def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...] | None]: """ - Helper method to get the leaf shapes of a PyTree. + Get the leaf shapes of a PyTree. Args: tree: The PyTree to consider. @@ -144,7 +144,7 @@ def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...] | None]: @staticmethod def get_leaf_dtypes(tree: jtp.PyTree) -> tuple: """ - Helper method to get the leaf dtypes of a PyTree. + Get the leaf dtypes of a PyTree. Args: tree: The PyTree to consider. @@ -164,7 +164,7 @@ def get_leaf_dtypes(tree: jtp.PyTree) -> tuple: @staticmethod def get_leaf_weak_types(tree: jtp.PyTree) -> tuple[bool, ...]: """ - Helper method to get the leaf weak types of a PyTree. + Get the leaf weak types of a PyTree. Args: tree: The PyTree to consider. diff --git a/src/jaxsim/utils/tracing.py b/src/jaxsim/utils/tracing.py index 1dacfa7e8..f956160a2 100644 --- a/src/jaxsim/utils/tracing.py +++ b/src/jaxsim/utils/tracing.py @@ -6,7 +6,7 @@ def tracing(var: Any) -> bool | jax.Array: - """Returns True if the variable is being traced by JAX, False otherwise.""" + """Return True if the variable is being traced by JAX, False otherwise.""" return isinstance( var, jax._src.core.Tracer | jax.interpreters.partial_eval.DynamicJaxprTracer @@ -14,6 +14,6 @@ def tracing(var: Any) -> bool | jax.Array: def not_tracing(var: Any) -> bool | jax.Array: - """Returns True if the variable is not being traced by JAX, False otherwise.""" + """Return True if the variable is not being traced by JAX, False otherwise.""" return True if tracing(var) is False else False diff --git a/tests/conftest.py b/tests/conftest.py index cf0c72a0d..94776545d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -149,7 +149,7 @@ def build_jaxsim_model( model_description: str | pathlib.Path | rod.Model, ) -> js.model.JaxSimModel: """ - Helper to build a JaxSim model from a model description. + Build a JaxSim model from a model description. Args: model_description: A model description provided by any fixture provider. @@ -444,7 +444,7 @@ def get_jaxsim_model_fixture( model_name: str, request: pytest.FixtureRequest ) -> str | pathlib.Path: """ - Factory to get the fixture providing the JaxSim model of a robot. + Get the fixture providing the JaxSim model of a robot. Args: model_name: The name of the model. diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 2db8721c8..8300f0747 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -15,7 +15,8 @@ def test_box_with_external_forces( velocity_representation: VelRepr, ): """ - This test simulates a box falling due to gravity. + Simulate a box falling due to gravity. + We apply to its CoM a 6D force that balances exactly the gravitational force. The box should not fall. """ From a4772cc1a14c6e9d16eea630391ce861f7f098ee Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Sat, 4 Jan 2025 11:27:14 +0100 Subject: [PATCH 4/6] Add docstring for `MeshCollision` --- src/jaxsim/parsers/descriptions/collision.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/jaxsim/parsers/descriptions/collision.py b/src/jaxsim/parsers/descriptions/collision.py index 7815af488..0ddb67baa 100644 --- a/src/jaxsim/parsers/descriptions/collision.py +++ b/src/jaxsim/parsers/descriptions/collision.py @@ -158,6 +158,13 @@ def __eq__(self, other: BoxCollision) -> bool: @dataclasses.dataclass class MeshCollision(CollisionShape): + """ + Represents a mesh-shaped collision shape. + + Attributes: + center: The center of the mesh in the local frame of the collision shape. + """ + center: jtp.VectorLike def __hash__(self) -> int: From a7e35a6295ce6e8061bf1064e35852e86e17d3a5 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Sat, 4 Jan 2025 12:01:46 +0100 Subject: [PATCH 5/6] Add missing docstrings --- src/jaxsim/api/contact.py | 3 ++ src/jaxsim/api/kin_dyn_parameters.py | 19 ++++++- src/jaxsim/api/model.py | 3 ++ src/jaxsim/api/ode.py | 22 +++++++- src/jaxsim/exceptions.py | 6 +++ src/jaxsim/integrators/common.py | 62 ++++++++++++++++++++++- src/jaxsim/integrators/fixed_step.py | 21 ++++++++ src/jaxsim/integrators/variable_step.py | 44 ++++++++++++++++ src/jaxsim/math/adjoint.py | 4 ++ src/jaxsim/math/cross.py | 4 ++ src/jaxsim/math/inertia.py | 4 ++ src/jaxsim/math/quaternion.py | 4 ++ src/jaxsim/math/rotation.py | 3 ++ src/jaxsim/math/transform.py | 3 ++ src/jaxsim/mujoco/loaders.py | 12 ++++- src/jaxsim/mujoco/utils.py | 7 ++- src/jaxsim/mujoco/visualizer.py | 8 ++- src/jaxsim/parsers/descriptions/joint.py | 3 ++ src/jaxsim/parsers/kinematic_graph.py | 38 ++++++++++++++ src/jaxsim/parsers/rod/utils.py | 11 ++++ src/jaxsim/rbda/contacts/common.py | 2 + src/jaxsim/rbda/contacts/relaxed_rigid.py | 7 ++- src/jaxsim/rbda/contacts/rigid.py | 4 +- src/jaxsim/rbda/contacts/soft.py | 36 +++++++++++++ src/jaxsim/terrain/terrain.py | 52 +++++++++++++++++++ src/jaxsim/utils/wrappers.py | 9 ++++ 26 files changed, 380 insertions(+), 11 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 294413f7e..ba46278db 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -293,6 +293,9 @@ def in_contact( def estimate_good_soft_contacts_parameters( *args, **kwargs ) -> jaxsim.rbda.contacts.ContactParamsTypes: + """ + Estimate good soft contacts parameters. Deprecated, use `estimate_good_contact_parameters` instead. + """ msg = "This method is deprecated, please use `{}`." logging.warning(msg.format(estimate_good_contact_parameters.__name__)) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 40cddde03..b0adf4e54 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -52,10 +52,16 @@ class KinDynParameters(JaxsimDataclass): @property def parent_array(self) -> jtp.Vector: + r""" + Return the parent array :math:`\lambda(i)` of the model. + """ return self._parent_array.get() @property def support_body_array_bool(self) -> jtp.Matrix: + r""" + Return the boolean support parent array :math:`\kappa_{b}(i)` of the model. + """ return self._support_body_array_bool.get() @staticmethod @@ -648,7 +654,16 @@ def build_from_inertial_parameters( def build_from_flat_parameters( index: jtp.IntLike, parameters: jtp.VectorLike ) -> LinkParameters: + """ + Build a LinkParameters object from a flat vector of parameters. + + Args: + index: The index of the link. + parameters: The flat vector of parameters. + Returns: + The LinkParameters object. + """ index = jnp.array(index).squeeze().astype(int) m = jnp.array(parameters[0]).squeeze().astype(float) @@ -772,7 +787,9 @@ class ContactParameters(JaxsimDataclass): @property def indices_of_enabled_collidable_points(self) -> npt.NDArray: - + """ + Return the indices of the enabled collidable points. + """ return np.where(np.array(self.enabled))[0] @staticmethod diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 85ab825d1..d172be0db 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -63,6 +63,9 @@ class JaxSimModel(JaxsimDataclass): @property def description(self) -> ModelDescription: + """ + Return the model description. + """ return self._description.get() def __eq__(self, other: JaxSimModel) -> bool: diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index a6b685cde..8a7f13f27 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -15,12 +15,32 @@ class SystemDynamicsFromModelAndData(Protocol): + """ + Protocol defining the signature of a function computing the system dynamics + given a model and data object. + """ + def __call__( self, model: js.model.JaxSimModel, data: js.data.JaxSimModelData, **kwargs: dict[str, Any], - ) -> tuple[ODEState, dict[str, Any]]: ... + ) -> tuple[ODEState, dict[str, Any]]: + """ + Compute the system dynamics given a model and data object. + + Args: + model: The model to consider. + data: The data of the considered model. + **kwargs: Additional keyword arguments. + + Returns: + A tuple with an `ODEState` object storing in each of its attributes the + corresponding derivative, and the dictionary of auxiliary data returned + by the system dynamics evaluation. + """ + + pass def wrap_system_dynamics_for_integration( diff --git a/src/jaxsim/exceptions.py b/src/jaxsim/exceptions.py index ae1cff655..7077a02b8 100644 --- a/src/jaxsim/exceptions.py +++ b/src/jaxsim/exceptions.py @@ -61,6 +61,9 @@ def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None: def raise_runtime_error_if( condition: bool | jax.Array, msg: str, *args, **kwargs ) -> None: + """ + Raise a RuntimeError if a condition is met. Useful in jit-compiled functions. + """ return raise_if(condition, RuntimeError, msg, *args, **kwargs) @@ -68,5 +71,8 @@ def raise_runtime_error_if( def raise_value_error_if( condition: bool | jax.Array, msg: str, *args, **kwargs ) -> None: + """ + Raise a ValueError if a condition is met. Useful in jit-compiled functions. + """ return raise_if(condition, ValueError, msg, *args, **kwargs) diff --git a/src/jaxsim/integrators/common.py b/src/jaxsim/integrators/common.py index caa7c4ac2..855b4f763 100644 --- a/src/jaxsim/integrators/common.py +++ b/src/jaxsim/integrators/common.py @@ -36,9 +36,25 @@ class SystemDynamics(Protocol[State, StateDerivative]): + """ + Protocol defining the system dynamics. + """ + def __call__( self, x: State, t: Time, **kwargs - ) -> tuple[StateDerivative, dict[str, Any]]: ... + ) -> tuple[StateDerivative, dict[str, Any]]: + """ + Compute the state derivative of the system. + + Args: + x: The state of the system. + t: The time of the system. + **kwargs: Additional keyword arguments. + + Returns: + The state derivative of the system and the auxiliary dictionary. + """ + pass # ======================= @@ -48,6 +64,9 @@ def __call__( @jax_dataclasses.pytree_dataclass class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]): + """ + Factory class for integrators. + """ dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field( repr=False, hash=False, compare=False, kw_only=True @@ -110,6 +129,9 @@ def step( def __call__( self, x0: State, t0: Time, dt: TimeStep, **kwargs ) -> tuple[NextState, dict[str, Any]]: + """ + Perform a single integration step. + """ pass def init( @@ -121,6 +143,9 @@ def init( include_dynamics_aux_dict: bool = False, **kwargs, ) -> dict[str, Any]: + """ + Initialize the integrator. This method is deprecated. + """ logging.warning( "The 'init' method has been deprecated. There is no need to call it." @@ -131,6 +156,18 @@ def init( @jax_dataclasses.pytree_dataclass class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]): + """ + Base class for explicit Runge-Kutta integrators. + + Attributes: + A: The Runge-Kutta matrix. + b: The weights coefficients. + c: The nodes coefficients. + order_of_bT_rows: The order of the solution. + row_index_of_solution: The row of the integration output corresponding to the final solution. + fsal_enabled_if_supported: Whether to enable the FSAL property, if supported. + index_of_fsal: The index of the intermediate derivative to be used as the first derivative of the next iteration. + """ # The Runge-Kutta matrix. A: ClassVar[jtp.Matrix] @@ -156,10 +193,16 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType] @property def has_fsal(self) -> bool: + """ + Check if the integrator supports the FSAL property. + """ return self.fsal_enabled_if_supported and self.index_of_fsal is not None @property def order(self) -> int: + """ + Return the order of the integrator. + """ return self.order_of_bT_rows[self.row_index_of_solution] @override @@ -221,6 +264,9 @@ def build( def __call__( self, x0: State, t0: Time, dt: TimeStep, **kwargs ) -> tuple[NextState, dict[str, Any]]: + """ + Perform a single integration step. + """ # Here z is a batched state with as many batch elements as b.T rows. # Note that z has multiple batches only if b.T has more than one row, @@ -331,7 +377,9 @@ def get_ẋ0_and_aux_dict() -> tuple[StateDerivative, dict[str, Any]]: def scan_body( carry: jax.Array, i: int | jax.Array ) -> tuple[jax.Array, dict[str, Any]]: - """""" + """ + Compute the kᵢ derivative of the Runge-Kutta stage. + """ # Unpack the carry, i.e. the stacked kᵢ vectors. K = carry @@ -498,6 +546,16 @@ class ExplicitRungeKuttaSO3Mixin: def post_process_state( cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep ) -> js.ode_data.ODEState: + r""" + Post-process the integrated state at :math:`t_f = t_0 + \Delta t` so that the + quaternion is normalized. + + Args: + x0: The initial state of the system. + t0: The initial time of the system. + xf: The final state of the system obtain through the integration. + dt: The time step used for the integration. + """ # Extract the initial base quaternion. W_Q_B_t0 = x0.physics_model.base_quaternion diff --git a/src/jaxsim/integrators/fixed_step.py b/src/jaxsim/integrators/fixed_step.py index 9ec0ef477..31d282089 100644 --- a/src/jaxsim/integrators/fixed_step.py +++ b/src/jaxsim/integrators/fixed_step.py @@ -17,6 +17,9 @@ @jax_dataclasses.pytree_dataclass class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): + """ + Forward Euler integrator. + """ A: ClassVar[jtp.Matrix] = jnp.atleast_2d(0).astype(float) @@ -30,6 +33,9 @@ class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): @jax_dataclasses.pytree_dataclass class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): + """ + Heun's second-order integrator. + """ A: ClassVar[jtp.Matrix] = jnp.array( [ @@ -56,6 +62,9 @@ class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): @jax_dataclasses.pytree_dataclass class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): + """ + Fourth-order Runge-Kutta integrator. + """ A: ClassVar[jtp.Matrix] = jnp.array( [ @@ -89,14 +98,26 @@ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): @jax_dataclasses.pytree_dataclass class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[js.ode_data.ODEState]): + """ + Forward Euler integrator for SO(3) states. + """ + pass @jax_dataclasses.pytree_dataclass class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[js.ode_data.ODEState]): + """ + Heun's second-order integrator for SO(3) states. + """ + pass @jax_dataclasses.pytree_dataclass class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[js.ode_data.ODEState]): + """ + Fourth-order Runge-Kutta integrator for SO(3) states. + """ + pass diff --git a/src/jaxsim/integrators/variable_step.py b/src/jaxsim/integrators/variable_step.py index 5e063b8d0..70b0eefdc 100644 --- a/src/jaxsim/integrators/variable_step.py +++ b/src/jaxsim/integrators/variable_step.py @@ -216,6 +216,17 @@ def flatten(pytree) -> jax.Array: @jax_dataclasses.pytree_dataclass class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): + """ + An Embedded Runge-Kutta integrator. + + This class implements a general-purpose Embedded Runge-Kutta integrator + that can be used to solve ordinary differential equations with adaptive + step sizes. + + The integrator is based on an Explicit Runge-Kutta method, and it uses + two different solutions to estimate the local integration error. The + error is then used to adapt the step size to reach a desired accuracy. + """ AfterInitKey: ClassVar[str] = "after_init" InitializingKey: ClassVar[str] = "initializing" @@ -257,6 +268,7 @@ def init( x0: The initial state of the system. t0: The initial time of the system. dt: The time step of the integration. + **kwargs: Additional parameters. Returns: The metadata of the integrator to be passed to the first step. @@ -296,6 +308,9 @@ def init( def __call__( self, x0: State, t0: Time, dt: TimeStep, **kwargs ) -> tuple[NextState, dict[str, Any]]: + """ + Integrate the system for a single step. + """ # This method is called differently in three stages: # @@ -512,10 +527,16 @@ def reject_step(): @property def order_of_solution(self) -> int: + """ + The order of the solution. + """ return self.order_of_bT_rows[self.row_index_of_solution] @property def order_of_solution_estimate(self) -> int: + """ + The order of the solution estimate. + """ return self.order_of_bT_rows[self.row_index_of_solution_estimate] @classmethod @@ -534,6 +555,23 @@ def build( max_step_rejections: jtp.IntLike = MAX_STEP_REJECTIONS_DEFAULT, **kwargs, ) -> Self: + """ + Build an Embedded Runge-Kutta integrator. + + Args: + dynamics: The system dynamics function. + fsal_enabled_if_supported: + Whether to enable the FSAL property if supported by the integrator. + dt_max: The maximum step size. + dt_min: The minimum step size. + rtol: The relative tolerance. + atol: The absolute tolerance. + safety: The safety factor to shrink the step size. + beta_max: The maximum factor to increase the step size. + beta_min: The minimum factor to increase the step size. + max_step_rejections: The maximum number of step rejections. + **kwargs: Additional parameters. + """ # Check that b.T has enough rows based on the configured index of the # solution estimate. This is necessary for embedded methods. @@ -569,6 +607,9 @@ def build( @jax_dataclasses.pytree_dataclass class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin): + """ + The Heun-Euler integrator for SO(3) dynamics. + """ A: ClassVar[jtp.Matrix] = jnp.array( [ @@ -602,6 +643,9 @@ class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin): @jax_dataclasses.pytree_dataclass class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin): + """ + The Bogacki-Shampine integrator for SO(3) dynamics. + """ A: ClassVar[jtp.Matrix] = jnp.array( [ diff --git a/src/jaxsim/math/adjoint.py b/src/jaxsim/math/adjoint.py index 5d763eb3b..8c84483af 100644 --- a/src/jaxsim/math/adjoint.py +++ b/src/jaxsim/math/adjoint.py @@ -7,6 +7,10 @@ class Adjoint: + """ + A utility class for adjoint matrix operations. + """ + @staticmethod def from_quaternion_and_translation( quaternion: jtp.Vector = jnp.array([1.0, 0, 0, 0]), diff --git a/src/jaxsim/math/cross.py b/src/jaxsim/math/cross.py index 0b1c5579b..821a38aa3 100644 --- a/src/jaxsim/math/cross.py +++ b/src/jaxsim/math/cross.py @@ -6,6 +6,10 @@ class Cross: + """ + A utility class for cross product matrix operations. + """ + @staticmethod def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix: """ diff --git a/src/jaxsim/math/inertia.py b/src/jaxsim/math/inertia.py index 12349eade..3d7a15517 100644 --- a/src/jaxsim/math/inertia.py +++ b/src/jaxsim/math/inertia.py @@ -6,6 +6,10 @@ class Inertia: + """ + A utility class for inertia matrix operations. + """ + @staticmethod def to_sixd(mass: jtp.Float, com: jtp.Vector, I: jtp.Matrix) -> jtp.Matrix: """ diff --git a/src/jaxsim/math/quaternion.py b/src/jaxsim/math/quaternion.py index a87321bac..195e24990 100644 --- a/src/jaxsim/math/quaternion.py +++ b/src/jaxsim/math/quaternion.py @@ -8,6 +8,10 @@ class Quaternion: + """ + A utility class for quaternion operations. + """ + @staticmethod def to_xyzw(wxyz: jtp.Vector) -> jtp.Vector: """ diff --git a/src/jaxsim/math/rotation.py b/src/jaxsim/math/rotation.py index a2d942a55..f62bcc5b6 100644 --- a/src/jaxsim/math/rotation.py +++ b/src/jaxsim/math/rotation.py @@ -8,6 +8,9 @@ class Rotation: + """ + A utility class for rotation matrix operations. + """ @staticmethod def x(theta: jtp.Float) -> jtp.Matrix: diff --git a/src/jaxsim/math/transform.py b/src/jaxsim/math/transform.py index 226e9b801..674744477 100644 --- a/src/jaxsim/math/transform.py +++ b/src/jaxsim/math/transform.py @@ -5,6 +5,9 @@ class Transform: + """ + A utility class for transformation matrix operations. + """ @staticmethod def from_quaternion_and_translation( diff --git a/src/jaxsim/mujoco/loaders.py b/src/jaxsim/mujoco/loaders.py index dd751f9ca..ba476299b 100644 --- a/src/jaxsim/mujoco/loaders.py +++ b/src/jaxsim/mujoco/loaders.py @@ -62,7 +62,9 @@ def load_rod_model( class RodModelToMjcf: - """""" + """ + Class to convert a ROD model to a Mujoco MJCF string. + """ @staticmethod def assets_from_rod_model( @@ -522,6 +524,10 @@ def convert( class UrdfToMjcf: + """ + Class to convert a URDF file to a Mujoco MJCF string. + """ + @staticmethod def convert( urdf: str | pathlib.Path, @@ -564,6 +570,10 @@ def convert( class SdfToMjcf: + """ + Class to convert a SDF file to a Mujoco MJCF string. + """ + @staticmethod def convert( sdf: str | pathlib.Path, diff --git a/src/jaxsim/mujoco/utils.py b/src/jaxsim/mujoco/utils.py index 2afff1732..cb0645a76 100644 --- a/src/jaxsim/mujoco/utils.py +++ b/src/jaxsim/mujoco/utils.py @@ -133,6 +133,9 @@ class MujocoCamera: @classmethod def build(cls, **kwargs) -> MujocoCamera: + """ + Build a Mujoco camera from a dictionary. + """ if not all(isinstance(value, str) for value in kwargs.values()): raise ValueError(f"Values must be strings: {kwargs}") @@ -219,5 +222,7 @@ def build_from_target_view( ) def asdict(self) -> dict[str, str]: - + """ + Convert the camera to a dictionary. + """ return {k: v for k, v in dataclasses.asdict(self).items() if v is not None} diff --git a/src/jaxsim/mujoco/visualizer.py b/src/jaxsim/mujoco/visualizer.py index 05f4d3e4a..36e2c08f9 100644 --- a/src/jaxsim/mujoco/visualizer.py +++ b/src/jaxsim/mujoco/visualizer.py @@ -10,7 +10,9 @@ class MujocoVideoRecorder: - """""" + """ + Video recorder for the MuJoCo passive viewer. + """ def __init__( self, @@ -117,7 +119,9 @@ def compute_down_sampling(original_fps: int, target_min_fps: int) -> int: class MujocoVisualizer: - """""" + """ + Visualizer for the MuJoCo passive viewer. + """ def __init__( self, model: mj.MjModel | None = None, data: mj.MjData | None = None diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index c4432365b..04ccfaa4a 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -14,6 +14,9 @@ @dataclasses.dataclass(frozen=True) class JointType: + """ + Enumeration of joint types. + """ Fixed: ClassVar[int] = 0 Revolute: ClassVar[int] = 1 diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 44d058706..491e694cd 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -97,20 +97,32 @@ class KinematicGraph(Sequence[LinkDescription]): @functools.cached_property def links_dict(self) -> dict[str, LinkDescription]: + """ + Get a dictionary of links indexed by their name. + """ return {l.name: l for l in iter(self)} @functools.cached_property def frames_dict(self) -> dict[str, LinkDescription]: + """ + Get a dictionary of frames indexed by their name. + """ return {f.name: f for f in self.frames} @functools.cached_property def joints_dict(self) -> dict[str, JointDescription]: + """ + Get a dictionary of joints indexed by their name. + """ return {j.name: j for j in self.joints} @functools.cached_property def joints_connection_dict( self, ) -> dict[tuple[str, str], JointDescription]: + """ + Get a dictionary of joints indexed by the tuple (parent, child) link names. + """ return {(j.parent.name, j.child.name): j for j in self.joints} def __post_init__(self) -> None: @@ -734,9 +746,15 @@ def __getitem__(self, key: int | str) -> LinkDescription: raise TypeError(type(key).__name__) def count(self, value: LinkDescription) -> int: + """ + Count the occurrences of a link in the kinematic graph. + """ return list(iter(self)).count(value) def index(self, value: LinkDescription, start: int = 0, stop: int = -1) -> int: + """ + Find the index of a link in the kinematic graph. + """ return list(iter(self)).index(value, start, stop) @@ -747,6 +765,12 @@ def index(self, value: LinkDescription, start: int = 0, stop: int = -1) -> int: @dataclasses.dataclass(frozen=True) class KinematicGraphTransforms: + """ + Class to compute forward kinematics on a kinematic graph. + + Attributes: + graph: The kinematic graph on which to compute forward kinematics. + """ graph: KinematicGraph @@ -767,6 +791,9 @@ def __post_init__(self) -> None: @property def initial_joint_positions(self) -> npt.NDArray: + """ + Get the initial joint positions of the kinematic graph. + """ return np.atleast_1d( np.array(list(self._initial_joint_positions.values())) @@ -910,6 +937,17 @@ def pre_H_suc( joint_axis: npt.NDArray, joint_position: float | None = None, ) -> npt.NDArray: + """ + Compute the SE(3) transform from the predecessor to the successor frame. + + Args: + joint_type: The type of the joint. + joint_axis: The axis of the joint. + joint_position: The position of the joint. + + Returns: + The 4x4 transform matrix from the predecessor to the successor frame. + """ import jaxsim.math diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index 5250d65ba..2c1d6b376 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -223,6 +223,17 @@ def create_mesh_collision( link_description: descriptions.LinkDescription, method: MeshMappingMethod = None, ) -> descriptions.MeshCollision: + """ + Create a mesh collision from an SDF collision element. + + Args: + collision: The SDF collision element. + link_description: The link description. + method: The method to use for mesh wrapping. + + Returns: + The mesh collision description. + """ file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri)) file_type = file.suffix.replace(".", "") diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index 56b403fa7..30c715f0a 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -125,6 +125,7 @@ def compute_contact_forces( Args: model: The robot model considered by the contact model. data: The data of the considered model. + **kwargs: Optional additional arguments, specific to the contact model. Returns: A tuple containing as first element the computed 6D contact force applied to @@ -146,6 +147,7 @@ def compute_link_contact_forces( Args: model: The robot model considered by the contact model. data: The data of the considered model. + **kwargs: Optional additional arguments, specific to the contact model. Returns: A tuple containing as first element the 6D contact force applied to the diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 454f33852..999c4cf90 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -160,6 +160,7 @@ def default(name: str): ) def valid(self) -> jtp.BoolLike: + """Check if the parameters are valid.""" return bool( jnp.all(self.time_constant >= 0.0) @@ -187,6 +188,7 @@ class RelaxedRigidContacts(common.ContactModel): @property def solver_options(self) -> dict[str, Any]: + """Get the solver options.""" return dict( zip( @@ -207,6 +209,7 @@ def build( Args: solver_options: The options to pass to the L-BFGS solver. + **kwargs: The parameters of the relaxed rigid contacts model. Returns: The `RelaxedRigidContacts` instance. @@ -483,8 +486,8 @@ def _regularizers( Args: model: The jaxsim model. - penetration: The point position in the constraint frame. - velocity: The point velocity in the constraint frame. + position_constraint: The position of the collidable points in the constraint frame. + velocity_constraint: The velocity of the collidable points in the constraint frame. parameters: The parameters of the relaxed rigid contacts model. Returns: diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 3d24637ec..879c56f64 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -79,7 +79,7 @@ def build( ) def valid(self) -> jtp.BoolLike: - + """Check if the parameters are valid.""" return bool( jnp.all(self.mu >= 0.0) and jnp.all(self.K >= 0.0) @@ -104,6 +104,7 @@ class RigidContacts(ContactModel): @property def solver_options(self) -> dict[str, Any]: + """Get the solver options as a dictionary.""" return dict( zip( @@ -127,6 +128,7 @@ def build( regularization_delassus: The regularization term to add to the diagonal of the Delassus matrix. solver_options: The options to pass to the QP solver. + **kwargs: Extra arguments which are ignored. Returns: The `RigidContacts` instance. diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 8d4c0d545..43cdf70a1 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -244,6 +244,28 @@ def hunt_crossley_contact_model( p: jtp.FloatLike = 0.5, q: jtp.FloatLike = 0.5, ) -> tuple[jtp.Vector, jtp.Vector]: + """ + Compute the contact force using the Hunt/Crossley model. + + Args: + position: The position of the collidable point. + velocity: The velocity of the collidable point. + tangential_deformation: The material deformation of the collidable point. + terrain: The terrain model. + K: The stiffness parameter. + D: The damping parameter of the soft contacts model. + mu: The static friction coefficient. + p: + The exponent p corresponding to the damping-related non-linearity + of the Hunt/Crossley model. + q: + The exponent q corresponding to the spring-related non-linearity + of the Hunt/Crossley model + + Returns: + A tuple containing the computed contact force and the derivative of the + material deformation. + """ # Convert the input vectors to arrays. W_p_C = jnp.array(position, dtype=float).squeeze() @@ -364,6 +386,20 @@ def compute_contact_force( parameters: SoftContactsParams, terrain: Terrain, ) -> tuple[jtp.Vector, jtp.Vector]: + """ + Compute the contact force. + + Args: + position: The position of the collidable point. + velocity: The velocity of the collidable point. + tangential_deformation: The material deformation of the collidable point. + parameters: The parameters of the soft contacts model. + terrain: The terrain model. + + Returns: + A tuple containing the computed contact force and the derivative of the + material deformation. + """ CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model( position=position, diff --git a/src/jaxsim/terrain/terrain.py b/src/jaxsim/terrain/terrain.py index f6b4ddcc2..f5b364dec 100644 --- a/src/jaxsim/terrain/terrain.py +++ b/src/jaxsim/terrain/terrain.py @@ -13,11 +13,28 @@ class Terrain(abc.ABC): + """ + Base class for terrain models. + + Attributes: + delta: The delta value used for numerical differentiation. + """ delta = 0.010 @abc.abstractmethod def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float: + """ + Compute the height of the terrain at a specific (x, y) location. + + Args: + x: The x-coordinate of the location. + y: The y-coordinate of the location. + + Returns: + The height of the terrain at the specified location. + """ + pass def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector: @@ -47,19 +64,51 @@ def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector: @jax_dataclasses.pytree_dataclass class FlatTerrain(Terrain): + """ + Represents a terrain model with a flat surface and a constant height. + """ _height: float = dataclasses.field(default=0.0, kw_only=True) @staticmethod def build(height: jtp.FloatLike = 0.0) -> FlatTerrain: + """ + Create a FlatTerrain instance with a specified height. + + Args: + height: The height of the flat terrain. + + Returns: + FlatTerrain: A FlatTerrain instance. + """ return FlatTerrain(_height=float(height)) def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float: + """ + Compute the height of the terrain at a specific (x, y) location. + + Args: + x: The x-coordinate of the location. + y: The y-coordinate of the location. + + Returns: + The height of the terrain at the specified location. + """ return jnp.array(self._height, dtype=float) def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector: + """ + Compute the normal vector of the terrain at a specific (x, y) location. + + Args: + x: The x-coordinate of the location. + y: The y-coordinate of the location. + + Returns: + The normal vector of the terrain surface at the specified location. + """ return jnp.array([0.0, 0.0, 1.0], dtype=float) @@ -77,6 +126,9 @@ def __eq__(self, other: FlatTerrain) -> bool: @jax_dataclasses.pytree_dataclass class PlaneTerrain(FlatTerrain): + """ + Represents a terrain model with a flat surface defined by a normal vector. + """ _normal: tuple[float, float, float] = jax_dataclasses.field( default=(0.0, 0.0, 1.0), kw_only=True diff --git a/src/jaxsim/utils/wrappers.py b/src/jaxsim/utils/wrappers.py index ea3eb6a83..bfb29701f 100644 --- a/src/jaxsim/utils/wrappers.py +++ b/src/jaxsim/utils/wrappers.py @@ -25,6 +25,9 @@ class HashlessObject(Generic[T]): obj: T def get(self: HashlessObject[T]) -> T: + """ + Get the wrapped object. + """ return self.obj def __hash__(self) -> int: @@ -52,6 +55,9 @@ class CustomHashedObject(Generic[T]): hash_function: Callable[[T], int] = hash def get(self: CustomHashedObject[T]) -> T: + """ + Get the wrapped object. + """ return self.obj def __hash__(self) -> int: @@ -93,6 +99,9 @@ class HashedNumpyArray: ) def get(self) -> jax.Array | npt.NDArray: + """ + Get the wrapped array. + """ return self.array def __hash__(self) -> int: From 3dca043b95b5bc1e753f9eecd74336f4de312c2e Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Sat, 4 Jan 2025 12:02:10 +0100 Subject: [PATCH 6/6] Update docstrings for clarity and consistency --- src/jaxsim/api/data.py | 3 ++- src/jaxsim/exceptions.py | 2 ++ src/jaxsim/math/utils.py | 4 ++-- src/jaxsim/parsers/descriptions/collision.py | 4 ---- src/jaxsim/parsers/descriptions/model.py | 2 +- src/jaxsim/rbda/contacts/relaxed_rigid.py | 2 +- src/jaxsim/rbda/contacts/rigid.py | 5 +++-- src/jaxsim/rbda/contacts/soft.py | 1 + src/jaxsim/rbda/contacts/visco_elastic.py | 1 + tests/conftest.py | 8 ++++++++ tests/test_meshes.py | 7 +++++-- tests/utils_idyntree.py | 2 ++ 12 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index af3ea7ef2..d03df8fff 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -456,7 +456,8 @@ def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]: @jax.jit def generalized_velocity(self) -> jtp.Vector: r""" - Get the generalized velocity + Get the generalized velocity. + :math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\, \boldsymbol{\omega}_{W,B};\, \mathbf{s}) \in \mathbb{R}^{6+n}` Returns: diff --git a/src/jaxsim/exceptions.py b/src/jaxsim/exceptions.py index 7077a02b8..16590d051 100644 --- a/src/jaxsim/exceptions.py +++ b/src/jaxsim/exceptions.py @@ -17,6 +17,8 @@ def raise_if( msg: The message to display when the exception is raised. The message can be a format string (fmt), whose fields are filled with the args and kwargs. + *args: The arguments to fill the format string. + **kwargs: The keyword arguments to fill the format string """ # Disable host callback if running on unsupported hardware or if the user diff --git a/src/jaxsim/math/utils.py b/src/jaxsim/math/utils.py index 64d7a24ca..82e919a4e 100644 --- a/src/jaxsim/math/utils.py +++ b/src/jaxsim/math/utils.py @@ -5,8 +5,8 @@ def safe_norm(array: jtp.ArrayLike, axis=None) -> jtp.Array: """ - Provides a calculation for an array norm so that it is safe - to compute the gradient and handle NaNs. + Compute an array norm handling NaNs and making sure that + it is safe to get the gradient. Args: array: The array for which to compute the norm. diff --git a/src/jaxsim/parsers/descriptions/collision.py b/src/jaxsim/parsers/descriptions/collision.py index 0ddb67baa..719c92d2b 100644 --- a/src/jaxsim/parsers/descriptions/collision.py +++ b/src/jaxsim/parsers/descriptions/collision.py @@ -22,7 +22,6 @@ class CollidablePoint: parent_link: The parent link to which the collidable point is attached. position: The position of the collidable point relative to the parent link. enabled: A flag indicating whether the collidable point is enabled for collision detection. - """ parent_link: LinkDescription @@ -86,7 +85,6 @@ class CollisionShape(abc.ABC): Attributes: collidable_points: A list of collidable points associated with the collision shape. - """ collidable_points: tuple[CollidablePoint] @@ -107,7 +105,6 @@ class BoxCollision(CollisionShape): Attributes: center: The center of the box in the local frame of the collision shape. - """ center: jtp.VectorLike @@ -135,7 +132,6 @@ class SphereCollision(CollisionShape): Attributes: center: The center of the sphere in the local frame of the collision shape. - """ center: jtp.VectorLike diff --git a/src/jaxsim/parsers/descriptions/model.py b/src/jaxsim/parsers/descriptions/model.py index d9f5f0bed..f8d1446d6 100644 --- a/src/jaxsim/parsers/descriptions/model.py +++ b/src/jaxsim/parsers/descriptions/model.py @@ -158,7 +158,7 @@ def reduce(self, considered_joints: Sequence[str]) -> ModelDescription: Reduce the model by removing specified joints. Args: - The joint names to consider. + considered_joints: Sequence of joint names to consider. Returns: A `ModelDescription` instance that only includes the considered joints. diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 999c4cf90..9778db117 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -112,7 +112,7 @@ def build( damping: jtp.FloatLike | None = None, mu: jtp.FloatLike | None = None, ) -> Self: - """Create a `RelaxedRigidContactsParams` instance""" + """Create a `RelaxedRigidContactsParams` instance.""" def default(name: str): return cls.__dataclass_fields__[name].default_factory() diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 879c56f64..e61a51f45 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -62,7 +62,7 @@ def build( K: jtp.FloatLike | None = None, D: jtp.FloatLike | None = None, ) -> Self: - """Create a `RigidContactParams` instance""" + """Create a `RigidContactParams` instance.""" return cls( mu=jnp.array( @@ -416,7 +416,8 @@ def _compute_ineq_constraint_matrix( inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike ) -> jtp.Matrix: """ - Compute the inequality constraint matrix for a single collidable point + Compute the inequality constraint matrix for a single collidable point. + Rows 0-3: enforce the friction pyramid constraint, Row 4: last one is for the non negativity of the vertical force Row 5: contact complementarity condition diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 43cdf70a1..dde16cfb2 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -207,6 +207,7 @@ def build( model: The robot model considered by the contact model. If passed, it is used to estimate good default parameters. + **kwargs: Additional parameters to pass to the contact model. Returns: The `SoftContacts` instance. diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py index c433fe23d..40ad4ab61 100644 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -206,6 +206,7 @@ def build( If passed, it is used to estimate good default parameters. max_squarings: The maximum number of squarings performed in the matrix exponential. + **kwargs: Extra arguments to ignore. Returns: The `ViscoElasticContacts` instance. diff --git a/tests/conftest.py b/tests/conftest.py index 94776545d..9fcb9be4b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -243,6 +243,7 @@ def ergocub_model_description_path() -> pathlib.Path: Returns: The path to the URDF model description of the ErgoCub robot. + """ try: @@ -271,6 +272,7 @@ def jaxsim_model_ergocub( Returns: The JaxSim model of the ErgoCub robot. + """ return build_jaxsim_model(model_description=ergocub_model_description_path) @@ -283,6 +285,7 @@ def jaxsim_model_ergocub_reduced(jaxsim_model_ergocub) -> js.model.JaxSimModel: Returns: The JaxSim model of the ErgoCub robot with only locomotion joints. + """ model_full = jaxsim_model_ergocub @@ -316,6 +319,7 @@ def jaxsim_model_ur10() -> js.model.JaxSimModel: Returns: The JaxSim model of the UR10 robot. + """ import robot_descriptions.ur10_description @@ -329,6 +333,7 @@ def jaxsim_model_ur10() -> js.model.JaxSimModel: def jaxsim_model_single_pendulum() -> js.model.JaxSimModel: """ Fixture providing the JaxSim model of a single pendulum. + Returns: The JaxSim model of a single pendulum. """ @@ -452,6 +457,7 @@ def get_jaxsim_model_fixture( Returns: The JaxSim model of the robot. + """ match model_name: @@ -507,6 +513,7 @@ def jaxsim_models_types(request) -> pathlib.Path | str: - A robot with no joints. - A fixed-base robot. - A floating-base robot. + """ model_name: str = request.param @@ -580,6 +587,7 @@ def jaxsim_model_box_32bit(set_jax_32bit, request) -> js.model.JaxSimModel: Returns: The JaxSim model of a box with 32-bit precision. + """ return get_jaxsim_model_fixture(model_name="box", request=request) diff --git a/tests/test_meshes.py b/tests/test_meshes.py index 58fcb9827..d9bd66dcc 100644 --- a/tests/test_meshes.py +++ b/tests/test_meshes.py @@ -6,8 +6,9 @@ def test_mesh_wrapping_vertex_extraction(): """ Test the vertex extraction method on different meshes. - 1. A simple box - 2. A sphere + + 1. A simple box. + 2. A sphere. """ # Test 1: A simple box. @@ -29,6 +30,7 @@ def test_mesh_wrapping_vertex_extraction(): def test_mesh_wrapping_aap(): """ Test the AAP wrapping method on different meshes. + 1. A simple box 1.1: Remove all points above x=0.0 1.2: Remove all points below y=0.0 @@ -64,6 +66,7 @@ def test_mesh_wrapping_aap(): def test_mesh_wrapping_points_over_axis(): """ Test the points over axis method on different meshes. + 1. A simple box 1.1: Select 10 points from the lower end of the x-axis 1.2: Select 10 points from the higher end of the y-axis diff --git a/tests/utils_idyntree.py b/tests/utils_idyntree.py index 5376a82ef..f19f412b1 100644 --- a/tests/utils_idyntree.py +++ b/tests/utils_idyntree.py @@ -33,6 +33,7 @@ def build_kindyncomputations_from_jaxsim_model( Note: Only `JaxSimModel` built from URDF files are supported. + """ if ( @@ -98,6 +99,7 @@ def store_jaxsim_data_in_kindyncomputations( Returns: The updated `KinDynComputations` with the state of `JaxSimModelData`. + """ if kin_dyn.dofs() != data.joint_positions().size: