From 68354b51a0d36a1192a08c3587b6f0e9940db713 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 12 Sep 2023 18:15:42 +0200 Subject: [PATCH] Refactor code and add oop decorators for JIT compilation --- src/jaxsim/high_level/joint.py | 20 +++++++---- src/jaxsim/high_level/model.py | 44 +++++++++++------------ src/jaxsim/physics/model/physics_model.py | 11 +++--- 3 files changed, 40 insertions(+), 35 deletions(-) diff --git a/src/jaxsim/high_level/joint.py b/src/jaxsim/high_level/joint.py index 80feeb040..2828e15bf 100644 --- a/src/jaxsim/high_level/joint.py +++ b/src/jaxsim/high_level/joint.py @@ -111,15 +111,23 @@ def position_limit(self, dof: int = None) -> Tuple[jtp.Float, jtp.Float]: # ============= # Motor methods # ============= + @functools.partial(oop.jax_tf.method_ro) + def motor_inertia(self) -> jtp.Vector: + """""" + + return jnp.array(self.joint_description.motor_inertia, dtype=float) + + @functools.partial(oop.jax_tf.method_ro) + def motor_gear_ratio(self) -> jtp.Vector: + """""" - def motor_inertia(self) -> float: - return self.joint_description.motor_inertia + return jnp.array(self.joint_description.motor_gear_ratio, dtype=float) - def motor_gear_ratio(self) -> float: - return self.joint_description.motor_gear_ratio + @functools.partial(oop.jax_tf.method_ro) + def motor_viscous_friction(self) -> jtp.Vector: + """""" - def motor_viscous_friction(self) -> float: - return self.joint_description.motor_viscous_friction + return jnp.array(self.joint_description.motor_viscous_friction, dtype=float) # ================= # Multi-DoF methods diff --git a/src/jaxsim/high_level/model.py b/src/jaxsim/high_level/model.py index 418464394..6a29c68ee 100644 --- a/src/jaxsim/high_level/model.py +++ b/src/jaxsim/high_level/model.py @@ -1346,6 +1346,7 @@ def integrate( # Motor dynamics # ============== + @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) def set_motor_inertias( self, inertias: jtp.Vector, joint_names: List[str] = None ) -> None: @@ -1356,18 +1357,17 @@ def set_motor_inertias( raise ValueError("Wrong arguments size", inertias.size, len(joint_names)) self.physics_model._joint_motor_inertia.update( - dict( - map( - lambda j, g: (j, g), - self.physics_model._joint_motor_inertia.keys(), - inertias, + { + j: im + for j, im in zip( + self.physics_model._joint_motor_inertia.keys(), inertias ) - ) + } ) - self.physics_model.has_motors = True - logging.info("Setting attribute 'has_motors' to True") + logging.info("Setting attribute 'motor_inertias'") + @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) def set_motor_gear_ratios( self, gear_ratios: jtp.Vector, joint_names: List[str] = None ) -> None: @@ -1378,20 +1378,19 @@ def set_motor_gear_ratios( raise ValueError("Wrong arguments size", gear_ratios.size, len(joint_names)) self.physics_model._joint_motor_gear_ratio.update( - dict( - map( - lambda j, g: (j, g), - self.physics_model._joint_motor_gear_ratio.keys(), - gear_ratios, + { + j: gr + for j, gr in zip( + self.physics_model._joint_motor_gear_ratio.keys(), gear_ratios ) - ) + } ) - self.physics_model.has_motors = True - logging.info("Setting attribute 'has_motors' to True") + logging.info("Setting attribute 'motor_gear_ratios'") + @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) def set_motor_viscous_frictions( - self, viscous_frictions: jtp.Vector, joint_names: List[str] = None + self, viscous_frictions: Tuple, joint_names: List[str] = None ) -> None: if joint_names is None: joint_names = self.joint_names() @@ -1402,17 +1401,16 @@ def set_motor_viscous_frictions( ) self.physics_model._joint_motor_viscous_friction.update( - dict( - map( - lambda j, g: (j, g), + { + j: kv + for j, kv in zip( self.physics_model._joint_motor_viscous_friction.keys(), viscous_frictions, ) - ) + } ) - self.physics_model.has_motors = True - logging.info("Setting attribute 'has_motors' to True") + logging.info("Setting attribute 'motor_viscous_frictions'") # =============== # Private methods diff --git a/src/jaxsim/physics/model/physics_model.py b/src/jaxsim/physics/model/physics_model.py index fa74d2549..689276cfe 100644 --- a/src/jaxsim/physics/model/physics_model.py +++ b/src/jaxsim/physics/model/physics_model.py @@ -33,7 +33,6 @@ class PhysicsModel(JaxsimDataclass): ) ) is_floating_base: Static[bool] = dataclasses.field(default=False) - has_motors: Static[bool] = dataclasses.field(default=False) gc: GroundContact = dataclasses.field(default_factory=lambda: GroundContact()) description: Static[ jaxsim.parsers.descriptions.model.ModelDescription @@ -129,13 +128,15 @@ def build_from( # Dicts from the joint index to the motor inertia, gear ratio and viscous friction. # Note: the joint index is equal to its child link index. joint_motor_inertia = { - joint.index: joint.motor_inertia for joint in model_description.joints + joint.index: jnp.array(joint.motor_inertia, dtype=float) + for joint in model_description.joints } joint_motor_gear_ratio = { - joint.index: joint.motor_gear_ratio for joint in model_description.joints + joint.index: jnp.array(joint.motor_gear_ratio, dtype=float) + for joint in model_description.joints } joint_motor_viscous_friction = { - joint.index: joint.motor_viscous_friction + joint.index: jnp.array(joint.motor_viscous_friction, dtype=float) for joint in model_description.joints } @@ -204,7 +205,6 @@ def build_from( _joint_motor_viscous_friction=joint_motor_viscous_friction, gravity=jnp.hstack([gravity.squeeze(), np.zeros(3)]), is_floating_base=True, - has_motors=False, gc=GroundContact.build_from(model_description=model_description), description=model_description, ) @@ -341,7 +341,6 @@ def __repr__(self) -> str: f"dofs: {self.dofs()},", f"links: {self.NB},", f"floating_base: {self.is_floating_base},", - f"has_motors: {self.has_motors},", ] attributes_string = "\n ".join(attributes)