Skip to content

Commit

Permalink
Add joint and base acceleration attributes to PhysicsModelState
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 26, 2024
1 parent de15613 commit c9c0651
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions src/jaxsim/api/ode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class PhysicsModelState(JaxsimDataclass):
# Joint state
joint_positions: jtp.Vector
joint_velocities: jtp.Vector
joint_accelerations: jtp.Vector

# Base state
base_position: jtp.Vector = jax_dataclasses.field(
Expand All @@ -202,6 +203,12 @@ class PhysicsModelState(JaxsimDataclass):
base_angular_velocity: jtp.Vector = jax_dataclasses.field(
default_factory=lambda: jnp.zeros(3)
)
base_linear_acceleration: jtp.Vector = jax_dataclasses.field(
default_factory=lambda: jnp.zeros(3)
)
base_angular_acceleration: jtp.Vector = jax_dataclasses.field(
default_factory=lambda: jnp.zeros(3)
)

def __hash__(self) -> int:

Expand All @@ -211,10 +218,13 @@ def __hash__(self) -> int:
(
HashedNumpyArray.hash_of_array(self.joint_positions),
HashedNumpyArray.hash_of_array(self.joint_velocities),
HashedNumpyArray.hash_of_array(self.joint_accelerations),
HashedNumpyArray.hash_of_array(self.base_position),
HashedNumpyArray.hash_of_array(self.base_quaternion),
HashedNumpyArray.hash_of_array(self.base_linear_velocity),
HashedNumpyArray.hash_of_array(self.base_angular_velocity),
HashedNumpyArray.hash_of_array(self.base_linear_acceleration),
HashedNumpyArray.hash_of_array(self.base_angular_acceleration),
)
)

Expand All @@ -230,10 +240,13 @@ def build_from_jaxsim_model(
model: js.model.JaxSimModel | None = None,
joint_positions: jtp.Vector | None = None,
joint_velocities: jtp.Vector | None = None,
joint_accelerations: jtp.Vector | None = None,
base_position: jtp.Vector | None = None,
base_quaternion: jtp.Vector | None = None,
base_linear_velocity: jtp.Vector | None = None,
base_angular_velocity: jtp.Vector | None = None,
base_linear_acceleration: jtp.Vector | None = None,
base_angular_acceleration: jtp.Vector | None = None,
) -> PhysicsModelState:
"""
Build a `PhysicsModelState` from a `JaxSimModel`.
Expand All @@ -242,12 +255,17 @@ def build_from_jaxsim_model(
model: The `JaxSimModel` associated with the state.
joint_positions: The vector of joint positions.
joint_velocities: The vector of joint velocities.
joint_accelerations: The vector of joint accelerations.
base_position: The 3D position of the base link.
base_quaternion: The quaternion defining the orientation of the base link.
base_linear_velocity:
The linear velocity of the base link in inertial-fixed representation.
base_angular_velocity:
The angular velocity of the base link in inertial-fixed representation.
base_linear_acceleration:
The linear acceleration of the base link in inertial-fixed representation.
base_angular_acceleration:
The angular acceleration of the base link in inertial-fixed representation
Note:
If any of the state components are not provided, they are built from the
Expand All @@ -260,21 +278,27 @@ def build_from_jaxsim_model(
return PhysicsModelState.build(
joint_positions=joint_positions,
joint_velocities=joint_velocities,
joint_accelerations=joint_accelerations,
base_position=base_position,
base_quaternion=base_quaternion,
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
base_linear_acceleration=base_linear_acceleration,
base_angular_acceleration=base_angular_acceleration,
number_of_dofs=model.dofs(),
)

@staticmethod
def build(
joint_positions: jtp.Vector | None = None,
joint_velocities: jtp.Vector | None = None,
joint_accelerations: jtp.Vector | None = None,
base_position: jtp.Vector | None = None,
base_quaternion: jtp.Vector | None = None,
base_linear_velocity: jtp.Vector | None = None,
base_angular_velocity: jtp.Vector | None = None,
base_linear_acceleration: jtp.Vector | None = None,
base_angular_acceleration: jtp.Vector | None = None,
number_of_dofs: jtp.Int | None = None,
) -> PhysicsModelState:
"""
Expand All @@ -283,12 +307,17 @@ def build(
Args:
joint_positions: The vector of joint positions.
joint_velocities: The vector of joint velocities.
joint_accelerations: The vector of joint accelerations.
base_position: The 3D position of the base link.
base_quaternion: The quaternion defining the orientation of the base link.
base_linear_velocity:
The linear velocity of the base link in inertial-fixed representation.
base_angular_velocity:
The angular velocity of the base link in inertial-fixed representation.
base_linear_acceleration:
The linear acceleration of the base link in inertial-fixed representation.
base_angular_acceleration:
The angular acceleration of the base link in inertial-fixed representation
number_of_dofs:
The number of degrees of freedom of the physics model.
Expand All @@ -308,6 +337,12 @@ def build(
else jnp.zeros(number_of_dofs)
)

joint_accelerations = (
joint_accelerations
if joint_accelerations is not None
else jnp.zeros(number_of_dofs)
)

base_position = base_position if base_position is not None else jnp.zeros(3)

base_quaternion = (
Expand All @@ -324,13 +359,28 @@ def build(
base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
)

base_linear_acceleration = (
base_linear_acceleration
if base_linear_acceleration is not None
else jnp.zeros(3)
)

base_angular_acceleration = (
base_angular_acceleration
if base_angular_acceleration is not None
else jnp.zeros(3)
)

physics_model_state = PhysicsModelState(
joint_positions=jnp.array(joint_positions, dtype=float),
joint_velocities=jnp.array(joint_velocities, dtype=float),
joint_accelerations=jnp.array(joint_accelerations, dtype=float),
base_position=jnp.array(base_position, dtype=float),
base_quaternion=jnp.array(base_quaternion, dtype=float),
base_linear_velocity=jnp.array(base_linear_velocity, dtype=float),
base_angular_velocity=jnp.array(base_angular_velocity, dtype=float),
base_linear_acceleration=jnp.array(base_linear_acceleration, dtype=float),
base_angular_acceleration=jnp.array(base_angular_acceleration, dtype=float),
)

# TODO (diegoferigo): assert state.valid(physics_model)
Expand Down Expand Up @@ -371,6 +421,12 @@ def valid(self, model: js.model.JaxSimModel) -> bool:
shape = self.joint_velocities.shape
expected_shape = (model.dofs(),)

if shape != expected_shape:
return False

shape = self.joint_accelerations.shape
expected_shape = (model.dofs(),)

if shape != expected_shape:
return False

Expand All @@ -395,6 +451,18 @@ def valid(self, model: js.model.JaxSimModel) -> bool:
shape = self.base_angular_velocity.shape
expected_shape = (3,)

if shape != expected_shape:
return False

shape = self.base_linear_acceleration.shape
expected_shape = (3,)

if shape != expected_shape:
return False

shape = self.base_angular_acceleration.shape
expected_shape = (3,)

if shape != expected_shape:
return False

Expand Down

0 comments on commit c9c0651

Please sign in to comment.