Skip to content

Commit

Permalink
Update type hinting
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jan 15, 2024
1 parent 2075833 commit b43c265
Show file tree
Hide file tree
Showing 18 changed files with 91 additions and 98 deletions.
10 changes: 5 additions & 5 deletions src/jaxsim/high_level/joint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import functools
from typing import Any, Optional
from typing import Any

import jax.numpy as jnp
import jax_dataclasses
Expand Down Expand Up @@ -54,7 +54,7 @@ def name(self) -> str:
return self.joint_description.name

@functools.partial(oop.jax_tf.method_ro)
def position(self, dof: Optional[int] = None) -> jtp.Float:
def position(self, dof: int | None = None) -> jtp.Float:
""""""

dof = dof if dof is not None else 0
Expand All @@ -65,7 +65,7 @@ def position(self, dof: Optional[int] = None) -> jtp.Float:
)

@functools.partial(oop.jax_tf.method_ro)
def velocity(self, dof: Optional[int] = None) -> jtp.Float:
def velocity(self, dof: int | None = None) -> jtp.Float:
""""""

dof = dof if dof is not None else 0
Expand All @@ -76,7 +76,7 @@ def velocity(self, dof: Optional[int] = None) -> jtp.Float:
)

@functools.partial(oop.jax_tf.method_ro)
def force_target(self, dof: Optional[int] = None) -> jtp.Float:
def force_target(self, dof: int | None = None) -> jtp.Float:
""""""

dof = dof if dof is not None else 0
Expand All @@ -89,7 +89,7 @@ def force_target(self, dof: Optional[int] = None) -> jtp.Float:
)

@functools.partial(oop.jax_tf.method_ro)
def position_limit(self, dof: Optional[int] = None) -> tuple[jtp.Float, jtp.Float]:
def position_limit(self, dof: int | None = None) -> tuple[jtp.Float, jtp.Float]:
""""""

dof = dof if dof is not None else 0
Expand Down
10 changes: 5 additions & 5 deletions src/jaxsim/high_level/link.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import functools
from typing import Any, Optional
from typing import Any

import jax.lax
import jax.numpy as jnp
Expand Down Expand Up @@ -122,7 +122,7 @@ def transform(self) -> jtp.Matrix:
return self.parent_model.forward_kinematics()[self.index()]

@functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"])
def velocity(self, vel_repr: Optional[VelRepr] = None) -> jtp.Vector:
def velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector:
""""""

