@@ -98,6 +98,7 @@ def collidable_point_forces(
98
98
data : js .data .JaxSimModelData ,
99
99
link_forces : jtp .MatrixLike | None = None ,
100
100
joint_force_references : jtp .VectorLike | None = None ,
101
+ ** kwargs ,
101
102
) -> jtp .Matrix :
102
103
"""
103
104
Compute the 6D forces applied to each collidable point.
@@ -110,6 +111,7 @@ def collidable_point_forces(
110
111
representation of data.
111
112
joint_force_references:
112
113
The joint force references to apply to the joints.
114
+ kwargs: Additional keyword arguments to pass to the active contact model.
113
115
114
116
Returns:
115
117
The 6D forces applied to each collidable point expressed in the frame
@@ -121,6 +123,7 @@ def collidable_point_forces(
121
123
data = data ,
122
124
link_forces = link_forces ,
123
125
joint_force_references = joint_force_references ,
126
+ ** kwargs ,
124
127
)
125
128
126
129
return f_Ci
@@ -132,7 +135,8 @@ def collidable_point_dynamics(
132
135
data : js .data .JaxSimModelData ,
133
136
link_forces : jtp .MatrixLike | None = None ,
134
137
joint_force_references : jtp .VectorLike | None = None ,
135
- ) -> tuple [jtp .Matrix , dict [str , jtp .Array ]]:
138
+ ** kwargs ,
139
+ ) -> tuple [jtp .Matrix , dict [str , jtp .PyTree ]]:
136
140
r"""
137
141
Compute the 6D force applied to each collidable point.
138
142
@@ -144,6 +148,7 @@ def collidable_point_dynamics(
144
148
representation of data.
145
149
joint_force_references:
146
150
The joint force references to apply to the joints.
151
+ kwargs: Additional keyword arguments to pass to the active contact model.
147
152
148
153
Returns:
149
154
The 6D force applied to each collidable point and additional data based
@@ -158,86 +163,46 @@ def collidable_point_dynamics(
158
163
Instead, the 6D forces are returned in the active representation.
159
164
"""
160
165
161
- # Build the soft contact model.
166
+ # Build the common kw arguments to pass to the computation of the contact forces.
167
+ common_kwargs = dict (
168
+ link_forces = link_forces ,
169
+ joint_force_references = joint_force_references ,
170
+ )
171
+
172
+ # Build the additional kwargs to pass to the computation of the contact forces.
162
173
match model .contact_model :
163
174
164
175
case contacts .SoftContacts ():
165
- assert isinstance (model .contact_model , contacts .SoftContacts )
166
176
167
- # Compute the 6D force expressed in the inertial frame and applied to each
168
- # collidable point, and the corresponding material deformation rate.
169
- # Note that the material deformation rate is always returned in the mixed frame
170
- # C[W] = (W_p_C, [W]). This is convenient for integration purpose.
171
- W_f_Ci , (CW_ṁ ,) = model .contact_model .compute_contact_forces (
172
- model = model , data = data
173
- )
174
-
175
- # Create the dictionary of auxiliary data.
176
- # This contact model considers the material deformation as additional state
177
- # of the ODE system. We need to pass its dynamics to the integrator.
178
- aux_data = dict (m_dot = CW_ṁ )
177
+ kwargs_contact_model = {}
179
178
180
179
case contacts .RigidContacts ():
181
- assert isinstance (model .contact_model , contacts .RigidContacts )
182
180
183
- # Compute the 6D force expressed in the inertial frame and applied to each
184
- # collidable point.
185
- W_f_Ci , _ = model .contact_model .compute_contact_forces (
186
- model = model ,
187
- data = data ,
188
- link_forces = link_forces ,
189
- joint_force_references = joint_force_references ,
190
- )
191
-
192
- aux_data = dict ()
181
+ kwargs_contact_model = common_kwargs | kwargs
193
182
194
183
case contacts .RelaxedRigidContacts ():
195
- assert isinstance (model .contact_model , contacts .RelaxedRigidContacts )
196
184
197
- # Compute the 6D force expressed in the inertial frame and applied to each
198
- # collidable point.
199
- W_f_Ci , _ = model .contact_model .compute_contact_forces (
200
- model = model ,
201
- data = data ,
202
- link_forces = link_forces ,
203
- joint_force_references = joint_force_references ,
204
- )
205
-
206
- aux_data = dict ()
185
+ kwargs_contact_model = common_kwargs | kwargs
207
186
208
187
case contacts .ViscoElasticContacts ():
209
- assert isinstance (model .contact_model , contacts .ViscoElasticContacts )
210
188
211
- # It is not yet clear how to pass the time step to this stage.
212
- # A possibility is to restrict the integrator to only forward Euler
213
- # and store the Δt inside the model.
214
- module = jaxsim .rbda .contacts .visco_elastic .step .__module__
215
- name = jaxsim .rbda .contacts .visco_elastic .step .__name__
216
- msg = "You need to use the custom '{}.{}' function with this contact model."
217
- jaxsim .exceptions .raise_runtime_error_if (
218
- condition = True , msg = msg .format (module , name )
219
- )
220
-
221
- # Compute the 6D force expressed in the inertial frame and applied to each
222
- # collidable point.
223
- W_f_Ci , (W_f̿_Ci , m_tf ) = model .contact_model .compute_contact_forces (
224
- model = model ,
225
- data = data ,
226
- dt = None , # TODO
227
- link_forces = link_forces ,
228
- joint_force_references = joint_force_references ,
229
- )
230
-
231
- aux_data = dict (W_f_avg2_C = W_f̿_Ci , m_tf = m_tf )
189
+ kwargs_contact_model = common_kwargs | dict (dt = model .time_step ) | kwargs
232
190
233
191
case _:
234
- raise ValueError (f"Invalid contact model { model .contact_model } " )
192
+ raise ValueError (f"Invalid contact model: { model .contact_model } " )
193
+
194
+ # Compute the contact forces with the active contact model.
195
+ W_f_C , aux_data = model .contact_model .compute_contact_forces (
196
+ model = model ,
197
+ data = data ,
198
+ ** kwargs_contact_model ,
199
+ )
235
200
236
201
# Compute the transforms of the implicit frames `C[L] = (W_p_C, [L])`
237
202
# associated to each collidable point.
238
203
# In inertial-fixed representation, the computation of these transforms
239
204
# is not necessary and the conversion below becomes a no-op.
240
- W_H_Ci = (
205
+ W_H_C = (
241
206
js .contact .transforms (model = model , data = data )
242
207
if data .velocity_representation is not VelRepr .Inertial
243
208
else jnp .zeros (
@@ -253,7 +218,7 @@ def collidable_point_dynamics(
253
218
transform = W_H_C ,
254
219
is_force = True ,
255
220
)
256
- )(W_f_Ci , W_H_Ci )
221
+ )(W_f_C , W_H_C )
257
222
258
223
return f_Ci , aux_data
259
224
@@ -392,11 +357,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
392
357
max_penetration = max_δ ,
393
358
number_of_active_collidable_points_steady_state = nc ,
394
359
damping_ratio = damping_ratio ,
395
- ** dict (
396
- p = model .contact_model .parameters .p ,
397
- q = model .contact_model .parameters .q ,
398
- )
399
- | kwargs ,
360
+ ** (
361
+ dict (
362
+ p = model .contact_model .parameters .p ,
363
+ q = model .contact_model .parameters .q ,
364
+ )
365
+ | kwargs
366
+ ),
400
367
)
401
368
402
369
case contacts .ViscoElasticContacts ():
@@ -410,11 +377,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
410
377
max_penetration = max_δ ,
411
378
number_of_active_collidable_points_steady_state = nc ,
412
379
damping_ratio = damping_ratio ,
413
- ** dict (
414
- p = model .contact_model .parameters .p ,
415
- q = model .contact_model .parameters .q ,
416
- )
417
- | kwargs ,
380
+ ** (
381
+ dict (
382
+ p = model .contact_model .parameters .p ,
383
+ q = model .contact_model .parameters .q ,
384
+ )
385
+ | kwargs
386
+ ),
418
387
)
419
388
)
420
389
@@ -427,11 +396,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
427
396
428
397
parameters = contacts .RigidContactsParams .build (
429
398
mu = static_friction_coefficient ,
430
- ** dict (
431
- K = K ,
432
- D = 2 * jnp .sqrt (K ),
433
- )
434
- | kwargs ,
399
+ ** (
400
+ dict (
401
+ K = K ,
402
+ D = 2 * jnp .sqrt (K ),
403
+ )
404
+ | kwargs
405
+ ),
435
406
)
436
407
437
408
case contacts .RelaxedRigidContacts ():
0 commit comments