From fe7998a6b97a21d53aa11004569e4cee1f3cb89a Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 11 Dec 2024 15:00:53 +0100 Subject: [PATCH] Add joint and base acceleration attributes to `PhysicsModelState` --- src/jaxsim/api/ode_data.py | 68 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/src/jaxsim/api/ode_data.py b/src/jaxsim/api/ode_data.py index 9d5db94f9..df4c02510 100644 --- a/src/jaxsim/api/ode_data.py +++ b/src/jaxsim/api/ode_data.py @@ -188,6 +188,7 @@ class PhysicsModelState(JaxsimDataclass): # Joint state joint_positions: jtp.Vector joint_velocities: jtp.Vector + joint_accelerations: jtp.Vector # Base state base_position: jtp.Vector = jax_dataclasses.field( @@ -202,6 +203,12 @@ class PhysicsModelState(JaxsimDataclass): base_angular_velocity: jtp.Vector = jax_dataclasses.field( default_factory=lambda: jnp.zeros(3) ) + base_linear_acceleration: jtp.Vector = jax_dataclasses.field( + default_factory=lambda: jnp.zeros(3) + ) + base_angular_acceleration: jtp.Vector = jax_dataclasses.field( + default_factory=lambda: jnp.zeros(3) + ) def __hash__(self) -> int: @@ -211,10 +218,13 @@ def __hash__(self) -> int: ( HashedNumpyArray.hash_of_array(self.joint_positions), HashedNumpyArray.hash_of_array(self.joint_velocities), + HashedNumpyArray.hash_of_array(self.joint_accelerations), HashedNumpyArray.hash_of_array(self.base_position), HashedNumpyArray.hash_of_array(self.base_quaternion), HashedNumpyArray.hash_of_array(self.base_linear_velocity), HashedNumpyArray.hash_of_array(self.base_angular_velocity), + HashedNumpyArray.hash_of_array(self.base_linear_acceleration), + HashedNumpyArray.hash_of_array(self.base_angular_acceleration), ) ) @@ -230,10 +240,13 @@ def build_from_jaxsim_model( model: js.model.JaxSimModel | None = None, joint_positions: jtp.Vector | None = None, joint_velocities: jtp.Vector | None = None, + joint_accelerations: jtp.Vector | None = None, base_position: jtp.Vector | None = None, base_quaternion: jtp.Vector | None = None, base_linear_velocity: jtp.Vector | None = None, base_angular_velocity: jtp.Vector | None = None, + base_linear_acceleration: jtp.Vector | None = None, + base_angular_acceleration: jtp.Vector | None = None, ) -> PhysicsModelState: """ Build a `PhysicsModelState` from a `JaxSimModel`. @@ -242,12 +255,17 @@ def build_from_jaxsim_model( model: The `JaxSimModel` associated with the state. joint_positions: The vector of joint positions. joint_velocities: The vector of joint velocities. + joint_accelerations: The vector of joint accelerations. base_position: The 3D position of the base link. base_quaternion: The quaternion defining the orientation of the base link. base_linear_velocity: The linear velocity of the base link in inertial-fixed representation. base_angular_velocity: The angular velocity of the base link in inertial-fixed representation. + base_linear_acceleration: + The linear acceleration of the base link in inertial-fixed representation. + base_angular_acceleration: + The angular acceleration of the base link in inertial-fixed representation Note: If any of the state components are not provided, they are built from the @@ -260,10 +278,13 @@ def build_from_jaxsim_model( return PhysicsModelState.build( joint_positions=joint_positions, joint_velocities=joint_velocities, + joint_accelerations=joint_accelerations, base_position=base_position, base_quaternion=base_quaternion, base_linear_velocity=base_linear_velocity, base_angular_velocity=base_angular_velocity, + base_linear_acceleration=base_linear_acceleration, + base_angular_acceleration=base_angular_acceleration, number_of_dofs=model.dofs(), ) @@ -271,10 +292,13 @@ def build_from_jaxsim_model( def build( joint_positions: jtp.Vector | None = None, joint_velocities: jtp.Vector | None = None, + joint_accelerations: jtp.Vector | None = None, base_position: jtp.Vector | None = None, base_quaternion: jtp.Vector | None = None, base_linear_velocity: jtp.Vector | None = None, base_angular_velocity: jtp.Vector | None = None, + base_linear_acceleration: jtp.Vector | None = None, + base_angular_acceleration: jtp.Vector | None = None, number_of_dofs: jtp.Int | None = None, ) -> PhysicsModelState: """ @@ -283,12 +307,17 @@ def build( Args: joint_positions: The vector of joint positions. joint_velocities: The vector of joint velocities. + joint_accelerations: The vector of joint accelerations. base_position: The 3D position of the base link. base_quaternion: The quaternion defining the orientation of the base link. base_linear_velocity: The linear velocity of the base link in inertial-fixed representation. base_angular_velocity: The angular velocity of the base link in inertial-fixed representation. + base_linear_acceleration: + The linear acceleration of the base link in inertial-fixed representation. + base_angular_acceleration: + The angular acceleration of the base link in inertial-fixed representation number_of_dofs: The number of degrees of freedom of the physics model. @@ -308,6 +337,12 @@ def build( else jnp.zeros(number_of_dofs) ) + joint_accelerations = ( + joint_accelerations + if joint_accelerations is not None + else jnp.zeros(number_of_dofs) + ) + base_position = base_position if base_position is not None else jnp.zeros(3) base_quaternion = ( @@ -324,13 +359,28 @@ def build( base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3) ) + base_linear_acceleration = ( + base_linear_acceleration + if base_linear_acceleration is not None + else jnp.zeros(3) + ) + + base_angular_acceleration = ( + base_angular_acceleration + if base_angular_acceleration is not None + else jnp.zeros(3) + ) + physics_model_state = PhysicsModelState( joint_positions=jnp.array(joint_positions, dtype=float), joint_velocities=jnp.array(joint_velocities, dtype=float), + joint_accelerations=jnp.array(joint_accelerations, dtype=float), base_position=jnp.array(base_position, dtype=float), base_quaternion=jnp.array(base_quaternion, dtype=float), base_linear_velocity=jnp.array(base_linear_velocity, dtype=float), base_angular_velocity=jnp.array(base_angular_velocity, dtype=float), + base_linear_acceleration=jnp.array(base_linear_acceleration, dtype=float), + base_angular_acceleration=jnp.array(base_angular_acceleration, dtype=float), ) # TODO (diegoferigo): assert state.valid(physics_model) @@ -371,6 +421,12 @@ def valid(self, model: js.model.JaxSimModel) -> bool: shape = self.joint_velocities.shape expected_shape = (model.dofs(),) + if shape != expected_shape: + return False + + shape = self.joint_accelerations.shape + expected_shape = (model.dofs(),) + if shape != expected_shape: return False @@ -395,6 +451,18 @@ def valid(self, model: js.model.JaxSimModel) -> bool: shape = self.base_angular_velocity.shape expected_shape = (3,) + if shape != expected_shape: + return False + + shape = self.base_linear_acceleration.shape + expected_shape = (3,) + + if shape != expected_shape: + return False + + shape = self.base_angular_acceleration.shape + expected_shape = (3,) + if shape != expected_shape: return False