Skip to content

Commit

Permalink
Save some kindyn computation in JaxSimModelData
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 26, 2024
1 parent b1a6c53 commit 21f9e53
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 25 deletions.
3 changes: 0 additions & 3 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,21 +666,18 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix:
W_J_WL_W = js.model.generalized_free_floating_jacobian(
model=model,
data=data,
output_vel_repr=VelRepr.Inertial,
)
# Compute the Jacobian derivative of the parent link in inertial representation.
W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative(
model=model,
data=data,
output_vel_repr=VelRepr.Inertial,
)

# Get the Jacobian of the enabled collidable points in the mixed representation.
with data.switch_velocity_representation(VelRepr.Mixed):
CW_J_WC_BW = jacobian(
model=model,
data=data,
output_vel_repr=VelRepr.Mixed,
)

def compute_O_J̇_WC_I(
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def jacobian_derivative(
)

O_J̇_WL_I = js.model.generalized_free_floating_jacobian_derivative(
model=model, data=data, output_vel_repr=output_vel_repr
model=model, data=data
)[link_index]

return O_J̇_WL_I
Expand Down
8 changes: 5 additions & 3 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ def build(

integrator_cls = integrator
integrator = integrator_cls.build(
dynamics=js.ode.wrap_system_dynamics_for_integration(
system_dynamics=js.ode.system_dynamics
)
# dynamics=js.ode.wrap_system_dynamics_for_integration(
# system_dynamics=js.ode.system_dynamics
# )
)

case _:
Expand Down Expand Up @@ -2218,6 +2218,8 @@ def forward(

data_tf = data.replace(state=state_tf)

data_tf = data_tf.update_kyn_dyn(model=model)

return data_tf


Expand Down
8 changes: 4 additions & 4 deletions src/jaxsim/integrators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,11 @@ def compute_ki() -> tuple[jax.Array, dict[str, Any]]:
ti = t0 + c[i] * Δt

# Evaluate the dynamics.
ki, aux_dict = f(x=xi, t=ti)
return ki, aux_dict
ki = f(x=xi, t=ti)
return ki

# This selector enables FSAL property in the first iteration (i=0).
ki, aux_dict = jax.lax.cond(
ki = jax.lax.cond(
pred=jnp.logical_and(i == 0, self.has_fsal),
true_fun=lambda: x0,
false_fun=compute_ki,
Expand All @@ -357,7 +357,7 @@ def compute_ki() -> tuple[jax.Array, dict[str, Any]]:
K = jax.tree.map(op, K, ki)

carry = K
return carry, aux_dict
return carry, None

# Compute the state derivatives kᵢ.
K, _ = jax.lax.scan(
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def compute_contact_forces(
)
)

M = js.model.free_floating_mass_matrix(model=model, data=data)
M = data.kyn_dyn.mass_matrix

Jl_WC = jnp.vstack(
jax.vmap(lambda J, δ: J * (δ > 0))(
Expand Down
3 changes: 2 additions & 1 deletion tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,13 +417,14 @@ def step(
base_linear_velocity=W_v_WB[0:3],
base_angular_velocity=W_v_WB[3:6],
joint_velocities=,
number_of_dofs=len(),
),
extended_state={"tangential_deformation": m},
),
)

# Update the kyn_dyn cache.
data_x0.update_kyn_dyn(model=model)
data_x0 = data_x0.update_kyn_dyn(model=model)

data_xf, _ = js.model.step(
model=model,
Expand Down
12 changes: 0 additions & 12 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,19 +201,8 @@ def run_simulation(
return data


@pytest.mark.parametrize(
"integrator",
[
jaxsim.integrators.fixed_step.ForwardEuler,
jaxsim.integrators.fixed_step.ForwardEulerSO3,
jaxsim.integrators.fixed_step.RungeKutta4,
jaxsim.integrators.fixed_step.RungeKutta4SO3,
jaxsim.integrators.variable_step.BogackiShampineSO3,
],
)
def test_simulation_with_soft_contacts(
jaxsim_model_box: js.model.JaxSimModel,
integrator: jaxsim.integrators.Integrator,
):

model = jaxsim_model_box
Expand All @@ -229,7 +218,6 @@ def test_simulation_with_soft_contacts(
model.kin_dyn_parameters.contact_parameters.enabled = tuple(
enabled_collidable_points_mask.tolist()
)
model.integrator = integrator.build()

assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4

Expand Down

0 comments on commit 21f9e53

Please sign in to comment.