diff --git a/src/jaxsim/high_level/model.py b/src/jaxsim/high_level/model.py index b7131da2d..db64e4294 100644 --- a/src/jaxsim/high_level/model.py +++ b/src/jaxsim/high_level/model.py @@ -491,7 +491,7 @@ def joint_positions(self, joint_names: tuple[str, ...] = None) -> jtp.Vector: @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) def joint_random_positions( - self, joint_names: tuple[str, ...] = None, key: jax.random.PRNGKeyArray = None + self, joint_names: tuple[str, ...] = None, key: jax.Array = None ) -> jtp.Vector: """"""