diff --git a/src/adam/casadi/casadi_like.py b/src/adam/casadi/casadi_like.py index 9ee0ce02..1a9ae888 100644 --- a/src/adam/casadi/casadi_like.py +++ b/src/adam/casadi/casadi_like.py @@ -103,16 +103,18 @@ def T(self) -> "CasadiLike": class CasadiLikeFactory(ArrayLikeFactory): - @staticmethod - def zeros(*x: int) -> "CasadiLike": + + def __init__(self, cs_type: Union[cs.SX, cs.DM]): + self.cs_type = cs_type + + def zeros(self, *x: int) -> "CasadiLike": """ Returns: CasadiLike: Matrix of zeros of dim *x """ - return CasadiLike(cs.SX.zeros(*x)) + return CasadiLike(self.cs_type.zeros(*x)) - @staticmethod - def eye(x: int) -> "CasadiLike": + def eye(self, x: int) -> "CasadiLike": """ Args: x (int): matrix dimension @@ -120,20 +122,20 @@ def eye(x: int) -> "CasadiLike": Returns: CasadiLike: Identity matrix """ - return CasadiLike(cs.SX.eye(x)) + return CasadiLike(self.cs_type.eye(x)) - @staticmethod - def array(*x) -> "CasadiLike": + def array(self, *x) -> "CasadiLike": """ Returns: CasadiLike: Vector wrapping *x """ - return CasadiLike(cs.SX(*x)) + return CasadiLike(self.cs_type(*x)) class SpatialMath(SpatialMath): - def __init__(self): - super().__init__(CasadiLikeFactory) + + def __init__(self, cs_type: Union[cs.SX, cs.DM]): + super().__init__(CasadiLikeFactory(cs_type)) @staticmethod def skew(x: Union["CasadiLike", npt.ArrayLike]) -> "CasadiLike": diff --git a/src/adam/casadi/computations.py b/src/adam/casadi/computations.py index eab62718..9e3e6ab7 100644 --- a/src/adam/casadi/computations.py +++ b/src/adam/casadi/computations.py @@ -22,6 +22,7 @@ def __init__( urdfstring: str, joints_name_list: list = None, root_link: str = "root_link", + cs_type: Union[cs.SX, cs.DM] = cs.SX, gravity: np.array = np.array([0.0, 0.0, -9.80665, 0.0, 0.0, 0.0]), f_opts: dict = dict(jit=False, jit_options=dict(flags="-Ofast")), ) -> None: @@ -31,7 +32,7 @@ def __init__( joints_name_list (list): list of the actuated joints root_link (str, optional): the first link. Defaults to 'root_link'. """ - math = SpatialMath() + math = SpatialMath(cs_type) factory = URDFModelFactory(path=urdfstring, math=math) model = Model.build(factory=factory, joints_name_list=joints_name_list) self.rbdalgos = RBDAlgorithms(model=model, math=math) diff --git a/src/adam/parametric/casadi/computations_parametric.py b/src/adam/parametric/casadi/computations_parametric.py index 37459821..85f92dda 100644 --- a/src/adam/parametric/casadi/computations_parametric.py +++ b/src/adam/parametric/casadi/computations_parametric.py @@ -1,7 +1,7 @@ # Copyright (C) 2021 Istituto Italiano di Tecnologia (IIT). All rights reserved. # This software may be modified and distributed under the terms of the # GNU Lesser General Public License v2.1 or any later version. -from typing import List +from typing import List, Union import casadi as cs import numpy as np @@ -24,6 +24,7 @@ def __init__( joints_name_list: list, links_name_list: list, root_link: str = "root_link", + cs_type: Union[cs.SX, cs.DM] = cs.SX, gravity: np.array = np.array([0.0, 0.0, -9.80665, 0.0, 0.0, 0.0]), f_opts: dict = dict(jit=False, jit_options=dict(flags="-Ofast")), ) -> None: @@ -34,7 +35,7 @@ def __init__( links_name_list (list): list of the parametrized links root_link (str, optional): the first link. Defaults to 'root_link'. """ - math = SpatialMath() + math = SpatialMath(cs_type) n_param_links = len(links_name_list) self.densities = cs.SX.sym("densities", n_param_links) self.length_multiplier = cs.SX.sym("length_multiplier", n_param_links)