From effe29040480ff3bec7b26cad5f6b73760dacb17 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 11 Dec 2024 14:38:29 +0100 Subject: [PATCH] Fix `js.common.named_scope` decorator usage --- src/jaxsim/api/com.py | 16 ++++++++-------- src/jaxsim/api/contact.py | 18 +++++++++--------- src/jaxsim/api/frame.py | 14 +++++++------- src/jaxsim/api/joint.py | 8 ++++---- src/jaxsim/api/model.py | 38 +++++++++++++++++++------------------- src/jaxsim/api/ode.py | 6 +++--- 6 files changed, 50 insertions(+), 50 deletions(-) diff --git a/src/jaxsim/api/com.py b/src/jaxsim/api/com.py index 11b9b2d4a..f2122ced1 100644 --- a/src/jaxsim/api/com.py +++ b/src/jaxsim/api/com.py @@ -8,8 +8,8 @@ from .common import VelRepr -@js.common.named_scope @jax.jit +@js.common.named_scope def com_position( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Vector: @@ -45,8 +45,8 @@ def B_p̃_LCoM(i) -> jtp.Vector: return (W_H_B @ B_p̃_CoM)[0:3].astype(float) -@js.common.named_scope @jax.jit +@js.common.named_scope def com_linear_velocity( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Vector: @@ -76,8 +76,8 @@ def com_linear_velocity( return G_vl_WG -@js.common.named_scope @jax.jit +@js.common.named_scope def centroidal_momentum( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Vector: @@ -104,8 +104,8 @@ def centroidal_momentum( return G_J @ ν -@js.common.named_scope @jax.jit +@js.common.named_scope def centroidal_momentum_jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: @@ -153,8 +153,8 @@ def centroidal_momentum_jacobian( return G_Xf_B @ B_Jh -@js.common.named_scope @jax.jit +@js.common.named_scope def locked_centroidal_spatial_inertia( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ): @@ -191,8 +191,8 @@ def locked_centroidal_spatial_inertia( return G_Xf_B @ B_Mbb_B @ B_Xv_G -@js.common.named_scope @jax.jit +@js.common.named_scope def average_centroidal_velocity( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Vector: @@ -219,8 +219,8 @@ def average_centroidal_velocity( return G_J @ ν -@js.common.named_scope @jax.jit +@js.common.named_scope def average_centroidal_velocity_jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: @@ -246,8 +246,8 @@ def average_centroidal_velocity_jacobian( return jnp.linalg.inv(G_Mbb) @ G_J -@js.common.named_scope @jax.jit +@js.common.named_scope def bias_acceleration( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Vector: diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 31d2245d4..294413f7e 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -16,8 +16,8 @@ from .common import VelRepr -@js.common.named_scope @jax.jit +@js.common.named_scope def collidable_point_kinematics( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> tuple[jtp.Matrix, jtp.Matrix]: @@ -53,8 +53,8 @@ def collidable_point_kinematics( return W_p_Ci, W_ṗ_Ci -@js.common.named_scope @jax.jit +@js.common.named_scope def collidable_point_positions( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: @@ -74,8 +74,8 @@ def collidable_point_positions( return W_p_Ci -@js.common.named_scope @jax.jit +@js.common.named_scope def collidable_point_velocities( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: @@ -95,8 +95,8 @@ def collidable_point_velocities( return W_ṗ_Ci -@js.common.named_scope @jax.jit +@js.common.named_scope def collidable_point_forces( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, @@ -133,8 +133,8 @@ def collidable_point_forces( return f_Ci -@js.common.named_scope @jax.jit +@js.common.named_scope def collidable_point_dynamics( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, @@ -232,8 +232,8 @@ def collidable_point_dynamics( return f_Ci, aux_data -@js.common.named_scope @functools.partial(jax.jit, static_argnames=["link_names"]) +@js.common.named_scope def in_contact( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, @@ -430,8 +430,8 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: return parameters -@js.common.named_scope @jax.jit +@js.common.named_scope def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array: r""" Return the pose of the enabled collidable points. @@ -476,8 +476,8 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C) -@js.common.named_scope @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@js.common.named_scope def jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, @@ -569,8 +569,8 @@ def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: return O_J_WC -@js.common.named_scope @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@js.common.named_scope def jacobian_derivative( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, diff --git a/src/jaxsim/api/frame.py b/src/jaxsim/api/frame.py index 3a1013137..39c7092fd 100644 --- a/src/jaxsim/api/frame.py +++ b/src/jaxsim/api/frame.py @@ -16,8 +16,8 @@ # ======================= -@js.common.named_scope @jax.jit +@js.common.named_scope def idx_of_parent_link( model: js.model.JaxSimModel, *, frame_index: jtp.IntLike ) -> jtp.Int: @@ -46,8 +46,8 @@ def idx_of_parent_link( ] -@js.common.named_scope @functools.partial(jax.jit, static_argnames="frame_name") +@js.common.named_scope def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int: """ Convert the name of a frame to its index. @@ -99,8 +99,8 @@ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str ] -@js.common.named_scope @functools.partial(jax.jit, static_argnames=["frame_names"]) +@js.common.named_scope def names_to_idxs( model: js.model.JaxSimModel, *, frame_names: Sequence[str] ) -> jax.Array: @@ -142,8 +142,8 @@ def idxs_to_names( # ========== -@js.common.named_scope @jax.jit +@js.common.named_scope def transform( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, @@ -184,8 +184,8 @@ def transform( return W_H_L @ L_H_F -@js.common.named_scope @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@js.common.named_scope def velocity( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, @@ -235,8 +235,8 @@ def velocity( return O_J_WF_I @ I_ν -@js.common.named_scope @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@js.common.named_scope def jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, @@ -315,8 +315,8 @@ def jacobian( return O_J_WL_I -@js.common.named_scope @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@js.common.named_scope def jacobian_derivative( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, diff --git a/src/jaxsim/api/joint.py b/src/jaxsim/api/joint.py index e70a393a1..8bc9127db 100644 --- a/src/jaxsim/api/joint.py +++ b/src/jaxsim/api/joint.py @@ -13,8 +13,8 @@ # ======================= -@js.common.named_scope @functools.partial(jax.jit, static_argnames="joint_name") +@js.common.named_scope def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int: """ Convert the name of a joint to its index. @@ -62,8 +62,8 @@ def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1] -@js.common.named_scope @functools.partial(jax.jit, static_argnames="joint_names") +@js.common.named_scope def names_to_idxs( model: js.model.JaxSimModel, *, joint_names: Sequence[str] ) -> jax.Array: @@ -143,8 +143,8 @@ def position_limit( return s_min.astype(float), s_max.astype(float) -@js.common.named_scope @functools.partial(jax.jit, static_argnames=["joint_names"]) +@js.common.named_scope def position_limits( model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None ) -> tuple[jtp.Vector, jtp.Vector]: @@ -179,8 +179,8 @@ def position_limits( # ====================== -@js.common.named_scope @functools.partial(jax.jit, static_argnames=["joint_names"]) +@js.common.named_scope def random_joint_positions( model: js.model.JaxSimModel, *, diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index d33232af7..76ca9a024 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -491,8 +491,8 @@ def reduce( # =================== -@js.common.named_scope @jax.jit +@js.common.named_scope def total_mass(model: JaxSimModel) -> jtp.Float: """ Compute the total mass of the model. @@ -507,8 +507,8 @@ def total_mass(model: JaxSimModel) -> jtp.Float: return model.kin_dyn_parameters.link_parameters.mass.sum().astype(float) -@js.common.named_scope @jax.jit +@js.common.named_scope def link_spatial_inertia_matrices(model: JaxSimModel) -> jtp.Array: """ Compute the spatial 6D inertia matrices of all links of the model. @@ -530,8 +530,8 @@ def link_spatial_inertia_matrices(model: JaxSimModel) -> jtp.Array: # ============================== -@js.common.named_scope @jax.jit +@js.common.named_scope def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array: """ Compute the SE(3) transforms from the world frame to the frames of all links. @@ -918,8 +918,8 @@ def forward_dynamics( ) -@js.common.named_scope @jax.jit +@js.common.named_scope def forward_dynamics_aba( model: JaxSimModel, data: js.data.JaxSimModelData, @@ -1063,8 +1063,8 @@ def to_active( return C_v̇_WB.astype(float), s̈.astype(float) -@js.common.named_scope @jax.jit +@js.common.named_scope def forward_dynamics_crb( model: JaxSimModel, data: js.data.JaxSimModelData, @@ -1155,8 +1155,8 @@ def forward_dynamics_crb( return v̇_WB, s̈ -@js.common.named_scope @jax.jit +@js.common.named_scope def free_floating_mass_matrix( model: JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: @@ -1201,8 +1201,8 @@ def free_floating_mass_matrix( raise ValueError(data.velocity_representation) -@js.common.named_scope @jax.jit +@js.common.named_scope def free_floating_coriolis_matrix( model: JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: @@ -1318,8 +1318,8 @@ def compute_link_contribution(M, v, J, J̇) -> jtp.Array: raise ValueError(data.velocity_representation) -@js.common.named_scope @jax.jit +@js.common.named_scope def inverse_dynamics( model: JaxSimModel, data: js.data.JaxSimModelData, @@ -1474,8 +1474,8 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): return f_B.astype(float), τ.astype(float) -@js.common.named_scope @jax.jit +@js.common.named_scope def free_floating_gravity_forces( model: JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Vector: @@ -1524,8 +1524,8 @@ def free_floating_gravity_forces( ).astype(float) -@js.common.named_scope @jax.jit +@js.common.named_scope def free_floating_bias_forces( model: JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Vector: @@ -1594,8 +1594,8 @@ def free_floating_bias_forces( # ========================== -@js.common.named_scope @jax.jit +@js.common.named_scope def locked_spatial_inertia( model: JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: @@ -1613,8 +1613,8 @@ def locked_spatial_inertia( return total_momentum_jacobian(model=model, data=data)[:, 0:6] -@js.common.named_scope @jax.jit +@js.common.named_scope def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector: """ Compute the total momentum of the model. @@ -1702,8 +1702,8 @@ def total_momentum_jacobian( raise ValueError(output_vel_repr) -@js.common.named_scope @jax.jit +@js.common.named_scope def average_velocity(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector: """ Compute the average velocity of the model. @@ -1791,8 +1791,8 @@ def average_velocity_jacobian( # ======================== -@js.common.named_scope @jax.jit +@js.common.named_scope def link_bias_accelerations( model: JaxSimModel, data: js.data.JaxSimModelData, @@ -2001,8 +2001,8 @@ def body_to_other_representation( return O_v̇_WL -@js.common.named_scope @jax.jit +@js.common.named_scope def link_contact_forces( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, @@ -2077,8 +2077,8 @@ def link_contact_forces( # ====== -@js.common.named_scope @jax.jit +@js.common.named_scope def mechanical_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float: """ Compute the mechanical energy of the model. @@ -2097,8 +2097,8 @@ def mechanical_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp. return (K + U).astype(float) -@js.common.named_scope @jax.jit +@js.common.named_scope def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float: """ Compute the kinetic energy of the model. @@ -2119,8 +2119,8 @@ def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Flo return K.squeeze().astype(float) -@js.common.named_scope @jax.jit +@js.common.named_scope def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float: """ Compute the potential energy of the model. @@ -2146,8 +2146,8 @@ def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.F # ========== -@js.common.named_scope @jax.jit +@js.common.named_scope def step( model: JaxSimModel, data: js.data.JaxSimModelData, diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 93e558a04..a6b685cde 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -85,8 +85,8 @@ def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]: # ================================== -@js.common.named_scope @jax.jit +@js.common.named_scope def system_velocity_dynamics( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, @@ -332,8 +332,8 @@ def system_acceleration( return v̇_WB, s̈ -@js.common.named_scope @jax.jit +@js.common.named_scope def system_position_dynamics( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, @@ -372,8 +372,8 @@ def system_position_dynamics( return W_ṗ_B, W_Q̇_B, ṡ -@js.common.named_scope @jax.jit +@js.common.named_scope def system_dynamics( model: js.model.JaxSimModel, data: js.data.JaxSimModelData,