diff --git a/src/jaxsim/high_level/model.py b/src/jaxsim/high_level/model.py index d1314bf9a..1ef0886c1 100644 --- a/src/jaxsim/high_level/model.py +++ b/src/jaxsim/high_level/model.py @@ -491,7 +491,7 @@ def joint_positions(self, joint_names: tuple[str, ...] = None) -> jtp.Vector: @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) def joint_random_positions( - self, joint_names: tuple[str, ...] = None, key: jax.random.PRNGKeyArray = None + self, joint_names: tuple[str, ...] = None, key: jax.Array = None ) -> jtp.Vector: """""" @@ -1505,64 +1505,48 @@ def integrate( # Motor dynamics # ============== - # @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) + @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) def set_motor_inertias( - self, inertias: jtp.Vector, joint_names: List[str] = None + self, inertias: jtp.Vector, joint_names: tuple[str, ...] = None ) -> None: - if joint_names is None: - joint_names = self.joint_names() + joint_names = joint_names or self.joint_names() if inertias.size != len(joint_names): raise ValueError("Wrong arguments size", inertias.size, len(joint_names)) self.physics_model._joint_motor_inertia.update( - { - j: im - for j, im in zip( - self.physics_model._joint_motor_inertia.keys(), inertias - ) - } + dict(zip(self.physics_model._joint_motor_inertia, inertias)) ) logging.info("Setting attribute 'motor_inertias'") - # @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) + @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) def set_motor_gear_ratios( - self, gear_ratios: jtp.Vector, joint_names: List[str] = None + self, gear_ratios: jtp.Vector, joint_names: tuple[str, ...] = None ) -> None: - if joint_names is None: - joint_names = self.joint_names() + joint_names = joint_names or self.joint_names() if gear_ratios.size != len(joint_names): raise ValueError("Wrong arguments size", gear_ratios.size, len(joint_names)) # Check on gear ratios if motor_inertias are not zero - jax.lax.cond( - pred=(jnp.diag(self.physics_model._joint_motor_inertia) != 0).any(), - operand=gear_ratios, - true_fun=lambda gr: gr, - false_fun=lambda: (_ for _ in _).throw( - ValueError("Motor inertias are zero") - ), - ) + for idx, gr in gear_ratios: + if gr != 0 and self.motor_inertias()[idx] == 0: + raise ValueError( + f"Zero motor inertia with non-zero gear ratio found in position {idx}" + ) self.physics_model._joint_motor_gear_ratio.update( - { - j: gr - for j, gr in zip( - self.physics_model._joint_motor_gear_ratio.keys(), gear_ratios - ) - } + dict(zip(self.physics_model._joint_motor_gear_ratio, gear_ratios)) ) logging.info("Setting attribute 'motor_gear_ratios'") - # @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) + @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) def set_motor_viscous_frictions( - self, viscous_frictions: Tuple, joint_names: List[str] = None + self, viscous_frictions: jtp.Vector, joint_names: tuple[str, ...] = None ) -> None: - if joint_names is None: - joint_names = self.joint_names() + joint_names = joint_names or self.joint_names() if viscous_frictions.size != len(joint_names): raise ValueError( @@ -1570,17 +1554,35 @@ def set_motor_viscous_frictions( ) self.physics_model._joint_motor_viscous_friction.update( - { - j: kv - for j, kv in zip( - self.physics_model._joint_motor_viscous_friction.keys(), + dict( + zip( + self.physics_model._joint_motor_viscous_friction, viscous_frictions, ) - } + ) ) + # self.physics_model._joint_motor_viscous_friction |= dict( + # zip( + # self.physics_model._joint_motor_viscous_friction.keys(), + # viscous_frictions, + # ) + # ) + logging.info("Setting attribute 'motor_viscous_frictions'") + @functools.partial(oop.jax_tf.method_ro, jit=False) + def motor_inertias(self) -> jtp.Vector: + return jnp.array([*self.physics_model._joint_motor_inertia.values()]) + + @functools.partial(oop.jax_tf.method_ro, jit=False) + def motor_gear_ratios(self) -> jtp.Vector: + return jnp.array([*self.physics_model._joint_motor_gear_ratio.values()]) + + @functools.partial(oop.jax_tf.method_ro, jit=False) + def motor_viscous_frictions(self) -> jtp.Vector: + return jnp.array([*self.physics_model._joint_motor_viscous_friction.values()]) + # =============== # Private methods # ===============