Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add relaxed rigid contacts model #223

Merged
merged 29 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d15149f
Remove unhashable member in set
flferretti Aug 13, 2024
7d41b29
Unify variable names in `terrain` module
flferretti Aug 13, 2024
10695c0
Fix return types
flferretti Aug 13, 2024
31ea9b2
Add `QuasiRigid` contact model
flferretti Aug 13, 2024
b04051a
Add the new soft contacts to `api.contact`
flferretti Aug 14, 2024
057f699
Normalize base quaternion when set by the user
flferretti Aug 14, 2024
b34909c
Add the new soft contacts to `api.ode`
flferretti Aug 14, 2024
663a2ea
Add the new soft contacts to `api.ode_data`
flferretti Aug 14, 2024
870bf68
Add `jaxopt` dependency
flferretti Aug 14, 2024
75218c1
Add `QuasiRigid` contact model
flferretti Aug 13, 2024
a399e06
Update typehints to PEP 484
flferretti Aug 22, 2024
f5b72f5
Speed up regularizers computation
flferretti Aug 23, 2024
5f401b7
Avoid to use `set` literal
flferretti Aug 26, 2024
31e5097
Apply suggestions from code review
flferretti Sep 6, 2024
26d0f47
Handle 1D penetrations instead of 3D collidable positions
flferretti Sep 9, 2024
5eee4ff
Set height and plane normal as private attributes
flferretti Sep 9, 2024
6058775
Expose solver parameters
flferretti Sep 9, 2024
1e40a20
Speed up compatibility computation between pytrees
flferretti Sep 11, 2024
56371ac
Use 3D velocity when warm-starting the optimizer
flferretti Sep 11, 2024
9afbd20
Apply suggestions from code review
flferretti Sep 12, 2024
25596db
Cast builder arguments to `jnp.array`
flferretti Sep 12, 2024
aea3d07
Update docstrings
flferretti Sep 12, 2024
d7d7ec9
Restore `lstsq` for the mass matrix inversion
flferretti Sep 12, 2024
fc88b0c
Fix resetting logic in case of overflow of simulation time
flferretti Sep 13, 2024
a4875f1
Make `ode.system_acceleration` return in the active representation
flferretti Sep 13, 2024
e9ac7ea
Drop `_compute_mixed_nu_dot_free` in favor of `ode.system_acceleration`
flferretti Sep 13, 2024
71ecac2
Fix system acceleration representation in `relaxed_rigid_contacts`
flferretti Sep 13, 2024
a09bc0d
Update variable names in `ode.system_acceleration`
flferretti Sep 13, 2024
42aed4b
Add reference sources and update comments
flferretti Sep 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies:
- python >= 3.12.0
- coloredlogs
- jax >= 0.4.13
- jaxopt >= 0.8.0
- jaxlib >= 0.4.13
- jaxlie >= 1.3.0
- jax-dataclasses >= 1.4.0
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ classifiers = [
dependencies = [
"coloredlogs",
"jax >= 0.4.13",
"jaxopt >= 0.8.0",
"jaxlib >= 0.4.13",
"jaxlie >= 1.3.0",
"jax_dataclasses >= 1.4.0",
Expand Down Expand Up @@ -181,6 +182,7 @@ platforms = ["linux-64", "osx-arm64", "osx-64"]
coloredlogs = "*"
jax = "*"
jax-dataclasses = "*"
jaxopt = "*"
jaxlib = "*"
jaxlie = "*"
lxml = "*"
Expand Down
28 changes: 27 additions & 1 deletion src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def collidable_point_dynamics(
Returns:
The 6D force applied to each collidable point and additional data based on the contact model configured:
- Soft: the material deformation rate.
- Rigid: nothing.
- Rigid: no additional data.
- QuasiRigid: no additional data.

Note:
The material deformation rate is always returned in the mixed frame
Expand All @@ -144,6 +145,10 @@ def collidable_point_dynamics(
W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)

# Import privately the contacts classes.
from jaxsim.rbda.contacts.relaxed_rigid import (
RelaxedRigidContacts,
RelaxedRigidContactsState,
)
from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState

Expand Down Expand Up @@ -190,6 +195,27 @@ def collidable_point_dynamics(

aux_data = dict()

case RelaxedRigidContacts():
assert isinstance(model.contact_model, RelaxedRigidContacts)
assert isinstance(data.state.contact, RelaxedRigidContactsState)

# Build the contact model.
relaxed_rigid_contacts = RelaxedRigidContacts(
parameters=data.contacts_params, terrain=model.terrain
)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point.
W_f_Ci, _ = relaxed_rigid_contacts.compute_contact_forces(
position=W_p_Ci,
velocity=W_ṗ_Ci,
model=model,
data=data,
link_forces=link_forces,
)

aux_data = dict()

case _:
raise ValueError(f"Invalid contact model {model.contact_model}")

Expand Down
14 changes: 8 additions & 6 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,16 +593,18 @@ def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:
The updated `JaxSimModelData` object.
"""

base_quaternion = jnp.array(base_quaternion)
W_Q_B = jnp.array(base_quaternion, dtype=float)

W_Q_B = jax.lax.select(
pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
on_true=W_Q_B,
on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
)

return self.replace(
validate=True,
state=self.state.replace(
physics_model=self.state.physics_model.replace(
base_quaternion=jnp.atleast_1d(base_quaternion.squeeze()).astype(
float
)
)
physics_model=self.state.physics_model.replace(base_quaternion=W_Q_B)
),
)

Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,7 +1935,7 @@ def step(
tf_ns = jnp.where(tf_ns >= t0_ns, tf_ns, jnp.array(0, dtype=t0_ns.dtype))

jax.lax.cond(
pred=tf_ns >= t0_ns,
pred=tf_ns < t0_ns,
true_fun=lambda: jax.debug.print(
"The simulation time overflowed, resetting simulation time to 0."
),
Expand Down
43 changes: 19 additions & 24 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,15 @@ def system_velocity_dynamics(
forces=W_f_Li_terrain,
additive=True,
)
# Get the link forces in the data representation
with references.switch_velocity_representation(data.velocity_representation):

# Get the link forces in inertial representation
f_L_total = references.link_forces(model=model, data=data)

# The following method always returns the inertial-fixed acceleration, and expects
# the link_forces expressed in the inertial frame.
W_v̇_WB, s̈ = system_acceleration(
model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
)
v̇_WB, s̈ = system_acceleration(
model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
)

return W_v̇_WB, s̈, aux_data
return v̇_WB, s̈, aux_data


def system_acceleration(
Expand All @@ -196,7 +194,7 @@ def system_acceleration(
link_forces: jtp.MatrixLike | None = None,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute the system acceleration in inertial-fixed representation.
Compute the system acceleration in the active representation.

Args:
model: The model to consider.
Expand All @@ -206,7 +204,7 @@ def system_acceleration(
The 6D forces to apply to the links expressed in the same representation of data.

Returns:
A tuple containing the base 6D acceleration in inertial-fixed representation
A tuple containing the base 6D acceleration in in the active representation
and the joint accelerations.
"""

Expand Down Expand Up @@ -272,18 +270,15 @@ def system_acceleration(
)

# - Joint accelerations: s̈ ∈ ℝⁿ
# - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
with (
data.switch_velocity_representation(velocity_representation=VelRepr.Inertial),
references.switch_velocity_representation(VelRepr.Inertial),
):
W_v̇_WB, s̈ = js.model.forward_dynamics_aba(
model=model,
data=data,
joint_forces=references.joint_force_references(),
link_forces=references.link_forces(),
)
return W_v̇_WB, s̈
# - Base acceleration: v̇_WB ∈ ℝ⁶
v̇_WB, s̈ = js.model.forward_dynamics_aba(
model=model,
data=data,
joint_forces=references.joint_force_references(model=model),
link_forces=references.link_forces(model=model, data=data),
)

return v̇_WB, s̈


@jax.jit
Expand Down Expand Up @@ -353,7 +348,7 @@ def system_dynamics(
corresponding derivative, and the dictionary of auxiliary data returned
by the system dynamics evaluation.
"""

from jaxsim.rbda.contacts.relaxed_rigid import RelaxedRigidContacts
from jaxsim.rbda.contacts.rigid import RigidContacts
from jaxsim.rbda.contacts.soft import SoftContacts

Expand All @@ -371,7 +366,7 @@ def system_dynamics(
case SoftContacts():
ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]

case RigidContacts():
case RigidContacts() | RelaxedRigidContacts():
pass

case _:
Expand Down
12 changes: 11 additions & 1 deletion src/jaxsim/api/ode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.rbda import ContactsState
from jaxsim.rbda.contacts.relaxed_rigid import (
RelaxedRigidContacts,
RelaxedRigidContactsState,
)
from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
from jaxsim.utils import JaxsimDataclass
Expand Down Expand Up @@ -173,6 +177,10 @@ def build_from_jaxsim_model(
)
case RigidContacts():
contact = RigidContactsState.build()

case RelaxedRigidContacts():
contact = RelaxedRigidContactsState.build()

case _:
raise ValueError("Unable to determine contact state class prefix.")

Expand Down Expand Up @@ -216,7 +224,9 @@ def build(

# Get the contact model from the `JaxSimModel`.
match contact:
case SoftContactsState() | RigidContactsState():
case (
SoftContactsState() | RigidContactsState() | RelaxedRigidContactsState()
):
pass
case None:
contact = SoftContactsState.zero(model=model)
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/integrators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def butcher_tableau_supports_fsal(
b: jtp.Matrix,
c: jtp.Vector,
index_of_solution: jtp.IntLike = 0,
) -> [bool, int | None]:
) -> tuple[bool, int | None]:
"""
Check if the Butcher tableau supports the FSAL (first-same-as-last) property.

Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/math/inertia.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def to_params(M: jtp.Matrix) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]:
M (jtp.Matrix): The 6x6 inertia matrix.

Returns:
Tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).

Raises:
ValueError: If the input matrix M has an unexpected shape.
Expand Down
6 changes: 3 additions & 3 deletions src/jaxsim/mujoco/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def convert(
joints_dict = {j.name: j for j in rod_model.joints()}

# Convert all the joints not considered to fixed joints.
for joint_name in set(j.name for j in rod_model.joints()) - considered_joints:
for joint_name in {j.name for j in rod_model.joints()} - considered_joints:
joints_dict[joint_name].type = "fixed"

# Convert the ROD model to URDF.
Expand Down Expand Up @@ -289,10 +289,10 @@ def convert(
mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets)

# Get the joint names.
mj_joint_names = set(
mj_joint_names = {
mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx)
for idx in range(mj_model.njnt)
)
}

# Check that the Mujoco model only has the considered joints.
if mj_joint_names != considered_joints:
Expand Down
6 changes: 3 additions & 3 deletions src/jaxsim/parsers/kinematic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def reduce(self, considered_joints: Sequence[str]) -> KinematicGraph:
return copy.deepcopy(self)

# Check if all considered joints are part of the full kinematic graph
if len(set(considered_joints) - set(j.name for j in full_graph.joints)) != 0:
if len(set(considered_joints) - {j.name for j in full_graph.joints}) != 0:
extra_j = set(considered_joints) - {j.name for j in full_graph.joints}
msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})"
raise ValueError(msg)
Expand Down Expand Up @@ -536,8 +536,8 @@ def reduce(self, considered_joints: Sequence[str]) -> KinematicGraph:
root_link_name=full_graph.root.name,
)

assert set(f.name for f in self.frames).isdisjoint(
set(f.name for f in unconnected_frames + reduced_frames)
assert {f.name for f in self.frames}.isdisjoint(
{f.name for f in unconnected_frames + reduced_frames}
)

for link in unconnected_links:
Expand Down
Loading