From b43c2657d81a5cf13d568191cf89d0f93faa42a5 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 15 Jan 2024 10:58:41 +0100 Subject: [PATCH] Update type hinting --- src/jaxsim/high_level/joint.py | 10 ++-- src/jaxsim/high_level/link.py | 10 ++-- src/jaxsim/high_level/model.py | 70 +++++++++++------------ src/jaxsim/parsers/descriptions/joint.py | 4 +- src/jaxsim/parsers/descriptions/link.py | 4 +- src/jaxsim/parsers/descriptions/model.py | 6 +- src/jaxsim/parsers/kinematic_graph.py | 2 +- src/jaxsim/parsers/rod/parser.py | 6 +- src/jaxsim/physics/algos/aba.py | 4 +- src/jaxsim/physics/algos/aba_motors.py | 4 +- src/jaxsim/physics/algos/rnea.py | 4 +- src/jaxsim/physics/algos/rnea_motors.py | 4 +- src/jaxsim/physics/algos/soft_contacts.py | 4 +- src/jaxsim/physics/algos/utils.py | 14 ++--- src/jaxsim/simulation/ode.py | 6 +- src/jaxsim/simulation/ode_integration.py | 8 +-- src/jaxsim/simulation/simulator.py | 21 +++---- tests/utils_idyntree.py | 8 +-- 18 files changed, 91 insertions(+), 98 deletions(-) diff --git a/src/jaxsim/high_level/joint.py b/src/jaxsim/high_level/joint.py index 5676b7fc4..db8d54876 100644 --- a/src/jaxsim/high_level/joint.py +++ b/src/jaxsim/high_level/joint.py @@ -1,6 +1,6 @@ import dataclasses import functools -from typing import Any, Optional +from typing import Any import jax.numpy as jnp import jax_dataclasses @@ -54,7 +54,7 @@ def name(self) -> str: return self.joint_description.name @functools.partial(oop.jax_tf.method_ro) - def position(self, dof: Optional[int] = None) -> jtp.Float: + def position(self, dof: int | None = None) -> jtp.Float: """""" dof = dof if dof is not None else 0 @@ -65,7 +65,7 @@ def position(self, dof: Optional[int] = None) -> jtp.Float: ) @functools.partial(oop.jax_tf.method_ro) - def velocity(self, dof: Optional[int] = None) -> jtp.Float: + def velocity(self, dof: int | None = None) -> jtp.Float: """""" dof = dof if dof is not None else 0 @@ -76,7 +76,7 @@ def velocity(self, dof: Optional[int] = None) -> jtp.Float: ) @functools.partial(oop.jax_tf.method_ro) - def force_target(self, dof: Optional[int] = None) -> jtp.Float: + def force_target(self, dof: int | None = None) -> jtp.Float: """""" dof = dof if dof is not None else 0 @@ -89,7 +89,7 @@ def force_target(self, dof: Optional[int] = None) -> jtp.Float: ) @functools.partial(oop.jax_tf.method_ro) - def position_limit(self, dof: Optional[int] = None) -> tuple[jtp.Float, jtp.Float]: + def position_limit(self, dof: int | None = None) -> tuple[jtp.Float, jtp.Float]: """""" dof = dof if dof is not None else 0 diff --git a/src/jaxsim/high_level/link.py b/src/jaxsim/high_level/link.py index 97b0224cf..d82e1d258 100644 --- a/src/jaxsim/high_level/link.py +++ b/src/jaxsim/high_level/link.py @@ -1,6 +1,6 @@ import dataclasses import functools -from typing import Any, Optional +from typing import Any import jax.lax import jax.numpy as jnp @@ -122,7 +122,7 @@ def transform(self) -> jtp.Matrix: return self.parent_model.forward_kinematics()[self.index()] @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"]) - def velocity(self, vel_repr: Optional[VelRepr] = None) -> jtp.Vector: + def velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector: """""" v_WL = ( @@ -133,19 +133,19 @@ def velocity(self, vel_repr: Optional[VelRepr] = None) -> jtp.Vector: return v_WL @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"]) - def linear_velocity(self, vel_repr: Optional[VelRepr] = None) -> jtp.Vector: + def linear_velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector: """""" return self.velocity(vel_repr=vel_repr)[0:3] @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"]) - def angular_velocity(self, vel_repr: Optional[VelRepr] = None) -> jtp.Vector: + def angular_velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector: """""" return self.velocity(vel_repr=vel_repr)[3:6] @functools.partial(oop.jax_tf.method_ro, static_argnames=["output_vel_repr"]) - def jacobian(self, output_vel_repr: Optional[VelRepr] = None) -> jtp.Matrix: + def jacobian(self, output_vel_repr: VelRepr | None = None) -> jtp.Matrix: """""" if output_vel_repr is None: diff --git a/src/jaxsim/high_level/model.py b/src/jaxsim/high_level/model.py index c733cfbde..520db2446 100644 --- a/src/jaxsim/high_level/model.py +++ b/src/jaxsim/high_level/model.py @@ -116,11 +116,11 @@ class Model(Vmappable): @staticmethod def build_from_model_description( model_description: Union[str, pathlib.Path, rod.Model], - model_name: Optional[str] = None, + model_name: str | None = None, vel_repr: VelRepr = VelRepr.Mixed, gravity: jtp.Array = jaxsim.physics.default_gravity(), - is_urdf: Optional[bool] = None, - considered_joints: Optional[List[str]] = None, + is_urdf: bool | None = None, + considered_joints: List[str] | None = None, ) -> "Model": """ Build a Model object from a model description. @@ -169,11 +169,11 @@ def build_from_model_description( @staticmethod def build_from_sdf( sdf: Union[str, pathlib.Path], - model_name: Optional[str] = None, + model_name: str | None = None, vel_repr: VelRepr = VelRepr.Mixed, gravity: jtp.Array = jaxsim.physics.default_gravity(), - is_urdf: Optional[bool] = None, - considered_joints: Optional[List[str]] = None, + is_urdf: bool | None = None, + considered_joints: List[str] | None = None, ) -> "Model": """ Build a Model object from an SDF description. @@ -197,7 +197,7 @@ def build_from_sdf( @staticmethod def build( physics_model: jaxsim.physics.model.physics_model.PhysicsModel, - model_name: Optional[str] = None, + model_name: str | None = None, vel_repr: VelRepr = VelRepr.Mixed, ) -> "Model": """ @@ -397,7 +397,7 @@ def joint_names(self) -> tuple[str, ...]: @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) def links( - self, link_names: Optional[tuple[str, ...]] = None + self, link_names: tuple[str, ...] | None = None ) -> tuple[high_level.link.Link, ...]: """""" @@ -421,7 +421,7 @@ def links( @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) def joints( - self, joint_names: Optional[tuple[str, ...]] = None + self, joint_names: tuple[str, ...] | None = None ) -> tuple[high_level.joint.Joint, ...]: """""" @@ -446,7 +446,7 @@ def joints( @functools.partial(oop.jax_tf.method_ro, static_argnames=["link_names", "terrain"]) def in_contact( self, - link_names: Optional[tuple[str, ...]] = None, + link_names: tuple[str, ...] | None = None, terrain: Terrain = FlatTerrain(), ) -> jtp.Vector: """""" @@ -484,9 +484,7 @@ def in_contact( # ================= @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) - def joint_positions( - self, joint_names: Optional[tuple[str, ...]] = None - ) -> jtp.Vector: + def joint_positions(self, joint_names: tuple[str, ...] | None = None) -> jtp.Vector: """""" return self.data.model_state.joint_positions[ @@ -496,8 +494,8 @@ def joint_positions( @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) def joint_random_positions( self, - joint_names: Optional[tuple[str, ...]] = None, - key: Optional[jax.Array] = None, + joint_names: tuple[str, ...] | None = None, + key: jax.Array | None = None, ) -> jtp.Vector: """""" @@ -517,7 +515,7 @@ def joint_random_positions( @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) def joint_velocities( - self, joint_names: Optional[tuple[str, ...]] = None + self, joint_names: tuple[str, ...] | None = None ) -> jtp.Vector: """""" @@ -527,7 +525,7 @@ def joint_velocities( @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) def joint_generalized_forces_targets( - self, joint_names: Optional[tuple[str, ...]] = None + self, joint_names: tuple[str, ...] | None = None ) -> jtp.Vector: """""" @@ -535,7 +533,7 @@ def joint_generalized_forces_targets( @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) def joint_limits( - self, joint_names: Optional[tuple[str, ...]] = None + self, joint_names: tuple[str, ...] | None = None ) -> Tuple[jtp.Vector, jtp.Vector]: """""" @@ -653,8 +651,8 @@ def external_forces(self) -> jtp.Matrix: def apply_external_force_to_link( self, link_name: str, - force: Optional[jtp.Array] = None, - torque: Optional[jtp.Array] = None, + force: jtp.Array | None = None, + torque: jtp.Array | None = None, additive: bool = True, ) -> None: """""" @@ -708,8 +706,8 @@ def apply_external_force_to_link( def apply_external_force_to_link_com( self, link_name: str, - force: Optional[jtp.Array] = None, - torque: Optional[jtp.Array] = None, + force: jtp.Array | None = None, + torque: jtp.Array | None = None, additive: bool = True, ) -> None: """""" @@ -779,7 +777,7 @@ def generalized_velocity(self) -> jtp.Vector: @functools.partial(oop.jax_tf.method_ro, static_argnames=["output_vel_repr"]) def generalized_free_floating_jacobian( - self, output_vel_repr: Optional[VelRepr] = None + self, output_vel_repr: VelRepr | None = None ) -> jtp.Matrix: """""" @@ -968,8 +966,8 @@ def forward_kinematics(self) -> jtp.Array: @functools.partial(oop.jax_tf.method_ro) def inverse_dynamics( self, - joint_accelerations: Optional[jtp.Vector] = None, - base_acceleration: Optional[jtp.Vector] = None, + joint_accelerations: jtp.Vector | None = None, + base_acceleration: jtp.Vector | None = None, ) -> Tuple[jtp.Vector, jtp.Vector]: """ Compute inverse dynamics with the RNEA algorithm. @@ -1058,7 +1056,7 @@ def to_inertial(C_vd_WB, W_H_C, C_v_WB, W_vl_WC): @functools.partial(oop.jax_tf.method_ro, static_argnames=["prefer_aba"]) def forward_dynamics( - self, tau: Optional[jtp.Vector] = None, prefer_aba: float = True + self, tau: jtp.Vector | None = None, prefer_aba: float = True ) -> Tuple[jtp.Vector, jtp.Vector]: """""" @@ -1070,7 +1068,7 @@ def forward_dynamics( @functools.partial(oop.jax_tf.method_ro) def forward_dynamics_aba( - self, tau: Optional[jtp.Vector] = None + self, tau: jtp.Vector | None = None ) -> Tuple[jtp.Vector, jtp.Vector]: """""" @@ -1136,7 +1134,7 @@ def to_active(W_vd_WB, W_H_C, W_v_WB, W_vl_WC): @functools.partial(oop.jax_tf.method_ro) def forward_dynamics_crb( - self, tau: Optional[jtp.Vector] = None + self, tau: jtp.Vector | None = None ) -> Tuple[jtp.Vector, jtp.Vector]: """""" @@ -1234,7 +1232,7 @@ def potential_energy(self) -> jtp.Float: @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) def set_joint_generalized_force_targets( - self, forces: jtp.Vector, joint_names: Optional[tuple[str, ...]] = None + self, forces: jtp.Vector, joint_names: tuple[str, ...] | None = None ) -> None: """""" @@ -1254,7 +1252,7 @@ def set_joint_generalized_force_targets( @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) def reset_joint_positions( - self, positions: jtp.Vector, joint_names: Optional[tuple[str, ...]] = None + self, positions: jtp.Vector, joint_names: tuple[str, ...] | None = None ) -> None: """""" @@ -1280,7 +1278,7 @@ def reset_joint_positions( @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) def reset_joint_velocities( - self, velocities: jtp.Vector, joint_names: Optional[tuple[str, ...]] = None + self, velocities: jtp.Vector, joint_names: tuple[str, ...] | None = None ) -> None: """""" @@ -1510,7 +1508,7 @@ def integrate( @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) def set_motor_inertias( - self, inertias: jtp.Vector, joint_names: Optional[tuple[str, ...]] = None + self, inertias: jtp.Vector, joint_names: tuple[str, ...] | None = None ) -> None: joint_names = joint_names or self.joint_names() @@ -1525,7 +1523,7 @@ def set_motor_inertias( @functools.partial(oop.jax_tf.method_rw, jit=False) def set_motor_gear_ratios( - self, gear_ratios: jtp.Vector, joint_names: Optional[tuple[str, ...]] = None + self, gear_ratios: jtp.Vector, joint_names: tuple[str, ...] | None = None ) -> None: joint_names = joint_names or self.joint_names() @@ -1549,7 +1547,7 @@ def set_motor_gear_ratios( def set_motor_viscous_frictions( self, viscous_frictions: jtp.Vector, - joint_names: Optional[tuple[str, ...]] = None, + joint_names: tuple[str, ...] | None = None, ) -> None: joint_names = joint_names or self.joint_names() @@ -1680,9 +1678,7 @@ def active_to_inertial_representation( else: raise ValueError(self.velocity_representation) - def _joint_indices( - self, joint_names: Optional[tuple[str, ...]] = None - ) -> jtp.Vector: + def _joint_indices(self, joint_names: tuple[str, ...] | None = None) -> jtp.Vector: """""" if joint_names is None: diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index 4a6d54fc5..0ac1ece3e 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -1,6 +1,6 @@ import dataclasses import enum -from typing import Optional, Tuple, Union +from typing import Tuple, Union import jax_dataclasses import numpy as np @@ -63,7 +63,7 @@ class JointDescription(JaxsimDataclass): child: LinkDescription = dataclasses.dataclass(repr=False) parent: LinkDescription = dataclasses.dataclass(repr=False) - index: Optional[int] = None + index: int | None = None friction_static: float = 0.0 friction_viscous: float = 0.0 diff --git a/src/jaxsim/parsers/descriptions/link.py b/src/jaxsim/parsers/descriptions/link.py index b2f1d2c17..b2a7fe55f 100644 --- a/src/jaxsim/parsers/descriptions/link.py +++ b/src/jaxsim/parsers/descriptions/link.py @@ -1,5 +1,5 @@ import dataclasses -from typing import List, Optional +from typing import List import jax.numpy as jnp import jax_dataclasses @@ -19,7 +19,7 @@ class LinkDescription(JaxsimDataclass): name: Static[str] mass: float inertia: jtp.Matrix - index: Optional[int] = None + index: int | None = None parent: Static["LinkDescription"] = dataclasses.field(default=None, repr=False) pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False) children: Static[List["LinkDescription"]] = dataclasses.field( diff --git a/src/jaxsim/parsers/descriptions/model.py b/src/jaxsim/parsers/descriptions/model.py index f29807c38..8a1a7b757 100644 --- a/src/jaxsim/parsers/descriptions/model.py +++ b/src/jaxsim/parsers/descriptions/model.py @@ -1,6 +1,6 @@ import dataclasses import itertools -from typing import List, Optional +from typing import List from jaxsim import logging @@ -23,8 +23,8 @@ def build_model_from( joints: List[JointDescription], collisions: List[CollisionShape] = (), fixed_base: bool = False, - base_link_name: Optional[str] = None, - considered_joints: Optional[List[str]] = None, + base_link_name: str | None = None, + considered_joints: List[str] | None = None, model_pose: RootPose = RootPose(), ) -> "ModelDescription": # Create the full kinematic graph diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index e33193628..1c7432daa 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -98,7 +98,7 @@ def __post_init__(self): def build_from( links: List[descriptions.LinkDescription], joints: List[descriptions.JointDescription], - root_link_name: Optional[str] = None, + root_link_name: str | None = None, root_pose: RootPose = RootPose(), ) -> "KinematicGraph": if root_link_name is None: diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index f06a47bb1..99894d0b5 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -27,14 +27,14 @@ class SDFData(NamedTuple): joint_descriptions: List[descriptions.JointDescription] collision_shapes: List[descriptions.CollisionShape] - sdf_model: Optional[rod.Model] = None + sdf_model: rod.Model | None = None model_pose: kinematic_graph.RootPose = kinematic_graph.RootPose() def extract_model_data( model_description: Union[pathlib.Path, str, rod.Model], - model_name: Optional[str] = None, - is_urdf: Optional[bool] = None, + model_name: str | None = None, + is_urdf: bool | None = None, ) -> SDFData: """ Extract data from an SDF/URDF resource useful to build a JaxSim model. diff --git a/src/jaxsim/physics/algos/aba.py b/src/jaxsim/physics/algos/aba.py index 1b5d5abbf..98cb12500 100644 --- a/src/jaxsim/physics/algos/aba.py +++ b/src/jaxsim/physics/algos/aba.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Tuple import jax import jax.numpy as jnp @@ -18,7 +18,7 @@ def aba( q: jtp.Vector, qd: jtp.Vector, tau: jtp.Vector, - f_ext: Optional[jtp.Matrix] = None, + f_ext: jtp.Matrix | None = None, ) -> Tuple[jtp.Vector, jtp.Vector]: """ Articulated Body Algorithm (ABA) algorithm for forward dynamics. diff --git a/src/jaxsim/physics/algos/aba_motors.py b/src/jaxsim/physics/algos/aba_motors.py index 7c3060790..fd7873ffe 100644 --- a/src/jaxsim/physics/algos/aba_motors.py +++ b/src/jaxsim/physics/algos/aba_motors.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Tuple import jax import jax.numpy as jnp @@ -18,7 +18,7 @@ def aba( q: jtp.Vector, qd: jtp.Vector, tau: jtp.Vector, - f_ext: Optional[jtp.Matrix] = None, + f_ext: jtp.Matrix | None = None, ) -> Tuple[jtp.Vector, jtp.Vector]: """ Articulated Body Algorithm (ABA) algorithm with motor dynamics for forward dynamics. diff --git a/src/jaxsim/physics/algos/rnea.py b/src/jaxsim/physics/algos/rnea.py index c51de1182..8a1c62116 100644 --- a/src/jaxsim/physics/algos/rnea.py +++ b/src/jaxsim/physics/algos/rnea.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Tuple import jax import jax.numpy as jnp @@ -19,7 +19,7 @@ def rnea( qd: jtp.Vector, qdd: jtp.Vector, a0fb: jtp.Vector = jnp.zeros(6), - f_ext: Optional[jtp.Matrix] = None, + f_ext: jtp.Matrix | None = None, ) -> Tuple[jtp.Vector, jtp.Vector]: """ Recursive Newton-Euler Algorithm (RNEA) algorithm for inverse dynamics. diff --git a/src/jaxsim/physics/algos/rnea_motors.py b/src/jaxsim/physics/algos/rnea_motors.py index e324e8d86..84a8c02cd 100644 --- a/src/jaxsim/physics/algos/rnea_motors.py +++ b/src/jaxsim/physics/algos/rnea_motors.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Tuple import jax import jax.numpy as jnp @@ -19,7 +19,7 @@ def rnea( qd: jtp.Vector, qdd: jtp.Vector, a0fb: jtp.Vector = jnp.zeros(6), - f_ext: Optional[jtp.Matrix] = None, + f_ext: jtp.Matrix | None = None, ) -> Tuple[jtp.Vector, jtp.Vector]: """ Recursive Newton-Euler Algorithm (RNEA) algorithm for inverse dynamics. diff --git a/src/jaxsim/physics/algos/soft_contacts.py b/src/jaxsim/physics/algos/soft_contacts.py index b3a8b85f9..b7627bf9c 100644 --- a/src/jaxsim/physics/algos/soft_contacts.py +++ b/src/jaxsim/physics/algos/soft_contacts.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Optional, Tuple +from typing import Tuple import jax import jax.flatten_util @@ -52,7 +52,7 @@ def collidable_points_pos_vel( model: PhysicsModel, q: jtp.Vector, qd: jtp.Vector, - xfb: Optional[jtp.Vector] = None, + xfb: jtp.Vector | None = None, ) -> Tuple[jtp.Matrix, jtp.Matrix]: """ Compute the position and linear velocity of collidable points in the world frame. diff --git a/src/jaxsim/physics/algos/utils.py b/src/jaxsim/physics/algos/utils.py index 690fd3217..0204986e1 100644 --- a/src/jaxsim/physics/algos/utils.py +++ b/src/jaxsim/physics/algos/utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Tuple import jax.numpy as jnp @@ -8,12 +8,12 @@ def process_inputs( physics_model: PhysicsModel, - xfb: Optional[jtp.Vector] = None, - q: Optional[jtp.Vector] = None, - qd: Optional[jtp.Vector] = None, - qdd: Optional[jtp.Vector] = None, - tau: Optional[jtp.Vector] = None, - f_ext: Optional[jtp.Matrix] = None, + xfb: jtp.Vector | None = None, + q: jtp.Vector | None = None, + qd: jtp.Vector | None = None, + qdd: jtp.Vector | None = None, + tau: jtp.Vector | None = None, + f_ext: jtp.Matrix | None = None, ) -> Tuple[jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector, jtp.Matrix]: """ Adjust the inputs to the physics model. diff --git a/src/jaxsim/simulation/ode.py b/src/jaxsim/simulation/ode.py index c5e145f4d..f22941da4 100644 --- a/src/jaxsim/simulation/ode.py +++ b/src/jaxsim/simulation/ode.py @@ -74,10 +74,10 @@ def compute_contact_forces( def dx_dt( x: ode_data.ODEState, - t: Optional[jtp.Float], physics_model: PhysicsModel, soft_contacts_params: SoftContactsParams = SoftContactsParams(), - ode_input: Optional[ode_data.ODEInput] = None, + t: jtp.Float | None = None, + ode_input: ode_data.ODEInput | None = None, terrain: Terrain = FlatTerrain(), ) -> Tuple[ode_data.ODEState, Dict[str, Any]]: """ @@ -231,7 +231,7 @@ def dx_dt( # Build the full derivative of ODEState # ===================================== - def fix_one_dof(vector: jtp.Vector) -> Optional[jtp.Vector]: + def fix_one_dof(vector: jtp.Vector) -> jtp.Vector | None: """Fix the shape of computed quantities for models with just 1 DoF.""" if vector is None: diff --git a/src/jaxsim/simulation/ode_integration.py b/src/jaxsim/simulation/ode_integration.py index fe98e4a0f..c34d2bd36 100644 --- a/src/jaxsim/simulation/ode_integration.py +++ b/src/jaxsim/simulation/ode_integration.py @@ -1,6 +1,6 @@ import enum import functools -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Tuple, Union import jax.flatten_util from jax.experimental.ode import odeint @@ -39,7 +39,7 @@ def ode_integration_euler( physics_model: PhysicsModel, soft_contacts_params: SoftContactsParams = SoftContactsParams(), terrain: Terrain = FlatTerrain(), - ode_input: Optional[ode.ode_data.ODEInput] = None, + ode_input: ode.ode_data.ODEInput | None = None, *args, num_sub_steps: int = 1, return_aux: bool = False, @@ -70,7 +70,7 @@ def ode_integration_euler_semi_implicit( physics_model: PhysicsModel, soft_contacts_params: SoftContactsParams = SoftContactsParams(), terrain: Terrain = FlatTerrain(), - ode_input: Optional[ode.ode_data.ODEInput] = None, + ode_input: ode.ode_data.ODEInput | None = None, *args, num_sub_steps: int = 1, return_aux: bool = False, @@ -101,7 +101,7 @@ def ode_integration_rk4( physics_model: PhysicsModel, soft_contacts_params: SoftContactsParams = SoftContactsParams(), terrain: Terrain = FlatTerrain(), - ode_input: Optional[ode.ode_data.ODEInput] = None, + ode_input: ode.ode_data.ODEInput | None = None, *args, num_sub_steps=1, return_aux: bool = False, diff --git a/src/jaxsim/simulation/simulator.py b/src/jaxsim/simulation/simulator.py index 4f4f5b9a4..75dbd9d93 100644 --- a/src/jaxsim/simulation/simulator.py +++ b/src/jaxsim/simulation/simulator.py @@ -90,7 +90,7 @@ def build( steps_per_run: jtp.Int = 1, velocity_representation: VelRepr = VelRepr.Inertial, integrator_type: IntegratorType = IntegratorType.EulerSemiImplicit, - simulator_data: Optional[SimulatorData] = None, + simulator_data: SimulatorData | None = None, ) -> "JaxSim": """ Build a JaxSim simulator object. @@ -219,9 +219,7 @@ def get_model(self, model_name: str) -> Model: return self.data.models[model_name] @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) - def models( - self, model_names: Optional[tuple[str, ...]] = None - ) -> tuple[Model, ...]: + def models(self, model_names: tuple[str, ...] | None = None) -> tuple[Model, ...]: """ Return the simulated models. @@ -259,8 +257,8 @@ def set_gravity(self, gravity: jtp.Vector) -> None: def insert_model_from_description( self, model_description: Union[pathlib.Path, str, rod.Model], - model_name: Optional[str] = None, - considered_joints: Optional[List[str]] = None, + model_name: str | None = None, + considered_joints: List[str] | None = None, ) -> Model: """ Insert a model from a model description. @@ -302,8 +300,8 @@ def insert_model_from_description( def insert_model_from_sdf( self, sdf: Union[pathlib.Path, str], - model_name: Optional[str] = None, - considered_joints: Optional[List[str]] = None, + model_name: str | None = None, + considered_joints: List[str] | None = None, ) -> Model: """ Insert a model from an SDF resource. @@ -324,7 +322,7 @@ def insert_model_from_sdf( def insert_model( self, model_description: descriptions.ModelDescription, - model_name: Optional[str] = None, + model_name: str | None = None, ) -> Model: """ Insert a model from a model description object. @@ -436,9 +434,8 @@ def step(self, clear_inputs: bool = False) -> Dict[str, StepData]: def step_over_horizon( self, horizon_steps: jtp.Int, - callback_handler: Optional[ - Union["scb.SimulatorCallback", "scb.CallbackHandler"] - ] = None, + callback_handler: Union["scb.SimulatorCallback", "scb.CallbackHandler"] + | None = None, clear_inputs: jtp.Bool = False, ) -> Union[ "JaxSim", diff --git a/tests/utils_idyntree.py b/tests/utils_idyntree.py index 201ca13a4..a58391067 100644 --- a/tests/utils_idyntree.py +++ b/tests/utils_idyntree.py @@ -1,6 +1,6 @@ import dataclasses import pathlib -from typing import List, Optional, Union +from typing import List, Union import idyntree.bindings as idt import numpy as np @@ -66,11 +66,11 @@ def build( def set_robot_state( self, - joint_positions: Optional[npt.NDArray] = None, - joint_velocities: Optional[npt.NDArray] = None, + joint_positions: npt.NDArray | None = None, + joint_velocities: npt.NDArray | None = None, base_transform: npt.NDArray = np.eye(4), base_velocity: npt.NDArray = np.zeros(6), - world_gravity: Optional[npt.NDArray] = None, + world_gravity: npt.NDArray | None = None, ) -> None: joint_positions = ( joint_positions if joint_positions is not None else np.zeros(self.dofs())