Skip to content

Commit

Permalink
Merge pull request #307 from ami-iit/remove_odeinput_physicsmodelinput
Browse files Browse the repository at this point in the history
Refactor simulation input classes and fix deprecation warnings
  • Loading branch information
flferretti authored Dec 6, 2024
2 parents 43febe5 + 22b3eb7 commit fbb7726
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 261 deletions.
4 changes: 3 additions & 1 deletion src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2198,7 +2198,9 @@ def step(
isinstance(model.contact_model, jaxsim.rbda.contacts.ViscoElasticContacts)
& (
~jnp.allclose(dt, model.time_step)
| ~isinstance(integrator, jaxsim.integrators.fixed_step.ForwardEuler)
| ~int(
isinstance(integrator, jaxsim.integrators.fixed_step.ForwardEuler)
)
)
),
msg=msg.format(module, name),
Expand Down
220 changes: 3 additions & 217 deletions src/jaxsim/api/ode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,108 +10,14 @@
import jaxsim.typing as jtp
from jaxsim.utils import JaxsimDataclass

# =============================================================================
# Define the input and state of the ODE system defining the integrated dynamics
# =============================================================================
# ===================================================================
# Define the state of the ODE system defining the integrated dynamics
# ===================================================================

# Note: the ODE system is the combination of the floating-base dynamics and the
# soft-contacts dynamics.


@jax_dataclasses.pytree_dataclass
class ODEInput(JaxsimDataclass):
"""
The input to the ODE system.
Attributes:
physics_model: The input to the physics model.
"""

physics_model: PhysicsModelInput

@staticmethod
def build_from_jaxsim_model(
model: js.model.JaxSimModel | None = None,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
) -> ODEInput:
"""
Build an `ODEInput` from a `JaxSimModel`.
Args:
model: The `JaxSimModel` associated with the ODE input.
link_forces: The matrix of external forces applied to the links.
joint_force_references: The vector of joint force references.
Returns:
The `ODEInput` built from the `JaxSimModel`.
Note:
If any of the input components are not provided, they are built from the
`JaxSimModel` and initialized to zero.
"""

return ODEInput.build(
physics_model_input=PhysicsModelInput.build_from_jaxsim_model(
model=model,
link_forces=link_forces,
joint_force_references=joint_force_references,
),
model=model,
)

@staticmethod
def build(
physics_model_input: PhysicsModelInput | None = None,
model: js.model.JaxSimModel | None = None,
) -> ODEInput:
"""
Build an `ODEInput` from a `PhysicsModelInput`.
Args:
physics_model_input: The `PhysicsModelInput` associated with the ODE input.
model: The `JaxSimModel` associated with the ODE input.
Returns:
A `ODEInput` instance.
"""

physics_model_input = (
physics_model_input
if physics_model_input is not None
else PhysicsModelInput.zero(model=model)
)

return ODEInput(physics_model=physics_model_input)

@staticmethod
def zero(model: js.model.JaxSimModel) -> ODEInput:
"""
Build a zero `ODEInput` from a `JaxSimModel`.
Args:
model: The `JaxSimModel` associated with the ODE input.
Returns:
A zero `ODEInput` instance.
"""

return ODEInput.build(model=model)

def valid(self, model: js.model.JaxSimModel) -> bool:
"""
Check if the `ODEInput` is valid for a given `JaxSimModel`.
Args:
model: The `JaxSimModel` to validate the `ODEInput` against.
Returns:
`True` if the ODE input is valid for the given model, `False` otherwise.
"""

return self.physics_model.valid(model=model)


@jax_dataclasses.pytree_dataclass
class ODEState(JaxsimDataclass):
"""
Expand Down Expand Up @@ -493,123 +399,3 @@ def valid(self, model: js.model.JaxSimModel) -> bool:
return False

return True


@jax_dataclasses.pytree_dataclass
class PhysicsModelInput(JaxsimDataclass):
"""
Class storing the inputs of the physics model dynamics.
Attributes:
tau: The vector of joint forces.
f_ext: The matrix of external forces applied to the links.
"""

tau: jtp.Vector
f_ext: jtp.Matrix

@staticmethod
def build_from_jaxsim_model(
model: js.model.JaxSimModel | None = None,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
) -> PhysicsModelInput:
"""
Build a `PhysicsModelInput` from a `JaxSimModel`.
Args:
model: The `JaxSimModel` associated with the input.
link_forces: The matrix of external forces applied to the links.
joint_force_references: The vector of joint force references.
Returns:
A `PhysicsModelInput` instance.
Note:
If any of the input components are not provided, they are built from the
`JaxSimModel` and initialized to zero.
"""

return PhysicsModelInput.build(
joint_force_references=joint_force_references,
link_forces=link_forces,
number_of_dofs=model.dofs(),
number_of_links=model.number_of_links(),
)

@staticmethod
def build(
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
number_of_dofs: jtp.Int | None = None,
number_of_links: jtp.Int | None = None,
) -> PhysicsModelInput:
"""
Build a `PhysicsModelInput`.
Args:
link_forces: The matrix of external forces applied to the links.
joint_force_references: The vector of joint force references.
number_of_dofs: The number of degrees of freedom of the model.
number_of_links: The number of links of the model.
Returns:
A `PhysicsModelInput` instance.
"""

joint_force_references = jnp.atleast_1d(
jnp.array(joint_force_references, dtype=float).squeeze()
if joint_force_references is not None
else jnp.zeros(number_of_dofs)
).astype(float)

link_forces = jnp.atleast_2d(
jnp.array(link_forces, dtype=float).squeeze()
if link_forces is not None
else jnp.zeros(shape=(number_of_links, 6))
).astype(float)

return PhysicsModelInput(
tau=joint_force_references,
f_ext=link_forces,
)

@staticmethod
def zero(model: js.model.JaxSimModel) -> PhysicsModelInput:
"""
Build a `PhysicsModelInput` with all components initialized to zero.
Args:
model: The `JaxSimModel` associated with the input.
Returns:
A `PhysicsModelInput` instance.
"""

return PhysicsModelInput.build_from_jaxsim_model(model=model)

def valid(self, model: js.model.JaxSimModel) -> bool:
"""
Check if the `PhysicsModelInput` is valid for a given `JaxSimModel`.
Args:
model: The `JaxSimModel` to validate the `PhysicsModelInput` against.
Returns:
`True` if the `PhysicsModelInput` is valid for the given model,
`False` otherwise.
"""

shape = self.tau.shape
expected_shape = (model.dofs(),)

if shape != expected_shape:
return False

shape = self.f_ext.shape
expected_shape = (model.number_of_links(), 6)

if shape != expected_shape:
return False

return True
Loading

0 comments on commit fbb7726

Please sign in to comment.