Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve docstring quality and consistency #326

Merged
merged 6 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ preview = true
# https://docs.astral.sh/ruff/rules/
select = [
"B",
"D",
"E",
"F",
"I",
Expand All @@ -162,6 +163,15 @@ select = [
ignore = [
"B008", # Function call in default argument
"B024", # Abstract base class without abstract methods
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
"D105", # Missing docstring in magic method
"D200", # One-line docstring should fit on one line with quotes
"D202", # No blank lines allowed after function docstring
"D205", # 1 blank line required between summary line and description
"D212", # Multi-line docstring summary should start at the first line
"D411", # Missing blank line before section
"D413", # Missing blank line after last section
"E402", # Module level import not at top of file
"E501", # Line too long
"E731", # Do not assign a `lambda` expression, use a `def`
Expand All @@ -173,9 +183,11 @@ ignore = [
[tool.ruff.lint.per-file-ignores]
# Ignore `E402` (import violations) in all `__init__.py` files
"**/{tests,docs,tools}/*" = ["E402"]
"**/{tests}/*" = ["B007"]
"**/{tests,examples}/*" = ["B007", "D100", "D102", "D103"]
"__init__.py" = ["F401"]
"docs/conf.py" = ["F401"]
"src/jaxsim/exceptions.py" = ["D401"]
"src/jaxsim/logging.py" = ["D101", "D103"]

# ==================
# Pixi configuration
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def other_representation_to_body(
C_v̇_WL: jtp.Vector, C_v_WC: jtp.Vector, L_H_C: jtp.Matrix, L_v_LC: jtp.Vector
) -> jtp.Vector:
"""
Helper to convert the body-fixed representation of the link bias acceleration
Convert the body-fixed representation of the link bias acceleration
C_v̇_WL expressed in a generic frame C to the body-fixed representation L_v̇_WL.
"""

Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


def named_scope(fn, name: str | None = None) -> Callable[_P, _R]:
"""Applies a JAX named scope to a function for improved profiling and clarity."""
"""Apply a JAX named scope to a function for improved profiling and clarity."""

@functools.wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
Expand Down
3 changes: 3 additions & 0 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ def in_contact(
def estimate_good_soft_contacts_parameters(
*args, **kwargs
) -> jaxsim.rbda.contacts.ContactParamsTypes:
"""
Estimate good soft contacts parameters. Deprecated, use `estimate_good_contact_parameters` instead.
"""

msg = "This method is deprecated, please use `{}`."
logging.warning(msg.format(estimate_good_contact_parameters.__name__))
Expand Down
3 changes: 2 additions & 1 deletion src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@ def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
@jax.jit
def generalized_velocity(self) -> jtp.Vector:
r"""
Get the generalized velocity
Get the generalized velocity.

:math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\, \boldsymbol{\omega}_{W,B};\, \mathbf{s}) \in \mathbb{R}^{6+n}`

Returns:
Expand Down
19 changes: 18 additions & 1 deletion src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,16 @@ class KinDynParameters(JaxsimDataclass):

@property
def parent_array(self) -> jtp.Vector:
r"""
Return the parent array :math:`\lambda(i)` of the model.
"""
return self._parent_array.get()

@property
def support_body_array_bool(self) -> jtp.Matrix:
r"""
Return the boolean support parent array :math:`\kappa_{b}(i)` of the model.
"""
return self._support_body_array_bool.get()

@staticmethod
Expand Down Expand Up @@ -648,7 +654,16 @@ def build_from_inertial_parameters(
def build_from_flat_parameters(
index: jtp.IntLike, parameters: jtp.VectorLike
) -> LinkParameters:
"""
Build a LinkParameters object from a flat vector of parameters.

Args:
index: The index of the link.
parameters: The flat vector of parameters.

Returns:
The LinkParameters object.
"""
index = jnp.array(index).squeeze().astype(int)

m = jnp.array(parameters[0]).squeeze().astype(float)
Expand Down Expand Up @@ -772,7 +787,9 @@ class ContactParameters(JaxsimDataclass):

@property
def indices_of_enabled_collidable_points(self) -> npt.NDArray:

"""
Return the indices of the enabled collidable points.
"""
return np.where(np.array(self.enabled))[0]

@staticmethod
Expand Down
11 changes: 7 additions & 4 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class JaxSimModel(JaxsimDataclass):

@property
def description(self) -> ModelDescription:
"""
Return the model description.
"""
return self._description.get()

def __eq__(self, other: JaxSimModel) -> bool:
Expand Down Expand Up @@ -1015,7 +1018,7 @@ def to_active(
W_v̇_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WB: jtp.Vector, W_v_WC: jtp.Vector
) -> jtp.Vector:
"""
Helper to convert the inertial-fixed apparent base acceleration W_v̇_WB to
Convert the inertial-fixed apparent base acceleration W_v̇_WB to
another representation C_v̇_WB expressed in a generic frame C.
"""

Expand Down Expand Up @@ -1376,7 +1379,7 @@ def inverse_dynamics(

def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):
"""
Helper to convert the active representation of the base acceleration C_v̇_WB
Convert the active representation of the base acceleration C_v̇_WB
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
"""

Expand Down Expand Up @@ -1825,7 +1828,7 @@ def other_representation_to_inertial(
C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector
) -> jtp.Vector:
"""
Helper to convert the active representation of the base acceleration C_v̇_WB
Convert the active representation of the base acceleration C_v̇_WB
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
"""

Expand Down Expand Up @@ -1961,7 +1964,7 @@ def body_to_other_representation(
L_v̇_WL: jtp.Vector, L_v_WL: jtp.Vector, C_H_L: jtp.Matrix, L_v_CL: jtp.Vector
) -> jtp.Vector:
"""
Helper to convert the body-fixed apparent acceleration L_v̇_WL to
Convert the body-fixed apparent acceleration L_v̇_WL to
another representation C_v̇_WL expressed in a generic frame C.
"""

Expand Down
22 changes: 21 additions & 1 deletion src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,32 @@


class SystemDynamicsFromModelAndData(Protocol):
"""
Protocol defining the signature of a function computing the system dynamics
given a model and data object.
"""

def __call__(
self,
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
**kwargs: dict[str, Any],
) -> tuple[ODEState, dict[str, Any]]: ...
) -> tuple[ODEState, dict[str, Any]]:
"""
Compute the system dynamics given a model and data object.

Args:
model: The model to consider.
data: The data of the considered model.
**kwargs: Additional keyword arguments.

Returns:
A tuple with an `ODEState` object storing in each of its attributes the
corresponding derivative, and the dictionary of auxiliary data returned
by the system dynamics evaluation.
"""

pass


def wrap_system_dynamics_for_integration(
Expand Down
8 changes: 8 additions & 0 deletions src/jaxsim/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def raise_if(
msg:
The message to display when the exception is raised. The message can be a
format string (fmt), whose fields are filled with the args and kwargs.
*args: The arguments to fill the format string.
**kwargs: The keyword arguments to fill the format string
"""

# Disable host callback if running on unsupported hardware or if the user
Expand Down Expand Up @@ -61,12 +63,18 @@ def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None:
def raise_runtime_error_if(
condition: bool | jax.Array, msg: str, *args, **kwargs
) -> None:
"""
Raise a RuntimeError if a condition is met. Useful in jit-compiled functions.
"""

return raise_if(condition, RuntimeError, msg, *args, **kwargs)


def raise_value_error_if(
condition: bool | jax.Array, msg: str, *args, **kwargs
) -> None:
"""
Raise a ValueError if a condition is met. Useful in jit-compiled functions.
"""

return raise_if(condition, ValueError, msg, *args, **kwargs)
62 changes: 60 additions & 2 deletions src/jaxsim/integrators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,25 @@


class SystemDynamics(Protocol[State, StateDerivative]):
"""
Protocol defining the system dynamics.
"""

def __call__(
self, x: State, t: Time, **kwargs
) -> tuple[StateDerivative, dict[str, Any]]: ...
) -> tuple[StateDerivative, dict[str, Any]]:
"""
Compute the state derivative of the system.

Args:
x: The state of the system.
t: The time of the system.
**kwargs: Additional keyword arguments.

Returns:
The state derivative of the system and the auxiliary dictionary.
"""
pass


# =======================
Expand All @@ -48,6 +64,9 @@ def __call__(

@jax_dataclasses.pytree_dataclass
class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
"""
Factory class for integrators.
"""

dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field(
repr=False, hash=False, compare=False, kw_only=True
Expand Down Expand Up @@ -110,6 +129,9 @@ def step(
def __call__(
self, x0: State, t0: Time, dt: TimeStep, **kwargs
) -> tuple[NextState, dict[str, Any]]:
"""
Perform a single integration step.
"""
pass

def init(
Expand All @@ -121,6 +143,9 @@ def init(
include_dynamics_aux_dict: bool = False,
**kwargs,
) -> dict[str, Any]:
"""
Initialize the integrator. This method is deprecated.
"""

logging.warning(
"The 'init' method has been deprecated. There is no need to call it."
Expand All @@ -131,6 +156,18 @@ def init(

@jax_dataclasses.pytree_dataclass
class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]):
"""
Base class for explicit Runge-Kutta integrators.

Attributes:
A: The Runge-Kutta matrix.
b: The weights coefficients.
c: The nodes coefficients.
order_of_bT_rows: The order of the solution.
row_index_of_solution: The row of the integration output corresponding to the final solution.
fsal_enabled_if_supported: Whether to enable the FSAL property, if supported.
index_of_fsal: The index of the intermediate derivative to be used as the first derivative of the next iteration.
"""

# The Runge-Kutta matrix.
A: ClassVar[jtp.Matrix]
Expand All @@ -156,10 +193,16 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]

@property
def has_fsal(self) -> bool:
"""
Check if the integrator supports the FSAL property.
"""
return self.fsal_enabled_if_supported and self.index_of_fsal is not None

@property
def order(self) -> int:
"""
Return the order of the integrator.
"""
return self.order_of_bT_rows[self.row_index_of_solution]

@override
Expand Down Expand Up @@ -221,6 +264,9 @@ def build(
def __call__(
self, x0: State, t0: Time, dt: TimeStep, **kwargs
) -> tuple[NextState, dict[str, Any]]:
"""
Perform a single integration step.
"""

# Here z is a batched state with as many batch elements as b.T rows.
# Note that z has multiple batches only if b.T has more than one row,
Expand Down Expand Up @@ -331,7 +377,9 @@ def get_ẋ0_and_aux_dict() -> tuple[StateDerivative, dict[str, Any]]:
def scan_body(
carry: jax.Array, i: int | jax.Array
) -> tuple[jax.Array, dict[str, Any]]:
""""""
"""
Compute the kᵢ derivative of the Runge-Kutta stage.
"""

# Unpack the carry, i.e. the stacked kᵢ vectors.
K = carry
Expand Down Expand Up @@ -498,6 +546,16 @@ class ExplicitRungeKuttaSO3Mixin:
def post_process_state(
cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
) -> js.ode_data.ODEState:
r"""
Post-process the integrated state at :math:`t_f = t_0 + \Delta t` so that the
quaternion is normalized.

Args:
x0: The initial state of the system.
t0: The initial time of the system.
xf: The final state of the system obtain through the integration.
dt: The time step used for the integration.
"""

# Extract the initial base quaternion.
W_Q_B_t0 = x0.physics_model.base_quaternion
Expand Down
Loading