v_WL = (
Expand All @@ -133,19 +133,19 @@ def velocity(self, vel_repr: Optional[VelRepr] = None) -> jtp.Vector:
return v_WL

@functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"])
def linear_velocity(self, vel_repr: Optional[VelRepr] = None) -> jtp.Vector:
def linear_velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector:
""""""

return self.velocity(vel_repr=vel_repr)[0:3]

@functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"])
def angular_velocity(self, vel_repr: Optional[VelRepr] = None) -> jtp.Vector:
def angular_velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector:
""""""

return self.velocity(vel_repr=vel_repr)[3:6]

@functools.partial(oop.jax_tf.method_ro, static_argnames=["output_vel_repr"])
def jacobian(self, output_vel_repr: Optional[VelRepr] = None) -> jtp.Matrix:
def jacobian(self, output_vel_repr: VelRepr | None = None) -> jtp.Matrix:
""""""

if output_vel_repr is None:
Expand Down
70 changes: 33 additions & 37 deletions src/jaxsim/high_level/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ class Model(Vmappable):
@staticmethod
def build_from_model_description(
model_description: Union[str, pathlib.Path, rod.Model],
model_name: Optional[str] = None,
model_name: str | None = None,
vel_repr: VelRepr = VelRepr.Mixed,
gravity: jtp.Array = jaxsim.physics.default_gravity(),
is_urdf: Optional[bool] = None,
considered_joints: Optional[List[str]] = None,
is_urdf: bool | None = None,
considered_joints: List[str] | None = None,
) -> "Model":
"""
Build a Model object from a model description.
Expand Down Expand Up @@ -169,11 +169,11 @@ def build_from_model_description(
@staticmethod
def build_from_sdf(
sdf: Union[str, pathlib.Path],
model_name: Optional[str] = None,
model_name: str | None = None,
vel_repr: VelRepr = VelRepr.Mixed,
gravity: jtp.Array = jaxsim.physics.default_gravity(),
is_urdf: Optional[bool] = None,
considered_joints: Optional[List[str]] = None,
is_urdf: bool | None = None,
considered_joints: List[str] | None = None,
) -> "Model":
"""
Build a Model object from an SDF description.
Expand All @@ -197,7 +197,7 @@ def build_from_sdf(
@staticmethod
def build(
physics_model: jaxsim.physics.model.physics_model.PhysicsModel,
model_name: Optional[str] = None,
model_name: str | None = None,
vel_repr: VelRepr = VelRepr.Mixed,
) -> "Model":
"""
Expand Down Expand Up @@ -397,7 +397,7 @@ def joint_names(self) -> tuple[str, ...]:

@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
def links(
self, link_names: Optional[tuple[str, ...]] = None
self, link_names: tuple[str, ...] | None = None
) -> tuple[high_level.link.Link, ...]:
""""""

Expand All @@ -421,7 +421,7 @@ def links(

@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
def joints(
self, joint_names: Optional[tuple[str, ...]] = None
self, joint_names: tuple[str, ...] | None = None
) -> tuple[high_level.joint.Joint, ...]:
""""""

Expand All @@ -446,7 +446,7 @@ def joints(
@functools.partial(oop.jax_tf.method_ro, static_argnames=["link_names", "terrain"])
def in_contact(
self,
link_names: Optional[tuple[str, ...]] = None,
link_names: tuple[str, ...] | None = None,
terrain: Terrain = FlatTerrain(),
) -> jtp.Vector:
""""""
Expand Down Expand Up @@ -484,9 +484,7 @@ def in_contact(
# =================

@functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
def joint_positions(
self, joint_names: Optional[tuple[str, ...]] = None
) -> jtp.Vector:
def joint_positions(self, joint_names: tuple[str, ...] | None = None) -> jtp.Vector:
""""""

return self.data.model_state.joint_positions[
Expand All @@ -496,8 +494,8 @@ def joint_positions(
@functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
def joint_random_positions(
self,
joint_names: Optional[tuple[str, ...]] = None,
key: Optional[jax.Array] = None,
joint_names: tuple[str, ...] | None = None,
key: jax.Array | None = None,
) -> jtp.Vector:
""""""

Expand All @@ -517,7 +515,7 @@ def joint_random_positions(

@functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
def joint_velocities(
self, joint_names: Optional[tuple[str, ...]] = None
self, joint_names: tuple[str, ...] | None = None
) -> jtp.Vector:
""""""

Expand All @@ -527,15 +525,15 @@ def joint_velocities(

@functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
def joint_generalized_forces_targets(
self, joint_names: Optional[tuple[str, ...]] = None
self, joint_names: tuple[str, ...] | None = None
) -> jtp.Vector:
""""""

return self.data.model_input.tau[self._joint_indices(joint_names=joint_names)]

@functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
def joint_limits(
self, joint_names: Optional[tuple[str, ...]] = None
self, joint_names: tuple[str, ...] | None = None
) -> Tuple[jtp.Vector, jtp.Vector]:
""""""

Expand Down Expand Up @@ -653,8 +651,8 @@ def external_forces(self) -> jtp.Matrix:
def apply_external_force_to_link(
self,
link_name: str,
force: Optional[jtp.Array] = None,
torque: Optional[jtp.Array] = None,
force: jtp.Array | None = None,
torque: jtp.Array | None = None,
additive: bool = True,
) -> None:
""""""
Expand Down Expand Up @@ -708,8 +706,8 @@ def apply_external_force_to_link(
def apply_external_force_to_link_com(
self,
link_name: str,
force: Optional[jtp.Array] = None,
torque: Optional[jtp.Array] = None,
force: jtp.Array | None = None,
torque: jtp.Array | None = None,
additive: bool = True,
) -> None:
""""""
Expand Down Expand Up @@ -779,7 +777,7 @@ def generalized_velocity(self) -> jtp.Vector:

@functools.partial(oop.jax_tf.method_ro, static_argnames=["output_vel_repr"])
def generalized_free_floating_jacobian(
self, output_vel_repr: Optional[VelRepr] = None
self, output_vel_repr: VelRepr | None = None
) -> jtp.Matrix:
""""""

Expand Down Expand Up @@ -968,8 +966,8 @@ def forward_kinematics(self) -> jtp.Array:
@functools.partial(oop.jax_tf.method_ro)
def inverse_dynamics(
self,
joint_accelerations: Optional[jtp.Vector] = None,
base_acceleration: Optional[jtp.Vector] = None,
joint_accelerations: jtp.Vector | None = None,
base_acceleration: jtp.Vector | None = None,
) -> Tuple[jtp.Vector, jtp.Vector]:
"""
Compute inverse dynamics with the RNEA algorithm.
Expand Down Expand Up @@ -1058,7 +1056,7 @@ def to_inertial(C_vd_WB, W_H_C, C_v_WB, W_vl_WC):

@functools.partial(oop.jax_tf.method_ro, static_argnames=["prefer_aba"])
def forward_dynamics(
self, tau: Optional[jtp.Vector] = None, prefer_aba: float = True
self, tau: jtp.Vector | None = None, prefer_aba: float = True
) -> Tuple[jtp.Vector, jtp.Vector]:
""""""

Expand All @@ -1070,7 +1068,7 @@ def forward_dynamics(

@functools.partial(oop.jax_tf.method_ro)
def forward_dynamics_aba(
self, tau: Optional[jtp.Vector] = None
self, tau: jtp.Vector | None = None
) -> Tuple[jtp.Vector, jtp.Vector]:
""""""

Expand Down Expand Up @@ -1136,7 +1134,7 @@ def to_active(W_vd_WB, W_H_C, W_v_WB, W_vl_WC):

@functools.partial(oop.jax_tf.method_ro)
def forward_dynamics_crb(
self, tau: Optional[jtp.Vector] = None
self, tau: jtp.Vector | None = None
) -> Tuple[jtp.Vector, jtp.Vector]:
""""""

Expand Down Expand Up @@ -1234,7 +1232,7 @@ def potential_energy(self) -> jtp.Float:

@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
def set_joint_generalized_force_targets(
self, forces: jtp.Vector, joint_names: Optional[tuple[str, ...]] = None
self, forces: jtp.Vector, joint_names: tuple[str, ...] | None = None
) -> None:
""""""

Expand All @@ -1254,7 +1252,7 @@ def set_joint_generalized_force_targets(

@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
def reset_joint_positions(
self, positions: jtp.Vector, joint_names: Optional[tuple[str, ...]] = None
self, positions: jtp.Vector, joint_names: tuple[str, ...] | None = None
) -> None:
""""""

Expand All @@ -1280,7 +1278,7 @@ def reset_joint_positions(

@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
def reset_joint_velocities(
self, velocities: jtp.Vector, joint_names: Optional[tuple[str, ...]] = None
self, velocities: jtp.Vector, joint_names: tuple[str, ...] | None = None
) -> None:
""""""

Expand Down Expand Up @@ -1510,7 +1508,7 @@ def integrate(

@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
def set_motor_inertias(
self, inertias: jtp.Vector, joint_names: Optional[tuple[str, ...]] = None
self, inertias: jtp.Vector, joint_names: tuple[str, ...] | None = None
) -> None:
joint_names = joint_names or self.joint_names()

Expand All @@ -1525,7 +1523,7 @@ def set_motor_inertias(

@functools.partial(oop.jax_tf.method_rw, jit=False)
def set_motor_gear_ratios(
self, gear_ratios: jtp.Vector, joint_names: Optional[tuple[str, ...]] = None
self, gear_ratios: jtp.Vector, joint_names: tuple[str, ...] | None = None
) -> None:
joint_names = joint_names or self.joint_names()

Expand All @@ -1549,7 +1547,7 @@ def set_motor_gear_ratios(
def set_motor_viscous_frictions(
self,
viscous_frictions: jtp.Vector,
joint_names: Optional[tuple[str, ...]] = None,
joint_names: tuple[str, ...] | None = None,
) -> None:
joint_names = joint_names or self.joint_names()

Expand Down Expand Up @@ -1680,9 +1678,7 @@ def active_to_inertial_representation(
else:
raise ValueError(self.velocity_representation)

def _joint_indices(
self, joint_names: Optional[tuple[str, ...]] = None
) -> jtp.Vector:
def _joint_indices(self, joint_names: tuple[str, ...] | None = None) -> jtp.Vector:
""""""

if joint_names is None:
Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/parsers/descriptions/joint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import enum
from typing import Optional, Tuple, Union
from typing import Tuple, Union

import jax_dataclasses
import numpy as np
Expand Down Expand Up @@ -63,7 +63,7 @@ class JointDescription(JaxsimDataclass):
child: LinkDescription = dataclasses.dataclass(repr=False)
parent: LinkDescription = dataclasses.dataclass(repr=False)

index: Optional[int] = None
index: int | None = None

friction_static: float = 0.0
friction_viscous: float = 0.0
Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/parsers/descriptions/link.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import List, Optional
from typing import List

import jax.numpy as jnp
import jax_dataclasses
Expand All @@ -19,7 +19,7 @@ class LinkDescription(JaxsimDataclass):
name: Static[str]
mass: float
inertia: jtp.Matrix
index: Optional[int] = None
index: int | None = None
parent: Static["LinkDescription"] = dataclasses.field(default=None, repr=False)
pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False)
children: Static[List["LinkDescription"]] = dataclasses.field(
Expand Down
6 changes: 3 additions & 3 deletions src/jaxsim/parsers/descriptions/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import itertools
from typing import List, Optional
from typing import List

from jaxsim import logging

Expand All @@ -23,8 +23,8 @@ def build_model_from(
joints: List[JointDescription],
collisions: List[CollisionShape] = (),
fixed_base: bool = False,
base_link_name: Optional[str] = None,
considered_joints: Optional[List[str]] = None,
base_link_name: str | None = None,
considered_joints: List[str] | None = None,
model_pose: RootPose = RootPose(),
) -> "ModelDescription":
# Create the full kinematic graph
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/parsers/kinematic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __post_init__(self):
def build_from(
links: List[descriptions.LinkDescription],
joints: List[descriptions.JointDescription],
root_link_name: Optional[str] = None,
root_link_name: str | None = None,
root_pose: RootPose = RootPose(),
) -> "KinematicGraph":
if root_link_name is None:
Expand Down
Loading

0 comments on commit b43c265

Please sign in to comment.