diff --git a/src/adam/casadi/computations.py b/src/adam/casadi/computations.py index 9e34194e..ba2c6901 100644 --- a/src/adam/casadi/computations.py +++ b/src/adam/casadi/computations.py @@ -2,7 +2,6 @@ import casadi as cs import numpy as np -from typing import Union from adam.casadi.casadi_like import SpatialMath from adam.core import RBDAlgorithms @@ -224,9 +223,7 @@ def get_total_mass(self) -> float: """ return self.rbdalgos.get_total_mass() - def mass_matrix( - self, base_transform: cs.SX, joint_positions: cs.SX - ): + def mass_matrix(self, base_transform: cs.SX, joint_positions: cs.SX): """Returns the Mass Matrix functions computed the CRBA Args: @@ -244,9 +241,7 @@ def mass_matrix( M, _ = self.rbdalgos.crba(base_transform, joint_positions) return M.array - def centroidal_momentum_matrix( - self, base_transform: cs.SX, joint_positions: cs.SX - ): + def centroidal_momentum_matrix(self, base_transform: cs.SX, joint_positions: cs.SX): """Returns the Centroidal Momentum Matrix functions computed the CRBA Args: @@ -427,9 +422,7 @@ def coriolis_term( np.zeros(6), ).array - def gravity_term( - self, base_transform: cs.SX, joint_positions: cs.SX - ) -> cs.SX: + def gravity_term(self, base_transform: cs.SX, joint_positions: cs.SX) -> cs.SX: """Returns the gravity term of the floating-base dynamics equation, using a reduced RNEA (no acceleration and external forces) @@ -453,9 +446,7 @@ def gravity_term( self.g, ).array - def CoM_position( - self, base_transform: cs.SX, joint_positions: cs.SX - ) -> cs.SX: + def CoM_position(self, base_transform: cs.SX, joint_positions: cs.SX) -> cs.SX: """Returns the CoM positon Args: diff --git a/src/adam/jax/computations.py b/src/adam/jax/computations.py index 5d10e52c..fab5dff3 100644 --- a/src/adam/jax/computations.py +++ b/src/adam/jax/computations.py @@ -1,6 +1,5 @@ # Copyright (C) Istituto Italiano di Tecnologia (IIT). All rights reserved. - import jax.numpy as jnp import numpy as np diff --git a/src/adam/model/__init__.py b/src/adam/model/__init__.py index 0426e317..3c23a54a 100644 --- a/src/adam/model/__init__.py +++ b/src/adam/model/__init__.py @@ -1,4 +1,4 @@ -from .abc_factories import Joint, Link, ModelFactory, Inertial, Pose +from .abc_factories import Inertial, Joint, Link, ModelFactory, Pose from .model import Model from .std_factories.std_joint import StdJoint from .std_factories.std_link import StdLink diff --git a/src/adam/model/abc_factories.py b/src/adam/model/abc_factories.py index 4720588f..576ad665 100644 --- a/src/adam/model/abc_factories.py +++ b/src/adam/model/abc_factories.py @@ -1,6 +1,5 @@ import abc import dataclasses -from typing import List import numpy.typing as npt @@ -11,8 +10,8 @@ class Pose: """Pose class""" - xyz: List - rpy: List + xyz: list + rpy: list @dataclasses.dataclass @@ -46,7 +45,7 @@ class Joint(abc.ABC): parent: str child: str type: str - axis: List + axis: list origin: Pose limit: Limits idx: int @@ -126,9 +125,9 @@ class Link(abc.ABC): math: SpatialMath name: str - visuals: List + visuals: list inertial: Inertial - collisions: List + collisions: list @abc.abstractmethod def spatial_inertia(self) -> npt.ArrayLike: @@ -178,25 +177,25 @@ def build_joint(self) -> Joint: pass @abc.abstractmethod - def get_links(self) -> List[Link]: + def get_links(self) -> list[Link]: """ Returns: - List[Link]: the list of the link + list[Link]: the list of the link """ pass @abc.abstractmethod - def get_frames(self) -> List[Link]: + def get_frames(self) -> list[Link]: """ Returns: - List[Link]: the list of the frames + list[Link]: the list of the frames """ pass @abc.abstractmethod - def get_joints(self) -> List[Joint]: + def get_joints(self) -> list[Joint]: """ Returns: - List[Joint]: the list of the joints + list[Joint]: the list of the joints """ pass diff --git a/src/adam/model/conversions/idyntree.py b/src/adam/model/conversions/idyntree.py index bbd8fd8e..dcdfc5f4 100644 --- a/src/adam/model/conversions/idyntree.py +++ b/src/adam/model/conversions/idyntree.py @@ -1,11 +1,9 @@ import idyntree.bindings import numpy as np import urdf_parser_py.urdf -from typing import List - +from adam.model.abc_factories import Joint, Link from adam.model.model import Model -from adam.model.abc_factories import Link, Joint def to_idyntree_solid_shape( @@ -63,7 +61,7 @@ def to_idyntree_solid_shape( def to_idyntree_link( link: Link, -) -> [idyntree.bindings.Link, List[idyntree.bindings.SolidShape]]: +) -> [idyntree.bindings.Link, list[idyntree.bindings.SolidShape]]: """ Args: link (Link): the link to convert diff --git a/src/adam/model/model.py b/src/adam/model/model.py index 887b4e7a..544b7ba5 100644 --- a/src/adam/model/model.py +++ b/src/adam/model/model.py @@ -1,5 +1,4 @@ import dataclasses -from typing import Dict, List from adam.model.abc_factories import Joint, Link, ModelFactory from adam.model.tree import Tree @@ -11,24 +10,24 @@ class Model: Model class. It describes the robot using links and frames and their connectivity""" name: str - links: Dict[str, Link] - frames: Dict[str, Link] - joints: Dict[str, Joint] + links: dict[str, Link] + frames: dict[str, Link] + joints: dict[str, Joint] tree: Tree NDoF: int - actuated_joints: List[str] + actuated_joints: list[str] def __post_init__(self): """set the "length of the model as the number of links""" self.N = len(self.links) @staticmethod - def build(factory: ModelFactory, joints_name_list: List[str] = None) -> "Model": + def build(factory: ModelFactory, joints_name_list: list[str] = None) -> "Model": """generates the model starting from the list of joints and the links-joints factory Args: factory (ModelFactory): the factory that generates the links and the joints, starting from a description (eg. urdf) - joints_name_list (List[str]): the list of the actuated joints + joints_name_list (list[str]): the list of the actuated joints Returns: Model: the model describing the robot @@ -63,9 +62,9 @@ def build(factory: ModelFactory, joints_name_list: List[str] = None) -> "Model": tree = Tree.build_tree(links=links_list, joints=joints_list) # generate some useful dict - joints: Dict[str, Joint] = {joint.name: joint for joint in joints_list} - links: Dict[str, Link] = {link.name: link for link in links_list} - frames: Dict[str, Link] = {frame.name: frame for frame in frames_list} + joints: dict[str, Joint] = {joint.name: joint for joint in joints_list} + links: dict[str, Link] = {link.name: link for link in links_list} + frames: dict[str, Link] = {frame.name: frame for frame in frames_list} return Model( name=factory.name, @@ -77,7 +76,7 @@ def build(factory: ModelFactory, joints_name_list: List[str] = None) -> "Model": actuated_joints=joints_name_list, ) - def get_joints_chain(self, root: str, target: str) -> List[Joint]: + def get_joints_chain(self, root: str, target: str) -> list[Joint]: """generate the joints chains from a link to a link Args: @@ -85,7 +84,7 @@ def get_joints_chain(self, root: str, target: str) -> List[Joint]: target (str): the target link Returns: - List[Joint]: the list of the joints + list[Joint]: the list of the joints """ if target not in list(self.links) and target not in list(self.frames): diff --git a/src/adam/model/std_factories/std_link.py b/src/adam/model/std_factories/std_link.py index 7754747d..49475d73 100644 --- a/src/adam/model/std_factories/std_link.py +++ b/src/adam/model/std_factories/std_link.py @@ -2,7 +2,7 @@ import urdf_parser_py.urdf from adam.core.spatial_math import SpatialMath -from adam.model import Link, Inertial, Pose +from adam.model import Inertial, Link, Pose class StdLink(Link): diff --git a/src/adam/model/std_factories/std_model.py b/src/adam/model/std_factories/std_model.py index 9f3a1dcd..68337bdf 100644 --- a/src/adam/model/std_factories/std_model.py +++ b/src/adam/model/std_factories/std_model.py @@ -1,7 +1,7 @@ +import os import pathlib -from typing import List import xml.etree.ElementTree as ET -import os + import urdf_parser_py.urdf from adam.core.spatial_math import SpatialMath @@ -85,17 +85,17 @@ def __init__(self, path: str, math: SpatialMath): ) self.name = self.urdf_desc.name - def get_joints(self) -> List[StdJoint]: + def get_joints(self) -> list[StdJoint]: """ Returns: - List[StdJoint]: build the list of the joints + list[StdJoint]: build the list of the joints """ return [self.build_joint(j) for j in self.urdf_desc.joints] - def get_links(self) -> List[StdLink]: + def get_links(self) -> list[StdLink]: """ Returns: - List[StdLink]: build the list of the links + list[StdLink]: build the list of the links A link is considered a "real" link if - it has an inertial @@ -116,10 +116,10 @@ def get_links(self) -> List[StdLink]: ) ] - def get_frames(self) -> List[StdLink]: + def get_frames(self) -> list[StdLink]: """ Returns: - List[StdLink]: build the list of the links + list[StdLink]: build the list of the links A link is considered a "fake" link (frame) if - it has no inertial diff --git a/src/adam/model/tree.py b/src/adam/model/tree.py index 73d36210..f497de75 100644 --- a/src/adam/model/tree.py +++ b/src/adam/model/tree.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Dict, Iterable, List, Tuple, Union +from typing import Iterable, Union from adam.model.abc_factories import Joint, Link @@ -10,19 +10,19 @@ class Node: name: str link: Link - arcs: List[Joint] - children: List["Node"] + arcs: list[Joint] + children: list["Node"] parent: Union[Link, None] = None parent_arc: Union[Joint, None] = None def __hash__(self) -> int: return hash(self.name) - def get_elements(self) -> Tuple[Link, Joint, Link]: + def get_elements(self) -> tuple[Link, Joint, Link]: """returns the node with its parent arc and parent link Returns: - Tuple[Link, Joint, Link]: the node, the parent_arc, the parent_link + tuple[Link, Joint, Link]: the node, the parent_arc, the parent_link """ return self.link, self.parent_arc, self.parent @@ -31,24 +31,24 @@ def get_elements(self) -> Tuple[Link, Joint, Link]: class Tree(Iterable): """The directed tree class""" - graph: Dict + graph: dict root: str def __post_init__(self): self.ordered_nodes_list = self.get_ordered_nodes_list(self.root) @staticmethod - def build_tree(links: List[Link], joints: List[Joint]) -> "Tree": + def build_tree(links: list[Link], joints: list[Joint]) -> "Tree": """builds the tree from the connectivity of the elements Args: - links (List[Link]) - joints (List[Joint]) + links (list[Link]) + joints (list[Joint]) Returns: Tree: the directed tree """ - nodes: Dict[str, Node] = { + nodes: dict[str, Node] = { l.name: Node( name=l.name, link=l, arcs=[], children=[], parent=None, parent_arc=None ) @@ -81,25 +81,25 @@ def print(self, root): pptree.print_tree(self.graph[root]) - def get_ordered_nodes_list(self, start: str) -> List[str]: + def get_ordered_nodes_list(self, start: str) -> list[str]: """get the ordered list of the nodes, given the connectivity Args: start (str): the start node Returns: - List[str]: the ordered list + list[str]: the ordered list """ ordered_list = [start] self.get_children(self.graph[start], ordered_list) return ordered_list @classmethod - def get_children(cls, node: Node, list: List): + def get_children(cls, node: Node, list: list): """Recursive method that finds children of child of child Args: node (Node): the analized node - list (List): the list of the children that needs to be filled + list (list): the list of the children that needs to be filled """ if node.children is not []: for child in node.children: diff --git a/src/adam/parametric/casadi/computations_parametric.py b/src/adam/parametric/casadi/computations_parametric.py index bb32dabb..1aaea1a2 100644 --- a/src/adam/parametric/casadi/computations_parametric.py +++ b/src/adam/parametric/casadi/computations_parametric.py @@ -1,7 +1,5 @@ # Copyright (C) Istituto Italiano di Tecnologia (IIT). All rights reserved. -from typing import List, Union - import casadi as cs import numpy as np @@ -9,7 +7,7 @@ from adam.core import RBDAlgorithms from adam.core.constants import Representations from adam.model import Model -from adam.parametric.model import URDFParametricModelFactory, ParametricLink +from adam.parametric.model import ParametricLink, URDFParametricModelFactory class KinDynComputationsParametric: @@ -260,7 +258,7 @@ def get_total_mass(self) -> cs.Function: "m", [self.length_multiplier, self.densities], [m], self.f_opts ) - def get_original_densities(self) -> List[float]: + def get_original_densities(self) -> list[float]: """Returns the original densities of the parametric links Returns: diff --git a/src/adam/parametric/jax/computations_parametric.py b/src/adam/parametric/jax/computations_parametric.py index 13b278b9..ca33f6db 100644 --- a/src/adam/parametric/jax/computations_parametric.py +++ b/src/adam/parametric/jax/computations_parametric.py @@ -1,16 +1,14 @@ # Copyright (C) Istituto Italiano di Tecnologia (IIT). All rights reserved. -from typing import List - import jax.numpy as jnp import numpy as np from jax import grad, jit, vmap -from adam.core.rbd_algorithms import RBDAlgorithms from adam.core.constants import Representations +from adam.core.rbd_algorithms import RBDAlgorithms from adam.jax.jax_like import SpatialMath from adam.model import Model -from adam.parametric.model import URDFParametricModelFactory, ParametricLink +from adam.parametric.model import ParametricLink, URDFParametricModelFactory class KinDynComputationsParametric: @@ -447,7 +445,7 @@ def get_total_mass( self.NDoF = self.rbdalgos.NDoF return self.rbdalgos.get_total_mass() - def get_original_densities(self) -> List[float]: + def get_original_densities(self) -> list[float]: """Returns the original densities of the parametric links Returns: diff --git a/src/adam/parametric/model/parametric_factories/parametric_link.py b/src/adam/parametric/model/parametric_factories/parametric_link.py index 1346f3f8..82d2c571 100644 --- a/src/adam/parametric/model/parametric_factories/parametric_link.py +++ b/src/adam/parametric/model/parametric_factories/parametric_link.py @@ -1,12 +1,12 @@ +import copy +import math +from enum import Enum + import numpy.typing as npt import urdf_parser_py.urdf -from enum import Enum -import copy from adam.core.spatial_math import SpatialMath from adam.model import Link - -import math from adam.model.abc_factories import Inertia, Inertial diff --git a/src/adam/parametric/model/parametric_factories/parametric_model.py b/src/adam/parametric/model/parametric_factories/parametric_model.py index 3702735e..8d38b226 100644 --- a/src/adam/parametric/model/parametric_factories/parametric_model.py +++ b/src/adam/parametric/model/parametric_factories/parametric_model.py @@ -1,11 +1,8 @@ -import pathlib -from typing import List -import os - import urdf_parser_py.urdf + from adam.core.spatial_math import SpatialMath -from adam.model import ModelFactory, StdJoint, StdLink, Link, Joint -from adam.model.std_factories.std_model import urdf_remove_sensors_tags, get_xml_string +from adam.model import Joint, Link, ModelFactory, StdJoint, StdLink +from adam.model.std_factories.std_model import get_xml_string, urdf_remove_sensors_tags from adam.parametric.model import ParametricJoint, ParametricLink @@ -21,7 +18,7 @@ def __init__( self, path: str, math: SpatialMath, - links_name_list: List, + links_name_list: list, length_multiplier, densities, ): @@ -45,26 +42,26 @@ def __init__( self.length_multiplier = length_multiplier self.densities = densities - def get_joints(self) -> List[Joint]: + def get_joints(self) -> list[Joint]: """ Returns: - List[Joint]: build the list of the joints + list[Joint]: build the list of the joints """ return [self.build_joint(j) for j in self.urdf_desc.joints] - def get_links(self) -> List[Link]: + def get_links(self) -> list[Link]: """ Returns: - List[Link]: build the list of the links + list[Link]: build the list of the links """ return [ self.build_link(l) for l in self.urdf_desc.links if l.inertial is not None ] - def get_frames(self) -> List[StdLink]: + def get_frames(self) -> list[StdLink]: """ Returns: - List[Link]: build the list of the links + list[Link]: build the list of the links """ return [self.build_link(l) for l in self.urdf_desc.links if l.inertial is None] diff --git a/src/adam/parametric/numpy/computations_parametric.py b/src/adam/parametric/numpy/computations_parametric.py index d8fecc9a..063c9311 100644 --- a/src/adam/parametric/numpy/computations_parametric.py +++ b/src/adam/parametric/numpy/computations_parametric.py @@ -1,14 +1,12 @@ # Copyright (C) Istituto Italiano di Tecnologia (IIT). All rights reserved. - import numpy as np -from typing import List -from adam.core.rbd_algorithms import RBDAlgorithms from adam.core.constants import Representations +from adam.core.rbd_algorithms import RBDAlgorithms from adam.model import Model -from adam.parametric.model import URDFParametricModelFactory, ParametricLink from adam.numpy.numpy_like import SpatialMath +from adam.parametric.model import ParametricLink, URDFParametricModelFactory class KinDynComputationsParametric: @@ -442,7 +440,7 @@ def get_total_mass( self.NDoF = model.NDoF return self.rbdalgos.get_total_mass() - def get_original_densities(self) -> List[float]: + def get_original_densities(self) -> list[float]: """Returns the original densities of the parametric links Returns: diff --git a/src/adam/parametric/pytorch/computations_parametric.py b/src/adam/parametric/pytorch/computations_parametric.py index b9131e0a..0b100dea 100644 --- a/src/adam/parametric/pytorch/computations_parametric.py +++ b/src/adam/parametric/pytorch/computations_parametric.py @@ -1,14 +1,12 @@ # Copyright (C) Istituto Italiano di Tecnologia (IIT). All rights reserved. - import numpy as np import torch -from typing import List -from adam.core.rbd_algorithms import RBDAlgorithms from adam.core.constants import Representations +from adam.core.rbd_algorithms import RBDAlgorithms from adam.model import Model -from adam.parametric.model import URDFParametricModelFactory, ParametricLink +from adam.parametric.model import ParametricLink, URDFParametricModelFactory from adam.pytorch.torch_like import SpatialMath @@ -446,7 +444,7 @@ def get_total_mass( self.NDoF = self.rbdalgos.NDoF return self.rbdalgos.get_total_mass() - def get_original_densities(self) -> List[float]: + def get_original_densities(self) -> list[float]: """Returns the original densities of the parametric links Returns: diff --git a/src/adam/pytorch/__init__.py b/src/adam/pytorch/__init__.py index 589bceaf..b427f55a 100644 --- a/src/adam/pytorch/__init__.py +++ b/src/adam/pytorch/__init__.py @@ -1,6 +1,6 @@ # Copyright (C) Istituto Italiano di Tecnologia (IIT). All rights reserved. -from .computations import KinDynComputations from .computation_batch import KinDynComputationsBatch +from .computations import KinDynComputations from .torch_like import TorchLike diff --git a/src/adam/pytorch/torch_like.py b/src/adam/pytorch/torch_like.py index 674cd0ff..de2cb812 100644 --- a/src/adam/pytorch/torch_like.py +++ b/src/adam/pytorch/torch_like.py @@ -3,9 +3,9 @@ from dataclasses import dataclass from typing import Union +import numpy as np import numpy.typing as ntp import torch -import numpy as np from adam.core.spatial_math import ArrayLike, ArrayLikeFactory, SpatialMath diff --git a/tests/conftest.py b/tests/conftest.py index 0e9b491f..31c582bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,17 @@ -import pytest -import numpy as np -import icub_models -import idyntree.bindings as idyntree -from adam import Representations -from adam.numpy.numpy_like import SpatialMath import dataclasses -from itertools import product import logging import os +from itertools import product + +import icub_models +import idyntree.bindings as idyntree +import numpy as np +import pytest import requests +from adam import Representations +from adam.numpy.numpy_like import SpatialMath + @dataclasses.dataclass class State: diff --git a/tests/parametric/test_casadi_parametric.py b/tests/parametric/test_casadi_parametric.py index 43dbae84..b8aa4cad 100644 --- a/tests/parametric/test_casadi_parametric.py +++ b/tests/parametric/test_casadi_parametric.py @@ -1,8 +1,9 @@ import casadi as cs import numpy as np import pytest +from conftest import RobotCfg, State + from adam.parametric.casadi import KinDynComputationsParametric -from conftest import State, RobotCfg @pytest.fixture(scope="module") diff --git a/tests/parametric/test_idyntree_conversion_parametric.py b/tests/parametric/test_idyntree_conversion_parametric.py index 29c8f70d..3d7fc4ea 100644 --- a/tests/parametric/test_idyntree_conversion_parametric.py +++ b/tests/parametric/test_idyntree_conversion_parametric.py @@ -1,14 +1,15 @@ import casadi as cs import numpy as np import pytest -from adam.parametric.casadi import KinDynComputationsParametric -from conftest import State, RobotCfg, compute_idyntree_values +from conftest import RobotCfg, State, compute_idyntree_values + +from adam.model import Model from adam.model.conversions.idyntree import to_idyntree_model +from adam.numpy.numpy_like import SpatialMath +from adam.parametric.casadi import KinDynComputationsParametric from adam.parametric.model.parametric_factories.parametric_model import ( URDFParametricModelFactory, ) -from adam.model import Model -from adam.numpy.numpy_like import SpatialMath @pytest.fixture(scope="module") diff --git a/tests/parametric/test_jax_parametric.py b/tests/parametric/test_jax_parametric.py index f1aed86f..75780bc2 100644 --- a/tests/parametric/test_jax_parametric.py +++ b/tests/parametric/test_jax_parametric.py @@ -1,7 +1,8 @@ import numpy as np import pytest +from conftest import RobotCfg, State + from adam.parametric.jax import KinDynComputationsParametric -from conftest import State, RobotCfg @pytest.fixture(scope="module") diff --git a/tests/parametric/test_numpy_parametric.py b/tests/parametric/test_numpy_parametric.py index 54e2c47f..15e4ec2f 100644 --- a/tests/parametric/test_numpy_parametric.py +++ b/tests/parametric/test_numpy_parametric.py @@ -1,7 +1,8 @@ import numpy as np import pytest +from conftest import RobotCfg, State + from adam.parametric.numpy import KinDynComputationsParametric -from conftest import State, RobotCfg @pytest.fixture(scope="module") diff --git a/tests/parametric/test_pytorch_parametric.py b/tests/parametric/test_pytorch_parametric.py index 9ec765b7..840007c4 100644 --- a/tests/parametric/test_pytorch_parametric.py +++ b/tests/parametric/test_pytorch_parametric.py @@ -1,8 +1,9 @@ import numpy as np import pytest -from adam.parametric.pytorch import KinDynComputationsParametric -from conftest import State, RobotCfg import torch +from conftest import RobotCfg, State + +from adam.parametric.pytorch import KinDynComputationsParametric torch.set_default_dtype(torch.float64) diff --git a/tests/test_casadi.py b/tests/test_casadi.py index 466c16f9..4f213a4d 100644 --- a/tests/test_casadi.py +++ b/tests/test_casadi.py @@ -1,8 +1,9 @@ import casadi as cs import numpy as np import pytest +from conftest import RobotCfg, State + from adam.casadi import KinDynComputations -from conftest import State, RobotCfg @pytest.fixture(scope="module") diff --git a/tests/test_idyntree_conversion.py b/tests/test_idyntree_conversion.py index ea4bc1fe..ceb5f779 100644 --- a/tests/test_idyntree_conversion.py +++ b/tests/test_idyntree_conversion.py @@ -1,8 +1,9 @@ import casadi as cs import numpy as np import pytest +from conftest import RobotCfg, State, compute_idyntree_values + from adam.casadi import KinDynComputations -from conftest import State, RobotCfg, compute_idyntree_values from adam.model.conversions.idyntree import to_idyntree_model diff --git a/tests/test_jax.py b/tests/test_jax.py index 94c788e9..a74449e6 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -1,9 +1,10 @@ import numpy as np import pytest -from adam.jax import KinDynComputations -from conftest import State, RobotCfg +from conftest import RobotCfg, State from jax import config +from adam.jax import KinDynComputations + config.update("jax_enable_x64", True) diff --git a/tests/test_numpy.py b/tests/test_numpy.py index f5619ed2..ba3e608c 100644 --- a/tests/test_numpy.py +++ b/tests/test_numpy.py @@ -1,7 +1,8 @@ import numpy as np import pytest +from conftest import RobotCfg, State + from adam.numpy import KinDynComputations -from conftest import State, RobotCfg @pytest.fixture(scope="module") diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index a85c8290..2af3bbc4 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -1,8 +1,9 @@ import numpy as np import pytest -from adam.pytorch import KinDynComputations -from conftest import State, RobotCfg import torch +from conftest import RobotCfg, State + +from adam.pytorch import KinDynComputations torch.set_default_dtype(torch.float64) diff --git a/tests/test_pytorch_batch.py b/tests/test_pytorch_batch.py index 51435cb7..4839b207 100644 --- a/tests/test_pytorch_batch.py +++ b/tests/test_pytorch_batch.py @@ -1,10 +1,11 @@ import numpy as np import pytest -from adam.pytorch import KinDynComputationsBatch -from conftest import State, RobotCfg import torch +from conftest import RobotCfg, State from jax import config +from adam.pytorch import KinDynComputationsBatch + config.update("jax_enable_x64", True)