Skip to content

Commit

Permalink
Merge pull request #326 from ami-iit/update_docstrings
Browse files Browse the repository at this point in the history
Improve docstring quality and consistency
  • Loading branch information
flferretti authored Jan 8, 2025
2 parents eabfd1b + 3dca043 commit 96523f0
Show file tree
Hide file tree
Showing 44 changed files with 517 additions and 114 deletions.
14 changes: 13 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ preview = true
# https://docs.astral.sh/ruff/rules/
select = [
"B",
"D",
"E",
"F",
"I",
Expand All @@ -162,6 +163,15 @@ select = [
ignore = [
"B008", # Function call in default argument
"B024", # Abstract base class without abstract methods
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
"D105", # Missing docstring in magic method
"D200", # One-line docstring should fit on one line with quotes
"D202", # No blank lines allowed after function docstring
"D205", # 1 blank line required between summary line and description
"D212", # Multi-line docstring summary should start at the first line
"D411", # Missing blank line before section
"D413", # Missing blank line after last section
"E402", # Module level import not at top of file
"E501", # Line too long
"E731", # Do not assign a `lambda` expression, use a `def`
Expand All @@ -173,9 +183,11 @@ ignore = [
[tool.ruff.lint.per-file-ignores]
# Ignore `E402` (import violations) in all `__init__.py` files
"**/{tests,docs,tools}/*" = ["E402"]
"**/{tests}/*" = ["B007"]
"**/{tests,examples}/*" = ["B007", "D100", "D102", "D103"]
"__init__.py" = ["F401"]
"docs/conf.py" = ["F401"]
"src/jaxsim/exceptions.py" = ["D401"]
"src/jaxsim/logging.py" = ["D101", "D103"]

# ==================
# Pixi configuration
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def other_representation_to_body(
C_v̇_WL: jtp.Vector, C_v_WC: jtp.Vector, L_H_C: jtp.Matrix, L_v_LC: jtp.Vector
) -> jtp.Vector:
"""
Helper to convert the body-fixed representation of the link bias acceleration
Convert the body-fixed representation of the link bias acceleration
C_v̇_WL expressed in a generic frame C to the body-fixed representation L_v̇_WL.
"""

Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


def named_scope(fn, name: str | None = None) -> Callable[_P, _R]:
"""Applies a JAX named scope to a function for improved profiling and clarity."""
"""Apply a JAX named scope to a function for improved profiling and clarity."""

@functools.wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
Expand Down
3 changes: 3 additions & 0 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ def in_contact(
def estimate_good_soft_contacts_parameters(
*args, **kwargs
) -> jaxsim.rbda.contacts.ContactParamsTypes:
"""
Estimate good soft contacts parameters. Deprecated, use `estimate_good_contact_parameters` instead.
"""

msg = "This method is deprecated, please use `{}`."
logging.warning(msg.format(estimate_good_contact_parameters.__name__))
Expand Down
3 changes: 2 additions & 1 deletion src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@ def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
@jax.jit
def generalized_velocity(self) -> jtp.Vector:
r"""
Get the generalized velocity
Get the generalized velocity.
:math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\, \boldsymbol{\omega}_{W,B};\, \mathbf{s}) \in \mathbb{R}^{6+n}`
Returns:
Expand Down
19 changes: 18 additions & 1 deletion src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,16 @@ class KinDynParameters(JaxsimDataclass):

@property
def parent_array(self) -> jtp.Vector:
r"""
Return the parent array :math:`\lambda(i)` of the model.
"""
return self._parent_array.get()

@property
def support_body_array_bool(self) -> jtp.Matrix:
r"""
Return the boolean support parent array :math:`\kappa_{b}(i)` of the model.
"""
return self._support_body_array_bool.get()

@staticmethod
Expand Down Expand Up @@ -648,7 +654,16 @@ def build_from_inertial_parameters(
def build_from_flat_parameters(
index: jtp.IntLike, parameters: jtp.VectorLike
) -> LinkParameters:
"""
Build a LinkParameters object from a flat vector of parameters.
Args:
index: The index of the link.
parameters: The flat vector of parameters.
Returns:
The LinkParameters object.
"""
index = jnp.array(index).squeeze().astype(int)

m = jnp.array(parameters[0]).squeeze().astype(float)
Expand Down Expand Up @@ -772,7 +787,9 @@ class ContactParameters(JaxsimDataclass):

@property
def indices_of_enabled_collidable_points(self) -> npt.NDArray:

"""
Return the indices of the enabled collidable points.
"""
return np.where(np.array(self.enabled))[0]

@staticmethod
Expand Down
11 changes: 7 additions & 4 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class JaxSimModel(JaxsimDataclass):

@property
def description(self) -> ModelDescription:
"""
Return the model description.
"""
return self._description.get()

def __eq__(self, other: JaxSimModel) -> bool:
Expand Down Expand Up @@ -1015,7 +1018,7 @@ def to_active(
W_v̇_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WB: jtp.Vector, W_v_WC: jtp.Vector
) -> jtp.Vector:
"""
Helper to convert the inertial-fixed apparent base acceleration W_v̇_WB to
Convert the inertial-fixed apparent base acceleration W_v̇_WB to
another representation C_v̇_WB expressed in a generic frame C.
"""

Expand Down Expand Up @@ -1376,7 +1379,7 @@ def inverse_dynamics(

def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):
"""
Helper to convert the active representation of the base acceleration C_v̇_WB
Convert the active representation of the base acceleration C_v̇_WB
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
"""

Expand Down Expand Up @@ -1825,7 +1828,7 @@ def other_representation_to_inertial(
C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector
) -> jtp.Vector:
"""
Helper to convert the active representation of the base acceleration C_v̇_WB
Convert the active representation of the base acceleration C_v̇_WB
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
"""

Expand Down Expand Up @@ -1961,7 +1964,7 @@ def body_to_other_representation(
L_v̇_WL: jtp.Vector, L_v_WL: jtp.Vector, C_H_L: jtp.Matrix, L_v_CL: jtp.Vector
) -> jtp.Vector:
"""
Helper to convert the body-fixed apparent acceleration L_v̇_WL to
Convert the body-fixed apparent acceleration L_v̇_WL to
another representation C_v̇_WL expressed in a generic frame C.
"""

Expand Down
22 changes: 21 additions & 1 deletion src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,32 @@


class SystemDynamicsFromModelAndData(Protocol):
"""
Protocol defining the signature of a function computing the system dynamics
given a model and data object.
"""

def __call__(
self,
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
**kwargs: dict[str, Any],
) -> tuple[ODEState, dict[str, Any]]: ...
) -> tuple[ODEState, dict[str, Any]]:
"""
Compute the system dynamics given a model and data object.
Args:
model: The model to consider.
data: The data of the considered model.
**kwargs: Additional keyword arguments.
Returns:
A tuple with an `ODEState` object storing in each of its attributes the
corresponding derivative, and the dictionary of auxiliary data returned
by the system dynamics evaluation.
"""

pass


def wrap_system_dynamics_for_integration(
Expand Down
8 changes: 8 additions & 0 deletions src/jaxsim/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def raise_if(
msg:
The message to display when the exception is raised. The message can be a
format string (fmt), whose fields are filled with the args and kwargs.
*args: The arguments to fill the format string.
**kwargs: The keyword arguments to fill the format string
"""

# Disable host callback if running on unsupported hardware or if the user
Expand Down Expand Up @@ -61,12 +63,18 @@ def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None:
def raise_runtime_error_if(
condition: bool | jax.Array, msg: str, *args, **kwargs
) -> None:
"""
Raise a RuntimeError if a condition is met. Useful in jit-compiled functions.
"""

return raise_if(condition, RuntimeError, msg, *args, **kwargs)


def raise_value_error_if(
condition: bool | jax.Array, msg: str, *args, **kwargs
) -> None:
"""
Raise a ValueError if a condition is met. Useful in jit-compiled functions.
"""

return raise_if(condition, ValueError, msg, *args, **kwargs)
62 changes: 60 additions & 2 deletions src/jaxsim/integrators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,25 @@


class SystemDynamics(Protocol[State, StateDerivative]):
"""
Protocol defining the system dynamics.
"""

def __call__(
self, x: State, t: Time, **kwargs
) -> tuple[StateDerivative, dict[str, Any]]: ...
) -> tuple[StateDerivative, dict[str, Any]]:
"""
Compute the state derivative of the system.
Args:
x: The state of the system.
t: The time of the system.
**kwargs: Additional keyword arguments.
Returns:
The state derivative of the system and the auxiliary dictionary.
"""
pass


# =======================
Expand All @@ -48,6 +64,9 @@ def __call__(

@jax_dataclasses.pytree_dataclass
class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
"""
Factory class for integrators.
"""

dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field(
repr=False, hash=False, compare=False, kw_only=True
Expand Down Expand Up @@ -110,6 +129,9 @@ def step(
def __call__(
self, x0: State, t0: Time, dt: TimeStep, **kwargs
) -> tuple[NextState, dict[str, Any]]:
"""
Perform a single integration step.
"""
pass

def init(
Expand All @@ -121,6 +143,9 @@ def init(
include_dynamics_aux_dict: bool = False,
**kwargs,
) -> dict[str, Any]:
"""
Initialize the integrator. This method is deprecated.
"""

logging.warning(
"The 'init' method has been deprecated. There is no need to call it."
Expand All @@ -131,6 +156,18 @@ def init(

@jax_dataclasses.pytree_dataclass
class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]):
"""
Base class for explicit Runge-Kutta integrators.
Attributes:
A: The Runge-Kutta matrix.
b: The weights coefficients.
c: The nodes coefficients.
order_of_bT_rows: The order of the solution.
row_index_of_solution: The row of the integration output corresponding to the final solution.
fsal_enabled_if_supported: Whether to enable the FSAL property, if supported.
index_of_fsal: The index of the intermediate derivative to be used as the first derivative of the next iteration.
"""

# The Runge-Kutta matrix.
A: ClassVar[jtp.Matrix]
Expand All @@ -156,10 +193,16 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]

@property
def has_fsal(self) -> bool:
"""
Check if the integrator supports the FSAL property.
"""
return self.fsal_enabled_if_supported and self.index_of_fsal is not None

@property
def order(self) -> int:
"""
Return the order of the integrator.
"""
return self.order_of_bT_rows[self.row_index_of_solution]

@override
Expand Down Expand Up @@ -221,6 +264,9 @@ def build(
def __call__(
self, x0: State, t0: Time, dt: TimeStep, **kwargs
) -> tuple[NextState, dict[str, Any]]:
"""
Perform a single integration step.
"""

# Here z is a batched state with as many batch elements as b.T rows.
# Note that z has multiple batches only if b.T has more than one row,
Expand Down Expand Up @@ -331,7 +377,9 @@ def get_ẋ0_and_aux_dict() -> tuple[StateDerivative, dict[str, Any]]:
def scan_body(
carry: jax.Array, i: int | jax.Array
) -> tuple[jax.Array, dict[str, Any]]:
""""""
"""
Compute the kᵢ derivative of the Runge-Kutta stage.
"""

# Unpack the carry, i.e. the stacked kᵢ vectors.
K = carry
Expand Down Expand Up @@ -498,6 +546,16 @@ class ExplicitRungeKuttaSO3Mixin:
def post_process_state(
cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
) -> js.ode_data.ODEState:
r"""
Post-process the integrated state at :math:`t_f = t_0 + \Delta t` so that the
quaternion is normalized.
Args:
x0: The initial state of the system.
t0: The initial time of the system.
xf: The final state of the system obtain through the integration.
dt: The time step used for the integration.
"""

# Extract the initial base quaternion.
W_Q_B_t0 = x0.physics_model.base_quaternion
Expand Down
Loading

0 comments on commit 96523f0

Please sign in to comment.