Skip to content

Commit

Permalink
Merge pull request #84 from ami-iit/use-cs-mx
Browse files Browse the repository at this point in the history
Refactor Casadi interface in order to accept also MX.
  • Loading branch information
Giulero authored Jun 11, 2024
2 parents d16713d + 213ff95 commit 9560312
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
24 changes: 13 additions & 11 deletions src/adam/casadi/casadi_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,37 +103,39 @@ def T(self) -> "CasadiLike":


class CasadiLikeFactory(ArrayLikeFactory):
@staticmethod
def zeros(*x: int) -> "CasadiLike":

def __init__(self, cs_type: Union[cs.SX, cs.DM]):
self.cs_type = cs_type

def zeros(self, *x: int) -> "CasadiLike":
"""
Returns:
CasadiLike: Matrix of zeros of dim *x
"""
return CasadiLike(cs.SX.zeros(*x))
return CasadiLike(self.cs_type.zeros(*x))

@staticmethod
def eye(x: int) -> "CasadiLike":
def eye(self, x: int) -> "CasadiLike":
"""
Args:
x (int): matrix dimension
Returns:
CasadiLike: Identity matrix
"""
return CasadiLike(cs.SX.eye(x))
return CasadiLike(self.cs_type.eye(x))

@staticmethod
def array(*x) -> "CasadiLike":
def array(self, *x) -> "CasadiLike":
"""
Returns:
CasadiLike: Vector wrapping *x
"""
return CasadiLike(cs.SX(*x))
return CasadiLike(self.cs_type(*x))


class SpatialMath(SpatialMath):
def __init__(self):
super().__init__(CasadiLikeFactory)

def __init__(self, cs_type: Union[cs.SX, cs.DM]):
super().__init__(CasadiLikeFactory(cs_type))

@staticmethod
def skew(x: Union["CasadiLike", npt.ArrayLike]) -> "CasadiLike":
Expand Down
3 changes: 2 additions & 1 deletion src/adam/casadi/computations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
urdfstring: str,
joints_name_list: list = None,
root_link: str = "root_link",
cs_type: Union[cs.SX, cs.DM] = cs.SX,
gravity: np.array = np.array([0.0, 0.0, -9.80665, 0.0, 0.0, 0.0]),
f_opts: dict = dict(jit=False, jit_options=dict(flags="-Ofast")),
) -> None:
Expand All @@ -31,7 +32,7 @@ def __init__(
joints_name_list (list): list of the actuated joints
root_link (str, optional): the first link. Defaults to 'root_link'.
"""
math = SpatialMath()
math = SpatialMath(cs_type)
factory = URDFModelFactory(path=urdfstring, math=math)
model = Model.build(factory=factory, joints_name_list=joints_name_list)
self.rbdalgos = RBDAlgorithms(model=model, math=math)
Expand Down
5 changes: 3 additions & 2 deletions src/adam/parametric/casadi/computations_parametric.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (C) 2021 Istituto Italiano di Tecnologia (IIT). All rights reserved.
# This software may be modified and distributed under the terms of the
# GNU Lesser General Public License v2.1 or any later version.
from typing import List
from typing import List, Union

import casadi as cs
import numpy as np
Expand All @@ -24,6 +24,7 @@ def __init__(
joints_name_list: list,
links_name_list: list,
root_link: str = "root_link",
cs_type: Union[cs.SX, cs.DM] = cs.SX,
gravity: np.array = np.array([0.0, 0.0, -9.80665, 0.0, 0.0, 0.0]),
f_opts: dict = dict(jit=False, jit_options=dict(flags="-Ofast")),
) -> None:
Expand All @@ -34,7 +35,7 @@ def __init__(
links_name_list (list): list of the parametrized links
root_link (str, optional): the first link. Defaults to 'root_link'.
"""
math = SpatialMath()
math = SpatialMath(cs_type)
n_param_links = len(links_name_list)
self.densities = cs.SX.sym("densities", n_param_links)
self.length_multiplier = cs.SX.sym("length_multiplier", n_param_links)
Expand Down

0 comments on commit 9560312

Please sign in to comment.