Skip to content

Commit

Permalink
Merge pull request ami-iit#59 from ami-iit/new_api
Browse files Browse the repository at this point in the history
New high-level APIs with OOP wrappers
  • Loading branch information
diegoferigo authored Dec 6, 2023
2 parents 1a364bf + b18b930 commit 293161c
Show file tree
Hide file tree
Showing 30 changed files with 2,242 additions and 503 deletions.
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,10 @@ line-length = 88
[tool.isort]
profile = "black"
multi_line_output = 3

[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-rsxX -v --strict-markers --forked"
testpaths = [
"tests",
]
14 changes: 5 additions & 9 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,13 @@ package_dir =
python_requires = >=3.10
install_requires =
coloredlogs
jax >= 0.4.1, <0.4.11
jaxlib < 0.4.11
jax >= 0.4.1
jaxlib
jaxlie
jax_dataclasses >= 1.4.0
ml-dtypes < 0.3.0
pptree
rod
scipy
typing_extensions; python_version < "3.11"

[options.packages.find]
where = src
Expand All @@ -71,13 +70,10 @@ style =
isort
testing =
idyntree
pytest
pytest >= 6.0
pytest-forked
pytest-icdiff
robot-descriptions
all =
%(style)s
%(testing)s

[tool:pytest]
addopts = -rsxX -v --strict-markers
testpaths = tests
4 changes: 3 additions & 1 deletion src/jaxsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,7 @@ def _is_editable() -> bool:
del _np_options
del _is_editable

from . import high_level, logging, math, sixd
from . import high_level, logging, math, simulation, sixd
from .high_level.common import VelRepr
from .simulation.ode_integration import IntegratorType
from .simulation.simulator import JaxSim
1 change: 1 addition & 0 deletions src/jaxsim/high_level/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import common, joint, link, model
from .common import VelRepr
114 changes: 84 additions & 30 deletions src/jaxsim/high_level/joint.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,127 @@
import dataclasses
from typing import Any, Tuple
import functools
from typing import Any

import jax.numpy as jnp
import jax_dataclasses
from jax_dataclasses import Static

import jaxsim.parsers
import jaxsim.typing as jtp
from jaxsim.utils import JaxsimDataclass
from jaxsim.utils import Vmappable, not_tracing, oop


@jax_dataclasses.pytree_dataclass
class Joint(JaxsimDataclass):
class Joint(Vmappable):
"""
High-level class to operate on a single joint of a simulated model.
High-level class to operate in r/o on a single joint of a simulated model.
"""

joint_description: Static[jaxsim.parsers.descriptions.JointDescription]

_parent_model: Any = dataclasses.field(default=None, repr=False, compare=False)
_parent_model: Any = dataclasses.field(
default=None, repr=False, compare=False, hash=False
)

@property
def parent_model(self) -> "jaxsim.high_level.model.Model":
""""""

return self._parent_model

def valid(self) -> bool:
return self.parent_model is not None
@functools.partial(oop.jax_tf.method_ro, jit=False)
def valid(self) -> jtp.Bool:
""""""

return jnp.array(self.parent_model is not None, dtype=bool)

@functools.partial(oop.jax_tf.method_ro, jit=False)
def index(self) -> jtp.Int:
""""""

return jnp.array(self.joint_description.index, dtype=int)

def index(self) -> int:
return self.joint_description.index
@functools.partial(oop.jax_tf.method_ro)
def dofs(self) -> jtp.Int:
""""""

def dofs(self) -> int:
return 1
return jnp.array(1, dtype=int)

@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
def name(self) -> str:
""""""

return self.joint_description.name

def position(self, dof: int = 0) -> float:
return self.parent_model.joint_positions(joint_names=[self.name()])[dof]
@functools.partial(oop.jax_tf.method_ro)
def position(self, dof: int = None) -> jtp.Float:
""""""

dof = dof if dof is not None else 0

return jnp.array(
self.parent_model.joint_positions(joint_names=(self.name(),))[dof],
dtype=float,
)

@functools.partial(oop.jax_tf.method_ro)
def velocity(self, dof: int = None) -> jtp.Float:
""""""

def velocity(self, dof: int = 0) -> float:
return self.parent_model.joint_velocities(joint_names=[self.name()])[dof]
dof = dof if dof is not None else 0

def acceleration(self, dof: int = 0) -> float:
return self.parent_model.joint_accelerations(joint_names=[self.name()])[dof]
return jnp.array(
self.parent_model.joint_velocities(joint_names=(self.name(),))[dof],
dtype=float,
)

def force(self, dof: int = 0) -> float:
return self.parent_model.joint_generalized_forces(joint_names=[self.name()])[
dof
]
@functools.partial(oop.jax_tf.method_ro)
def force_target(self, dof: int = None) -> jtp.Float:
""""""

def position_limit(self, dof: int = 0) -> Tuple[float, float]:
if dof != 0:
dof = dof if dof is not None else 0

return jnp.array(
self.parent_model.joint_generalized_forces_targets(
joint_names=(self.name(),)
)[dof],
dtype=float,
)

@functools.partial(oop.jax_tf.method_ro)
def position_limit(self, dof: int = None) -> tuple[jtp.Float, jtp.Float]:
""""""

dof = dof if dof is not None else 0

if not_tracing(dof) and dof != 0:
msg = "Only joints with 1 DoF are currently supported"
raise ValueError(msg)

return self.joint_description.position_limit
low, high = self.joint_description.position_limit

return jnp.array(low, dtype=float), jnp.array(high, dtype=float)

# =================
# Multi-DoF methods
# =================

@functools.partial(oop.jax_tf.method_ro)
def joint_position(self) -> jtp.Vector:
return self.parent_model.joint_positions(joint_names=[self.name()])
""""""

return self.parent_model.joint_positions(joint_names=(self.name(),))

@functools.partial(oop.jax_tf.method_ro)
def joint_velocity(self) -> jtp.Vector:
return self.parent_model.joint_velocities(joint_names=[self.name()])
""""""

return self.parent_model.joint_velocities(joint_names=(self.name(),))

def joint_acceleration(self) -> jtp.Vector:
return self.parent_model.joint_accelerations(joint_names=[self.name()])
@functools.partial(oop.jax_tf.method_ro)
def joint_force_target(self) -> jtp.Vector:
""""""

def joint_force(self) -> jtp.Vector:
return self.parent_model.joint_generalized_forces(joint_names=[self.name()])
return self.parent_model.joint_generalized_forces_targets(
joint_names=(self.name(),)
)
Loading

0 comments on commit 293161c

Please sign in to comment.