Skip to content

Commit

Permalink
[WIP] Save some kindyn computation in JaxSimModelData
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 26, 2024
1 parent 6c75ad7 commit 51443df
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 36 deletions.
16 changes: 3 additions & 13 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,9 +511,7 @@ def jacobian(
)[indices_of_enabled_collidable_points]

# Compute the Jacobians of all links.
W_J_WL = js.model.generalized_free_floating_jacobian(
model=model, data=data, output_vel_repr=VelRepr.Inertial
)
W_J_WL = data.kyn_dyn.jacobian

# Compute the contact Jacobian.
# In inertial-fixed output representation, the Jacobian of the parent link is also
Expand Down Expand Up @@ -663,17 +661,9 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix:

with data.switch_velocity_representation(VelRepr.Inertial):
# Compute the Jacobian of the parent link in inertial representation.
W_J_WL_W = js.model.generalized_free_floating_jacobian(
model=model,
data=data,
output_vel_repr=VelRepr.Inertial,
)
W_J_WL_W = data.kyn_dyn.jacobian
# Compute the Jacobian derivative of the parent link in inertial representation.
W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative(
model=model,
data=data,
output_vel_repr=VelRepr.Inertial,
)
W_J̇_WL_W = data.kyn_dyn.jacobian_derivative

# Get the Jacobian of the enabled collidable points in the mixed representation.
with data.switch_velocity_representation(VelRepr.Mixed):
Expand Down
18 changes: 18 additions & 0 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,22 @@ class KynDynComputation(common.ModelDataWithVelocityRepresentation):
mass_matrix: jtp.Matrix


@jax_dataclasses.pytree_dataclass
class KynDynComputation:

jacobian: jtp.Matrix

jacobian_derivative: jtp.Matrix

motion_subspaces: jtp.Matrix

joint_transforms: jtp.Matrix

mass_matrix: jtp.Matrix

forward_kinematics: jtp.Matrix


@jax_dataclasses.pytree_dataclass
class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
"""
Expand All @@ -162,6 +178,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):

state: ODEState

kyn_dyn: KynDynComputation

gravity: jtp.Vector

contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False)
Expand Down
4 changes: 1 addition & 3 deletions src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,7 @@ def jacobian_derivative(
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)

O_J̇_WL_I = js.model.generalized_free_floating_jacobian_derivative(
model=model, data=data, output_vel_repr=output_vel_repr
)[link_index]
O_J̇_WL_I = data.kyn_dyn.jacobian_derivative[link_index]

return O_J̇_WL_I

Expand Down
25 changes: 22 additions & 3 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ def build(

integrator_cls = integrator
integrator = integrator_cls.build(
dynamics=js.ode.wrap_system_dynamics_for_integration(
system_dynamics=js.ode.system_dynamics
)
# dynamics=js.ode.wrap_system_dynamics_for_integration(
# system_dynamics=js.ode.system_dynamics
# )
)

case _:
Expand Down Expand Up @@ -2178,6 +2178,25 @@ def forward(
joint_force_references: jtp.VectorLike | None = None,
) -> js.data.JaxSimModelData:

# Kinematics computation.
M = js.model.free_floating_mass_matrix(model=model, data=data)
J = js.model.generalized_free_floating_jacobian(model=model, data=data)
= js.model.generalized_free_floating_jacobian_derivative(model=model, data=data)
i_X_λ, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
joint_positions=data.joint_positions(), base_transform=data.base_transform()
)
FK = js.model.forward_kinematics(model=model, data=data)
kyn_dyn = js.data.KynDynComputation(
jacobian=J,
jacobian_derivative=,
joint_transforms=i_X_λ,
motion_subspaces=S,
mass_matrix=M,
forward_kinematics=FK,
)

data = data.replace(kyn_dyn=kyn_dyn)

# TODO: some contact models here may want to perform a dynamic filtering of
# the enabled collidable points.

Expand Down
8 changes: 4 additions & 4 deletions src/jaxsim/integrators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,11 @@ def compute_ki() -> tuple[jax.Array, dict[str, Any]]:
ti = t0 + c[i] * Δt

# Evaluate the dynamics.
ki, aux_dict = f(x=xi, t=ti)
return ki, aux_dict
ki = f(x=xi, t=ti)
return ki

# This selector enables FSAL property in the first iteration (i=0).
ki, aux_dict = jax.lax.cond(
ki = jax.lax.cond(
pred=jnp.logical_and(i == 0, self.has_fsal),
true_fun=lambda: x0,
false_fun=compute_ki,
Expand All @@ -357,7 +357,7 @@ def compute_ki() -> tuple[jax.Array, dict[str, Any]]:
K = jax.tree.map(op, K, ki)

carry = K
return carry, aux_dict
return carry, None

# Compute the state derivatives kᵢ.
K, _ = jax.lax.scan(
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def compute_contact_forces(
)
)

M = js.model.free_floating_mass_matrix(model=model, data=data)
M = data.kyn_dyn.mass_matrix

Jl_WC = jnp.vstack(
jax.vmap(lambda J, δ: J * (δ > 0))(
Expand Down
12 changes: 0 additions & 12 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,19 +201,8 @@ def run_simulation(
return data


@pytest.mark.parametrize(
"integrator",
[
jaxsim.integrators.fixed_step.ForwardEuler,
jaxsim.integrators.fixed_step.ForwardEulerSO3,
jaxsim.integrators.fixed_step.RungeKutta4,
jaxsim.integrators.fixed_step.RungeKutta4SO3,
jaxsim.integrators.variable_step.BogackiShampineSO3,
],
)
def test_simulation_with_soft_contacts(
jaxsim_model_box: js.model.JaxSimModel,
integrator: jaxsim.integrators.Integrator,
):

model = jaxsim_model_box
Expand All @@ -229,7 +218,6 @@ def test_simulation_with_soft_contacts(
model.kin_dyn_parameters.contact_parameters.enabled = tuple(
enabled_collidable_points_mask.tolist()
)
model.integrator = integrator.build()

assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4

Expand Down

0 comments on commit 51443df

Please sign in to comment.