Skip to content

Commit

Permalink
Fix js.common.named_scope decorator usage
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 11, 2024
1 parent c896b48 commit effe290
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 50 deletions.
16 changes: 8 additions & 8 deletions src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from .common import VelRepr


@js.common.named_scope
@jax.jit
@js.common.named_scope
def com_position(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
Expand Down Expand Up @@ -45,8 +45,8 @@ def B_p̃_LCoM(i) -> jtp.Vector:
return (W_H_B @ B_p̃_CoM)[0:3].astype(float)


@js.common.named_scope
@jax.jit
@js.common.named_scope
def com_linear_velocity(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
Expand Down Expand Up @@ -76,8 +76,8 @@ def com_linear_velocity(
return G_vl_WG


@js.common.named_scope
@jax.jit
@js.common.named_scope
def centroidal_momentum(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
Expand All @@ -104,8 +104,8 @@ def centroidal_momentum(
return G_J @ ν


@js.common.named_scope
@jax.jit
@js.common.named_scope
def centroidal_momentum_jacobian(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
Expand Down Expand Up @@ -153,8 +153,8 @@ def centroidal_momentum_jacobian(
return G_Xf_B @ B_Jh


@js.common.named_scope
@jax.jit
@js.common.named_scope
def locked_centroidal_spatial_inertia(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
):
Expand Down Expand Up @@ -191,8 +191,8 @@ def locked_centroidal_spatial_inertia(
return G_Xf_B @ B_Mbb_B @ B_Xv_G


@js.common.named_scope
@jax.jit
@js.common.named_scope
def average_centroidal_velocity(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
Expand All @@ -219,8 +219,8 @@ def average_centroidal_velocity(
return G_J @ ν


@js.common.named_scope
@jax.jit
@js.common.named_scope
def average_centroidal_velocity_jacobian(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
Expand All @@ -246,8 +246,8 @@ def average_centroidal_velocity_jacobian(
return jnp.linalg.inv(G_Mbb) @ G_J


@js.common.named_scope
@jax.jit
@js.common.named_scope
def bias_acceleration(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
Expand Down
18 changes: 9 additions & 9 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from .common import VelRepr


@js.common.named_scope
@jax.jit
@js.common.named_scope
def collidable_point_kinematics(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> tuple[jtp.Matrix, jtp.Matrix]:
Expand Down Expand Up @@ -53,8 +53,8 @@ def collidable_point_kinematics(
return W_p_Ci, W_ṗ_Ci


@js.common.named_scope
@jax.jit
@js.common.named_scope
def collidable_point_positions(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
Expand All @@ -74,8 +74,8 @@ def collidable_point_positions(
return W_p_Ci


@js.common.named_scope
@jax.jit
@js.common.named_scope
def collidable_point_velocities(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
Expand All @@ -95,8 +95,8 @@ def collidable_point_velocities(
return W_ṗ_Ci


@js.common.named_scope
@jax.jit
@js.common.named_scope
def collidable_point_forces(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
Expand Down Expand Up @@ -133,8 +133,8 @@ def collidable_point_forces(
return f_Ci


@js.common.named_scope
@jax.jit
@js.common.named_scope
def collidable_point_dynamics(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
Expand Down Expand Up @@ -232,8 +232,8 @@ def collidable_point_dynamics(
return f_Ci, aux_data


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["link_names"])
@js.common.named_scope
def in_contact(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
Expand Down Expand Up @@ -430,8 +430,8 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
return parameters


@js.common.named_scope
@jax.jit
@js.common.named_scope
def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
r"""
Return the pose of the enabled collidable points.
Expand Down Expand Up @@ -476,8 +476,8 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
@js.common.named_scope
def jacobian(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
Expand Down Expand Up @@ -569,8 +569,8 @@ def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
return O_J_WC


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
@js.common.named_scope
def jacobian_derivative(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
Expand Down
14 changes: 7 additions & 7 deletions src/jaxsim/api/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# =======================


@js.common.named_scope
@jax.jit
@js.common.named_scope
def idx_of_parent_link(
model: js.model.JaxSimModel, *, frame_index: jtp.IntLike
) -> jtp.Int:
Expand Down Expand Up @@ -46,8 +46,8 @@ def idx_of_parent_link(
]


@js.common.named_scope
@functools.partial(jax.jit, static_argnames="frame_name")
@js.common.named_scope
def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int:
"""
Convert the name of a frame to its index.
Expand Down Expand Up @@ -99,8 +99,8 @@ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str
]


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["frame_names"])
@js.common.named_scope
def names_to_idxs(
model: js.model.JaxSimModel, *, frame_names: Sequence[str]
) -> jax.Array:
Expand Down Expand Up @@ -142,8 +142,8 @@ def idxs_to_names(
# ==========


@js.common.named_scope
@jax.jit
@js.common.named_scope
def transform(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
Expand Down Expand Up @@ -184,8 +184,8 @@ def transform(
return W_H_L @ L_H_F


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
@js.common.named_scope
def velocity(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
Expand Down Expand Up @@ -235,8 +235,8 @@ def velocity(
return O_J_WF_I @ I_ν


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
@js.common.named_scope
def jacobian(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
Expand Down Expand Up @@ -315,8 +315,8 @@ def jacobian(
return O_J_WL_I


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
@js.common.named_scope
def jacobian_derivative(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
Expand Down
8 changes: 4 additions & 4 deletions src/jaxsim/api/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# =======================


@js.common.named_scope
@functools.partial(jax.jit, static_argnames="joint_name")
@js.common.named_scope
def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
"""
Convert the name of a joint to its index.
Expand Down Expand Up @@ -62,8 +62,8 @@ def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str
return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1]


@js.common.named_scope
@functools.partial(jax.jit, static_argnames="joint_names")
@js.common.named_scope
def names_to_idxs(
model: js.model.JaxSimModel, *, joint_names: Sequence[str]
) -> jax.Array:
Expand Down Expand Up @@ -143,8 +143,8 @@ def position_limit(
return s_min.astype(float), s_max.astype(float)


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["joint_names"])
@js.common.named_scope
def position_limits(
model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None
) -> tuple[jtp.Vector, jtp.Vector]:
Expand Down Expand Up @@ -179,8 +179,8 @@ def position_limits(
# ======================


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["joint_names"])
@js.common.named_scope
def random_joint_positions(
model: js.model.JaxSimModel,
*,
Expand Down
Loading

0 comments on commit effe290

Please sign in to comment.