Skip to content

Commit

Permalink
Add @named_scope decorator to main API function
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 9, 2024
1 parent d730d69 commit 33919cb
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .common import VelRepr


@js.common.named_scope
@jax.jit
def com_position(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
Expand Down Expand Up @@ -44,6 +45,7 @@ def B_p̃_LCoM(i) -> jtp.Vector:
return (W_H_B @ B_p̃_CoM)[0:3].astype(float)


@js.common.named_scope
@jax.jit
def com_linear_velocity(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
Expand Down Expand Up @@ -74,6 +76,7 @@ def com_linear_velocity(
return G_vl_WG


@js.common.named_scope
@jax.jit
def centroidal_momentum(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
Expand Down Expand Up @@ -101,6 +104,7 @@ def centroidal_momentum(
return G_J @ ν


@js.common.named_scope
@jax.jit
def centroidal_momentum_jacobian(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
Expand Down Expand Up @@ -149,6 +153,7 @@ def centroidal_momentum_jacobian(
return G_Xf_B @ B_Jh


@js.common.named_scope
@jax.jit
def locked_centroidal_spatial_inertia(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
Expand Down Expand Up @@ -186,6 +191,7 @@ def locked_centroidal_spatial_inertia(
return G_Xf_B @ B_Mbb_B @ B_Xv_G


@js.common.named_scope
@jax.jit
def average_centroidal_velocity(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
Expand Down Expand Up @@ -213,6 +219,7 @@ def average_centroidal_velocity(
return G_J @ ν


@js.common.named_scope
@jax.jit
def average_centroidal_velocity_jacobian(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
Expand All @@ -239,6 +246,7 @@ def average_centroidal_velocity_jacobian(
return jnp.linalg.inv(G_Mbb) @ G_J


@js.common.named_scope
@jax.jit
def bias_acceleration(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
Expand Down
9 changes: 9 additions & 0 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .common import VelRepr


@js.common.named_scope
@jax.jit
def collidable_point_kinematics(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
Expand Down Expand Up @@ -52,6 +53,7 @@ def collidable_point_kinematics(
return W_p_Ci, W_ṗ_Ci


@js.common.named_scope
@jax.jit
def collidable_point_positions(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
Expand All @@ -72,6 +74,7 @@ def collidable_point_positions(
return W_p_Ci


@js.common.named_scope
@jax.jit
def collidable_point_velocities(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
Expand All @@ -92,6 +95,7 @@ def collidable_point_velocities(
return W_ṗ_Ci


@js.common.named_scope
@jax.jit
def collidable_point_forces(
model: js.model.JaxSimModel,
Expand Down Expand Up @@ -129,6 +133,7 @@ def collidable_point_forces(
return f_Ci


@js.common.named_scope
@jax.jit
def collidable_point_dynamics(
model: js.model.JaxSimModel,
Expand Down Expand Up @@ -227,6 +232,7 @@ def collidable_point_dynamics(
return f_Ci, aux_data


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["link_names"])
def in_contact(
model: js.model.JaxSimModel,
Expand Down Expand Up @@ -424,6 +430,7 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
return parameters


@js.common.named_scope
@jax.jit
def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
r"""
Expand Down Expand Up @@ -469,6 +476,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def jacobian(
model: js.model.JaxSimModel,
Expand Down Expand Up @@ -561,6 +569,7 @@ def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
return O_J_WC


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def jacobian_derivative(
model: js.model.JaxSimModel,
Expand Down
16 changes: 16 additions & 0 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def standard_gravity(self) -> jtp.Float:

return -self.gravity[2]

@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["joint_names"])
def joint_positions(
self,
Expand Down Expand Up @@ -300,6 +301,7 @@ def joint_positions(

return self.state.physics_model.joint_positions[joint_idxs]

@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["joint_names"])
def joint_velocities(
self,
Expand Down Expand Up @@ -347,6 +349,7 @@ def joint_velocities(

return self.state.physics_model.joint_velocities[joint_idxs]

@js.common.named_scope
@jax.jit
def base_position(self) -> jtp.Vector:
"""
Expand All @@ -358,6 +361,7 @@ def base_position(self) -> jtp.Vector:

return self.state.physics_model.base_position.squeeze()

@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["dcm"])
def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix:
"""
Expand Down Expand Up @@ -386,6 +390,7 @@ def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix
float
)

@js.common.named_scope
@jax.jit
def base_transform(self) -> jtp.Matrix:
"""
Expand All @@ -405,6 +410,7 @@ def base_transform(self) -> jtp.Matrix:
]
)

@js.common.named_scope
@jax.jit
def base_velocity(self) -> jtp.Vector:
"""
Expand Down Expand Up @@ -434,6 +440,7 @@ def base_velocity(self) -> jtp.Vector:
.astype(float)
)

@js.common.named_scope
@jax.jit
def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
r"""
Expand All @@ -446,6 +453,7 @@ def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:

return self.base_transform(), self.joint_positions()

@js.common.named_scope
@jax.jit
def generalized_velocity(self) -> jtp.Vector:
r"""
Expand All @@ -466,6 +474,7 @@ def generalized_velocity(self) -> jtp.Vector:
# Store quantities
# ================

@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["joint_names"])
def reset_joint_positions(
self,
Expand Down Expand Up @@ -514,6 +523,7 @@ def replace(s: jtp.VectorLike) -> JaxSimModelData:
s=self.state.physics_model.joint_positions.at[joint_idxs].set(positions)
)

@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["joint_names"])
def reset_joint_velocities(
self,
Expand Down Expand Up @@ -562,6 +572,7 @@ def replace(ṡ: jtp.VectorLike) -> JaxSimModelData:
=self.state.physics_model.joint_velocities.at[joint_idxs].set(velocities)
)

@js.common.named_scope
@jax.jit
def reset_base_position(self, base_position: jtp.VectorLike) -> Self:
"""
Expand All @@ -585,6 +596,7 @@ def reset_base_position(self, base_position: jtp.VectorLike) -> Self:
),
)

@js.common.named_scope
@jax.jit
def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:
"""
Expand Down Expand Up @@ -612,6 +624,7 @@ def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:
),
)

@js.common.named_scope
@jax.jit
def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self:
"""
Expand All @@ -634,6 +647,7 @@ def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self:
base_quaternion=W_Q_B
)

@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["velocity_representation"])
def reset_base_linear_velocity(
self,
Expand Down Expand Up @@ -665,6 +679,7 @@ def reset_base_linear_velocity(
velocity_representation=velocity_representation,
)

@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["velocity_representation"])
def reset_base_angular_velocity(
self,
Expand Down Expand Up @@ -696,6 +711,7 @@ def reset_base_angular_velocity(
velocity_representation=velocity_representation,
)

@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["velocity_representation"])
def reset_base_velocity(
self,
Expand Down
7 changes: 7 additions & 0 deletions src/jaxsim/api/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# =======================


@js.common.named_scope
@jax.jit
def idx_of_parent_link(
model: js.model.JaxSimModel, *, frame_index: jtp.IntLike
Expand Down Expand Up @@ -45,6 +46,7 @@ def idx_of_parent_link(
]


@js.common.named_scope
@functools.partial(jax.jit, static_argnames="frame_name")
def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int:
"""
Expand Down Expand Up @@ -97,6 +99,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str
]


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["frame_names"])
def names_to_idxs(
model: js.model.JaxSimModel, *, frame_names: Sequence[str]
Expand Down Expand Up @@ -139,6 +142,7 @@ def idxs_to_names(
# ==========


@js.common.named_scope
@jax.jit
def transform(
model: js.model.JaxSimModel,
Expand Down Expand Up @@ -180,6 +184,7 @@ def transform(
return W_H_L @ L_H_F


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def velocity(
model: js.model.JaxSimModel,
Expand Down Expand Up @@ -230,6 +235,7 @@ def velocity(
return O_J_WF_I @ I_ν


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def jacobian(
model: js.model.JaxSimModel,
Expand Down Expand Up @@ -309,6 +315,7 @@ def jacobian(
return O_J_WL_I


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def jacobian_derivative(
model: js.model.JaxSimModel,
Expand Down
4 changes: 4 additions & 0 deletions src/jaxsim/api/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# =======================


@js.common.named_scope
@functools.partial(jax.jit, static_argnames="joint_name")
def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
"""
Expand Down Expand Up @@ -61,6 +62,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str
return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1]


@js.common.named_scope
@functools.partial(jax.jit, static_argnames="joint_names")
def names_to_idxs(
model: js.model.JaxSimModel, *, joint_names: Sequence[str]
Expand Down Expand Up @@ -141,6 +143,7 @@ def position_limit(
return s_min.astype(float), s_max.astype(float)


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["joint_names"])
def position_limits(
model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None
Expand Down Expand Up @@ -176,6 +179,7 @@ def position_limits(
# ======================


@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["joint_names"])
def random_joint_positions(
model: js.model.JaxSimModel,
Expand Down
Loading

0 comments on commit 33919cb

Please sign in to comment.