Skip to content

Commit ff611d9

Browse files
authored
Merge pull request #269 from ami-iit/update_contact_models
Update contact models
2 parents edc4a35 + 103750e commit ff611d9

File tree

8 files changed

+474
-254
lines changed

8 files changed

+474
-254
lines changed

src/jaxsim/api/contact.py

+48-77
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def collidable_point_forces(
9898
data: js.data.JaxSimModelData,
9999
link_forces: jtp.MatrixLike | None = None,
100100
joint_force_references: jtp.VectorLike | None = None,
101+
**kwargs,
101102
) -> jtp.Matrix:
102103
"""
103104
Compute the 6D forces applied to each collidable point.
@@ -110,6 +111,7 @@ def collidable_point_forces(
110111
representation of data.
111112
joint_force_references:
112113
The joint force references to apply to the joints.
114+
kwargs: Additional keyword arguments to pass to the active contact model.
113115
114116
Returns:
115117
The 6D forces applied to each collidable point expressed in the frame
@@ -121,6 +123,7 @@ def collidable_point_forces(
121123
data=data,
122124
link_forces=link_forces,
123125
joint_force_references=joint_force_references,
126+
**kwargs,
124127
)
125128

126129
return f_Ci
@@ -132,7 +135,8 @@ def collidable_point_dynamics(
132135
data: js.data.JaxSimModelData,
133136
link_forces: jtp.MatrixLike | None = None,
134137
joint_force_references: jtp.VectorLike | None = None,
135-
) -> tuple[jtp.Matrix, dict[str, jtp.Array]]:
138+
**kwargs,
139+
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
136140
r"""
137141
Compute the 6D force applied to each collidable point.
138142
@@ -144,6 +148,7 @@ def collidable_point_dynamics(
144148
representation of data.
145149
joint_force_references:
146150
The joint force references to apply to the joints.
151+
kwargs: Additional keyword arguments to pass to the active contact model.
147152
148153
Returns:
149154
The 6D force applied to each collidable point and additional data based
@@ -158,86 +163,46 @@ def collidable_point_dynamics(
158163
Instead, the 6D forces are returned in the active representation.
159164
"""
160165

161-
# Build the soft contact model.
166+
# Build the common kw arguments to pass to the computation of the contact forces.
167+
common_kwargs = dict(
168+
link_forces=link_forces,
169+
joint_force_references=joint_force_references,
170+
)
171+
172+
# Build the additional kwargs to pass to the computation of the contact forces.
162173
match model.contact_model:
163174

164175
case contacts.SoftContacts():
165-
assert isinstance(model.contact_model, contacts.SoftContacts)
166176

167-
# Compute the 6D force expressed in the inertial frame and applied to each
168-
# collidable point, and the corresponding material deformation rate.
169-
# Note that the material deformation rate is always returned in the mixed frame
170-
# C[W] = (W_p_C, [W]). This is convenient for integration purpose.
171-
W_f_Ci, (CW_ṁ,) = model.contact_model.compute_contact_forces(
172-
model=model, data=data
173-
)
174-
175-
# Create the dictionary of auxiliary data.
176-
# This contact model considers the material deformation as additional state
177-
# of the ODE system. We need to pass its dynamics to the integrator.
178-
aux_data = dict(m_dot=CW_ṁ)
177+
kwargs_contact_model = {}
179178

180179
case contacts.RigidContacts():
181-
assert isinstance(model.contact_model, contacts.RigidContacts)
182180

183-
# Compute the 6D force expressed in the inertial frame and applied to each
184-
# collidable point.
185-
W_f_Ci, _ = model.contact_model.compute_contact_forces(
186-
model=model,
187-
data=data,
188-
link_forces=link_forces,
189-
joint_force_references=joint_force_references,
190-
)
191-
192-
aux_data = dict()
181+
kwargs_contact_model = common_kwargs | kwargs
193182

194183
case contacts.RelaxedRigidContacts():
195-
assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)
196184

197-
# Compute the 6D force expressed in the inertial frame and applied to each
198-
# collidable point.
199-
W_f_Ci, _ = model.contact_model.compute_contact_forces(
200-
model=model,
201-
data=data,
202-
link_forces=link_forces,
203-
joint_force_references=joint_force_references,
204-
)
205-
206-
aux_data = dict()
185+
kwargs_contact_model = common_kwargs | kwargs
207186

208187
case contacts.ViscoElasticContacts():
209-
assert isinstance(model.contact_model, contacts.ViscoElasticContacts)
210188

211-
# It is not yet clear how to pass the time step to this stage.
212-
# A possibility is to restrict the integrator to only forward Euler
213-
# and store the Δt inside the model.
214-
module = jaxsim.rbda.contacts.visco_elastic.step.__module__
215-
name = jaxsim.rbda.contacts.visco_elastic.step.__name__
216-
msg = "You need to use the custom '{}.{}' function with this contact model."
217-
jaxsim.exceptions.raise_runtime_error_if(
218-
condition=True, msg=msg.format(module, name)
219-
)
220-
221-
# Compute the 6D force expressed in the inertial frame and applied to each
222-
# collidable point.
223-
W_f_Ci, (W_f̿_Ci, m_tf) = model.contact_model.compute_contact_forces(
224-
model=model,
225-
data=data,
226-
dt=None, # TODO
227-
link_forces=link_forces,
228-
joint_force_references=joint_force_references,
229-
)
230-
231-
aux_data = dict(W_f_avg2_C=W_f̿_Ci, m_tf=m_tf)
189+
kwargs_contact_model = common_kwargs | dict(dt=model.time_step) | kwargs
232190

233191
case _:
234-
raise ValueError(f"Invalid contact model {model.contact_model}")
192+
raise ValueError(f"Invalid contact model: {model.contact_model}")
193+
194+
# Compute the contact forces with the active contact model.
195+
W_f_C, aux_data = model.contact_model.compute_contact_forces(
196+
model=model,
197+
data=data,
198+
**kwargs_contact_model,
199+
)
235200

236201
# Compute the transforms of the implicit frames `C[L] = (W_p_C, [L])`
237202
# associated to each collidable point.
238203
# In inertial-fixed representation, the computation of these transforms
239204
# is not necessary and the conversion below becomes a no-op.
240-
W_H_Ci = (
205+
W_H_C = (
241206
js.contact.transforms(model=model, data=data)
242207
if data.velocity_representation is not VelRepr.Inertial
243208
else jnp.zeros(
@@ -253,7 +218,7 @@ def collidable_point_dynamics(
253218
transform=W_H_C,
254219
is_force=True,
255220
)
256-
)(W_f_Ci, W_H_Ci)
221+
)(W_f_C, W_H_C)
257222

258223
return f_Ci, aux_data
259224

@@ -392,11 +357,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
392357
max_penetration=max_δ,
393358
number_of_active_collidable_points_steady_state=nc,
394359
damping_ratio=damping_ratio,
395-
**dict(
396-
p=model.contact_model.parameters.p,
397-
q=model.contact_model.parameters.q,
398-
)
399-
| kwargs,
360+
**(
361+
dict(
362+
p=model.contact_model.parameters.p,
363+
q=model.contact_model.parameters.q,
364+
)
365+
| kwargs
366+
),
400367
)
401368

402369
case contacts.ViscoElasticContacts():
@@ -410,11 +377,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
410377
max_penetration=max_δ,
411378
number_of_active_collidable_points_steady_state=nc,
412379
damping_ratio=damping_ratio,
413-
**dict(
414-
p=model.contact_model.parameters.p,
415-
q=model.contact_model.parameters.q,
416-
)
417-
| kwargs,
380+
**(
381+
dict(
382+
p=model.contact_model.parameters.p,
383+
q=model.contact_model.parameters.q,
384+
)
385+
| kwargs
386+
),
418387
)
419388
)
420389

@@ -427,11 +396,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
427396

428397
parameters = contacts.RigidContactsParams.build(
429398
mu=static_friction_coefficient,
430-
**dict(
431-
K=K,
432-
D=2 * jnp.sqrt(K),
433-
)
434-
| kwargs,
399+
**(
400+
dict(
401+
K=K,
402+
D=2 * jnp.sqrt(K),
403+
)
404+
| kwargs
405+
),
435406
)
436407

437408
case contacts.RelaxedRigidContacts():

0 commit comments

Comments
 (0)