Skip to content

Commit

Permalink
Merge pull request #89 from ami-iit/remove-mx
Browse files Browse the repository at this point in the history
Remove casadi.MX
  • Loading branch information
Giulero authored Jun 25, 2024
2 parents c460594 + 52b8377 commit 393abd8
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 16 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,20 @@ w_H_b = np.eye(4)
joints = np.ones(len(joints_name_list))
M = kinDyn.mass_matrix_fun()
print(M(w_H_b, joints))

# If you want to use the symbolic version
w_H_b = cs.SX.eye(4)
joints = cs.SX.sym('joints', len(joints_name_list))
M = kinDyn.mass_matrix_fun()
print(M(w_H_b, joints))

# This is usable also with casadi.MX
w_H_b = cs.MX.eye(4)
joints = cs.MX.sym('joints', len(joints_name_list))
M = kinDyn.mass_matrix_fun()
print(M(w_H_b, joints))


```

### PyTorch interface
Expand Down
22 changes: 11 additions & 11 deletions src/adam/casadi/casadi_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,38 +104,38 @@ def T(self) -> "CasadiLike":

class CasadiLikeFactory(ArrayLikeFactory):

def __init__(self, cs_type: Union[cs.SX, cs.DM]):
self.cs_type = cs_type

def zeros(self, *x: int) -> "CasadiLike":
@staticmethod
def zeros(*x: int) -> "CasadiLike":
"""
Returns:
CasadiLike: Matrix of zeros of dim *x
"""
return CasadiLike(self.cs_type.zeros(*x))
return CasadiLike(cs.SX.zeros(*x))

def eye(self, x: int) -> "CasadiLike":
@staticmethod
def eye(x: int) -> "CasadiLike":
"""
Args:
x (int): matrix dimension
Returns:
CasadiLike: Identity matrix
"""
return CasadiLike(self.cs_type.eye(x))
return CasadiLike(cs.SX.eye(x))

def array(self, *x) -> "CasadiLike":
@staticmethod
def array(*x) -> "CasadiLike":
"""
Returns:
CasadiLike: Vector wrapping *x
"""
return CasadiLike(self.cs_type(*x))
return CasadiLike(cs.SX(*x))


class SpatialMath(SpatialMath):

def __init__(self, cs_type: Union[cs.SX, cs.DM]):
super().__init__(CasadiLikeFactory(cs_type))
def __init__(self):
super().__init__(CasadiLikeFactory)

@staticmethod
def skew(x: Union["CasadiLike", npt.ArrayLike]) -> "CasadiLike":
Expand Down
69 changes: 66 additions & 3 deletions src/adam/casadi/computations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(
urdfstring: str,
joints_name_list: list = None,
root_link: str = "root_link",
cs_type: Union[cs.SX, cs.DM] = cs.SX,
gravity: np.array = np.array([0.0, 0.0, -9.80665, 0.0, 0.0, 0.0]),
f_opts: dict = dict(jit=False, jit_options=dict(flags="-Ofast"), cse=True),
) -> None:
Expand All @@ -32,7 +31,7 @@ def __init__(
joints_name_list (list): list of the actuated joints
root_link (str, optional): the first link. Defaults to 'root_link'.
"""
math = SpatialMath(cs_type)
math = SpatialMath()
factory = URDFModelFactory(path=urdfstring, math=math)
model = Model.build(factory=factory, joints_name_list=joints_name_list)
self.rbdalgos = RBDAlgorithms(model=model, math=math)
Expand Down Expand Up @@ -239,8 +238,13 @@ def mass_matrix(
joint_positions (Union[cs.SX, cs.DM]): The joints position
Returns:
M (jax): Mass Matrix
M (Union[cs.SX, cs.DM]): Mass Matrix
"""
if isinstance(base_transform, cs.MX) and isinstance(joint_positions, cs.MX):
raise ValueError(
"You are using casadi MX. Please use the function KinDynComputations.mass_matrix_fun()"
)

[M, _] = self.rbdalgos.crba(base_transform, joint_positions)
return M.array

Expand All @@ -256,6 +260,11 @@ def centroidal_momentum_matrix(
Returns:
Jcc (Union[cs.SX, cs.DM]): Centroidal Momentum matrix
"""
if isinstance(base_transform, cs.MX) and isinstance(joint_positions, cs.MX):
raise ValueError(
"You are using casadi MX. Please use the function KinDynComputations.centroidal_momentum_matrix_fun()"
)

[_, Jcm] = self.rbdalgos.crba(base_transform, joint_positions)
return Jcm.array

Expand All @@ -269,6 +278,11 @@ def relative_jacobian(self, frame: str, joint_positions: Union[cs.SX, cs.DM]):
Returns:
J (Union[cs.SX, cs.DM]): The Jacobian between the root and the frame
"""
if isinstance(joint_positions, cs.MX):
raise ValueError(
"You are using casadi MX. Please use the function KinDynComputations.relative_jacobian_fun()"
)

return self.rbdalgos.relative_jacobian(frame, joint_positions).array

def jacobian_dot(
Expand All @@ -291,6 +305,15 @@ def jacobian_dot(
Returns:
Jdot (Union[cs.SX, cs.DM]): The Jacobian derivative relative to the frame
"""
if (
isinstance(base_transform, cs.MX)
and isinstance(joint_positions, cs.MX)
and isinstance(base_velocity, cs.MX)
and isinstance(joint_velocities, cs.MX)
):
raise ValueError(
"You are using casadi MX. Please use the function KinDynComputations.jacobian_dot_fun()"
)
return self.rbdalgos.jacobian_dot(
frame, base_transform, joint_positions, base_velocity, joint_velocities
).array
Expand All @@ -311,6 +334,11 @@ def forward_kinematics(
Returns:
H (Union[cs.SX, cs.DM]): The fk represented as Homogenous transformation matrix
"""
if isinstance(base_transform, cs.MX) and isinstance(joint_positions, cs.MX):
raise ValueError(
"You are using casadi MX. Please use the function KinDynComputations.forward_kinematics_fun()"
)

return self.rbdalgos.forward_kinematics(
frame, base_transform, joint_positions
).array
Expand All @@ -326,6 +354,11 @@ def jacobian(self, frame: str, base_transform, joint_positions):
Returns:
J_tot (Union[cs.SX, cs.DM]): The Jacobian relative to the frame
"""
if isinstance(base_transform, cs.MX) and isinstance(joint_positions, cs.MX):
raise ValueError(
"You are using casadi MX. Please use the function KinDynComputations.jacobian_fun()"
)

return self.rbdalgos.jacobian(frame, base_transform, joint_positions).array

def bias_force(
Expand All @@ -347,6 +380,16 @@ def bias_force(
Returns:
h (Union[cs.SX, cs.DM]): the bias force
"""
if (
isinstance(base_transform, cs.MX)
and isinstance(joint_positions, cs.MX)
and isinstance(base_velocity, cs.MX)
and isinstance(joint_velocities, cs.MX)
):
raise ValueError(
"You are using casadi MX. Please use the function KinDynComputations.bias_force_fun()"
)

return self.rbdalgos.rnea(
base_transform, joint_positions, base_velocity, joint_velocities, self.g
).array
Expand All @@ -370,6 +413,16 @@ def coriolis_term(
Returns:
C (Union[cs.SX, cs.DM]): the Coriolis term
"""
if (
isinstance(base_transform, cs.MX)
and isinstance(joint_positions, cs.MX)
and isinstance(base_velocity, cs.MX)
and isinstance(joint_velocities, cs.MX)
):
raise ValueError(
"You are using casadi MX. Please use the function KinDynComputations.coriolis_term_fun()"
)

return self.rbdalgos.rnea(
base_transform,
joint_positions,
Expand All @@ -391,6 +444,11 @@ def gravity_term(
Returns:
G (Union[cs.SX, cs.DM]): the gravity term
"""
if isinstance(base_transform, cs.MX) and isinstance(joint_positions, cs.MX):
raise ValueError(
"You are using casadi MX. Please use the function KinDynComputations.gravity_term_fun()"
)

return self.rbdalgos.rnea(
base_transform,
joint_positions,
Expand All @@ -411,4 +469,9 @@ def CoM_position(
Returns:
CoM (Union[cs.SX, cs.DM]): The CoM position
"""
if isinstance(base_transform, cs.MX) and isinstance(joint_positions, cs.MX):
raise ValueError(
"You are using casadi MX. Please use the function KinDynComputations.CoM_position_fun()"
)

return self.rbdalgos.CoM_position(base_transform, joint_positions).array
3 changes: 1 addition & 2 deletions src/adam/parametric/casadi/computations_parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(
joints_name_list: list,
links_name_list: list,
root_link: str = "root_link",
cs_type: Union[cs.SX, cs.DM] = cs.SX,
gravity: np.array = np.array([0.0, 0.0, -9.80665, 0.0, 0.0, 0.0]),
f_opts: dict = dict(jit=False, jit_options=dict(flags="-Ofast")),
) -> None:
Expand All @@ -35,7 +34,7 @@ def __init__(
links_name_list (list): list of the parametrized links
root_link (str, optional): the first link. Defaults to 'root_link'.
"""
math = SpatialMath(cs_type)
math = SpatialMath()
n_param_links = len(links_name_list)
self.densities = cs.SX.sym("densities", n_param_links)
self.length_multiplier = cs.SX.sym("length_multiplier", n_param_links)
Expand Down

0 comments on commit 393abd8

Please sign in to comment.