Skip to content

Commit

Permalink
Refactor code and add oop decorators for JIT compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Sep 12, 2023
1 parent ce93982 commit 68354b5
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 35 deletions.
20 changes: 14 additions & 6 deletions src/jaxsim/high_level/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,23 @@ def position_limit(self, dof: int = None) -> Tuple[jtp.Float, jtp.Float]:
# =============
# Motor methods
# =============
@functools.partial(oop.jax_tf.method_ro)
def motor_inertia(self) -> jtp.Vector:
""""""

return jnp.array(self.joint_description.motor_inertia, dtype=float)

@functools.partial(oop.jax_tf.method_ro)
def motor_gear_ratio(self) -> jtp.Vector:
""""""

def motor_inertia(self) -> float:
return self.joint_description.motor_inertia
return jnp.array(self.joint_description.motor_gear_ratio, dtype=float)

def motor_gear_ratio(self) -> float:
return self.joint_description.motor_gear_ratio
@functools.partial(oop.jax_tf.method_ro)
def motor_viscous_friction(self) -> jtp.Vector:
""""""

def motor_viscous_friction(self) -> float:
return self.joint_description.motor_viscous_friction
return jnp.array(self.joint_description.motor_viscous_friction, dtype=float)

# =================
# Multi-DoF methods
Expand Down
44 changes: 21 additions & 23 deletions src/jaxsim/high_level/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,7 @@ def integrate(
# Motor dynamics
# ==============

@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
def set_motor_inertias(
self, inertias: jtp.Vector, joint_names: List[str] = None
) -> None:
Expand All @@ -1356,18 +1357,17 @@ def set_motor_inertias(
raise ValueError("Wrong arguments size", inertias.size, len(joint_names))

self.physics_model._joint_motor_inertia.update(
dict(
map(
lambda j, g: (j, g),
self.physics_model._joint_motor_inertia.keys(),
inertias,
{
j: im
for j, im in zip(
self.physics_model._joint_motor_inertia.keys(), inertias
)
)
}
)

self.physics_model.has_motors = True
logging.info("Setting attribute 'has_motors' to True")
logging.info("Setting attribute 'motor_inertias'")

@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
def set_motor_gear_ratios(
self, gear_ratios: jtp.Vector, joint_names: List[str] = None
) -> None:
Expand All @@ -1378,20 +1378,19 @@ def set_motor_gear_ratios(
raise ValueError("Wrong arguments size", gear_ratios.size, len(joint_names))

self.physics_model._joint_motor_gear_ratio.update(
dict(
map(
lambda j, g: (j, g),
self.physics_model._joint_motor_gear_ratio.keys(),
gear_ratios,
{
j: gr
for j, gr in zip(
self.physics_model._joint_motor_gear_ratio.keys(), gear_ratios
)
)
}
)

self.physics_model.has_motors = True
logging.info("Setting attribute 'has_motors' to True")
logging.info("Setting attribute 'motor_gear_ratios'")

@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
def set_motor_viscous_frictions(
self, viscous_frictions: jtp.Vector, joint_names: List[str] = None
self, viscous_frictions: Tuple, joint_names: List[str] = None
) -> None:
if joint_names is None:
joint_names = self.joint_names()
Expand All @@ -1402,17 +1401,16 @@ def set_motor_viscous_frictions(
)

self.physics_model._joint_motor_viscous_friction.update(
dict(
map(
lambda j, g: (j, g),
{
j: kv
for j, kv in zip(
self.physics_model._joint_motor_viscous_friction.keys(),
viscous_frictions,
)
)
}
)

self.physics_model.has_motors = True
logging.info("Setting attribute 'has_motors' to True")
logging.info("Setting attribute 'motor_viscous_frictions'")

# ===============
# Private methods
Expand Down
11 changes: 5 additions & 6 deletions src/jaxsim/physics/model/physics_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class PhysicsModel(JaxsimDataclass):
)
)
is_floating_base: Static[bool] = dataclasses.field(default=False)
has_motors: Static[bool] = dataclasses.field(default=False)
gc: GroundContact = dataclasses.field(default_factory=lambda: GroundContact())
description: Static[
jaxsim.parsers.descriptions.model.ModelDescription
Expand Down Expand Up @@ -129,13 +128,15 @@ def build_from(
# Dicts from the joint index to the motor inertia, gear ratio and viscous friction.
# Note: the joint index is equal to its child link index.
joint_motor_inertia = {
joint.index: joint.motor_inertia for joint in model_description.joints
joint.index: jnp.array(joint.motor_inertia, dtype=float)
for joint in model_description.joints
}
joint_motor_gear_ratio = {
joint.index: joint.motor_gear_ratio for joint in model_description.joints
joint.index: jnp.array(joint.motor_gear_ratio, dtype=float)
for joint in model_description.joints
}
joint_motor_viscous_friction = {
joint.index: joint.motor_viscous_friction
joint.index: jnp.array(joint.motor_viscous_friction, dtype=float)
for joint in model_description.joints
}

Expand Down Expand Up @@ -204,7 +205,6 @@ def build_from(
_joint_motor_viscous_friction=joint_motor_viscous_friction,
gravity=jnp.hstack([gravity.squeeze(), np.zeros(3)]),
is_floating_base=True,
has_motors=False,
gc=GroundContact.build_from(model_description=model_description),
description=model_description,
)
Expand Down Expand Up @@ -341,7 +341,6 @@ def __repr__(self) -> str:
f"dofs: {self.dofs()},",
f"links: {self.NB},",
f"floating_base: {self.is_floating_base},",
f"has_motors: {self.has_motors},",
]
attributes_string = "\n ".join(attributes)

Expand Down

0 comments on commit 68354b5

Please sign in to comment.