Skip to content

Commit

Permalink
Solve deprecation warning
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Oct 20, 2023
1 parent 5078b36 commit 94a4688
Showing 1 changed file with 40 additions and 38 deletions.
78 changes: 40 additions & 38 deletions src/jaxsim/high_level/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def joint_positions(self, joint_names: tuple[str, ...] = None) -> jtp.Vector:

@functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
def joint_random_positions(
self, joint_names: tuple[str, ...] = None, key: jax.random.PRNGKeyArray = None
self, joint_names: tuple[str, ...] = None, key: jax.Array = None
) -> jtp.Vector:
""""""

Expand Down Expand Up @@ -1505,82 +1505,84 @@ def integrate(
# Motor dynamics
# ==============

# @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
def set_motor_inertias(
self, inertias: jtp.Vector, joint_names: List[str] = None
self, inertias: jtp.Vector, joint_names: tuple[str, ...] = None
) -> None:
if joint_names is None:
joint_names = self.joint_names()
joint_names = joint_names or self.joint_names()

if inertias.size != len(joint_names):
raise ValueError("Wrong arguments size", inertias.size, len(joint_names))

self.physics_model._joint_motor_inertia.update(
{
j: im
for j, im in zip(
self.physics_model._joint_motor_inertia.keys(), inertias
)
}
dict(zip(self.physics_model._joint_motor_inertia, inertias))
)

logging.info("Setting attribute 'motor_inertias'")

# @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
def set_motor_gear_ratios(
self, gear_ratios: jtp.Vector, joint_names: List[str] = None
self, gear_ratios: jtp.Vector, joint_names: tuple[str, ...] = None
) -> None:
if joint_names is None:
joint_names = self.joint_names()
joint_names = joint_names or self.joint_names()

if gear_ratios.size != len(joint_names):
raise ValueError("Wrong arguments size", gear_ratios.size, len(joint_names))

# Check on gear ratios if motor_inertias are not zero
jax.lax.cond(
pred=(jnp.diag(self.physics_model._joint_motor_inertia) != 0).any(),
operand=gear_ratios,
true_fun=lambda gr: gr,
false_fun=lambda: (_ for _ in _).throw(
ValueError("Motor inertias are zero")
),
)
for idx, gr in gear_ratios:
if gr != 0 and self.motor_inertias()[idx] == 0:
raise ValueError(
f"Zero motor inertia with non-zero gear ratio found in position {idx}"
)

self.physics_model._joint_motor_gear_ratio.update(
{
j: gr
for j, gr in zip(
self.physics_model._joint_motor_gear_ratio.keys(), gear_ratios
)
}
dict(zip(self.physics_model._joint_motor_gear_ratio, gear_ratios))
)

logging.info("Setting attribute 'motor_gear_ratios'")

# @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
def set_motor_viscous_frictions(
self, viscous_frictions: Tuple, joint_names: List[str] = None
self, viscous_frictions: jtp.Vector, joint_names: tuple[str, ...] = None
) -> None:
if joint_names is None:
joint_names = self.joint_names()
joint_names = joint_names or self.joint_names()

if viscous_frictions.size != len(joint_names):
raise ValueError(
"Wrong arguments size", viscous_frictions.size, len(joint_names)
)

self.physics_model._joint_motor_viscous_friction.update(
{
j: kv
for j, kv in zip(
self.physics_model._joint_motor_viscous_friction.keys(),
dict(
zip(
self.physics_model._joint_motor_viscous_friction,
viscous_frictions,
)
}
)
)

# self.physics_model._joint_motor_viscous_friction |= dict(
# zip(
# self.physics_model._joint_motor_viscous_friction.keys(),
# viscous_frictions,
# )
# )

logging.info("Setting attribute 'motor_viscous_frictions'")

@functools.partial(oop.jax_tf.method_ro, jit=False)
def motor_inertias(self) -> jtp.Vector:
return jnp.array([*self.physics_model._joint_motor_inertia.values()])

@functools.partial(oop.jax_tf.method_ro, jit=False)
def motor_gear_ratios(self) -> jtp.Vector:
return jnp.array([*self.physics_model._joint_motor_gear_ratio.values()])

@functools.partial(oop.jax_tf.method_ro, jit=False)
def motor_viscous_frictions(self) -> jtp.Vector:
return jnp.array([*self.physics_model._joint_motor_viscous_friction.values()])

# ===============
# Private methods
# ===============
Expand Down

0 comments on commit 94a4688

Please sign in to comment.