Skip to content

Commit

Permalink
Update usage of HashlessObject in JaxSimModel
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Jun 3, 2024
1 parent 7492038 commit d534de5
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 20 deletions.
8 changes: 4 additions & 4 deletions src/jaxsim/api/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def idx_of_parent_link(model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike) -
"""

# Get the intermediate representation parsed from the model description.
ir = model.description.get()
ir = model.description

# Extract the indices of the frame and the link it is attached to.
F = ir.frames[frame_idx - model.number_of_links()]
Expand All @@ -51,7 +51,7 @@ def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> int:
The index of the frame.
"""

frame_names = np.array([frame.name for frame in model.description.get().frames])
frame_names = np.array([frame.name for frame in model.description.frames])

if frame_name in frame_names:
idx_in_list = np.argwhere(frame_names == frame_name)
Expand All @@ -72,7 +72,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str
The name of the frame.
"""

return model.description.get().frames[frame_index - model.number_of_links()].name
return model.description.frames[frame_index - model.number_of_links()].name


@functools.partial(jax.jit, static_argnames=["frame_names"])
Expand Down Expand Up @@ -144,7 +144,7 @@ def transform(
W_H_L = js.link.transform(model=model, data=data, link_index=L)

# Get the static frame pose wrt the parent link.
frame = model.description.get().frames[frame_index - model.number_of_links()]
frame = model.description.frames[frame_index - model.number_of_links()]
L_H_F = frame.pose

# Combine the transforms computing the frame pose.
Expand Down
17 changes: 10 additions & 7 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,21 @@ class JaxSimModel(JaxsimDataclass):
terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False
)
kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
dataclasses.field(default=None, repr=False, compare=False, hash=False)
)

built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
default=None, repr=False, compare=False, hash=False
)

description: Static[
_description: Static[
HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None]
] = dataclasses.field(default=None, repr=False, compare=False, hash=False)

kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
dataclasses.field(default=None, repr=False, compare=False, hash=False)
)
@property
def description(self) -> jaxsim.parsers.descriptions.ModelDescription:
return self._description.get()

def __eq__(self, other: JaxSimModel) -> bool:

Expand Down Expand Up @@ -153,7 +156,7 @@ def build(
# Build the model
model = JaxSimModel(
model_name=model_name,
description=HashlessObject(obj=model_description),
_description=HashlessObject(obj=model_description),
kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
model_description=model_description
),
Expand Down Expand Up @@ -276,7 +279,7 @@ def frame_names(self) -> tuple[str, ...]:
The names of the links in the model.
"""

return tuple([frame.name for frame in self.description.get().frames])
return tuple([frame.name for frame in self.description.frames])


# =====================
Expand Down Expand Up @@ -313,7 +316,7 @@ def reduce(

# Copy the model description with a deep copy of the joints.
intermediate_description = dataclasses.replace(
model.description.get(), joints=copy.deepcopy(model.description.get().joints)
model.description, joints=copy.deepcopy(model.description.joints)
)

# Update the initial position of the joints.
Expand Down
14 changes: 6 additions & 8 deletions tests/test_api_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ def test_frame_index(jaxsim_models_types: js.model.JaxSimModel):
# =====

frame_indices = tuple(
frame.index
for frame in model.description.get().frames
if frame.index is not None
frame.index for frame in model.description.frames if frame.index is not None
)

frame_names = np.array([frame.name for frame in model.description.get().frames])
frame_names = np.array([frame.name for frame in model.description.frames])

for frame_idx, frame_name in zip(frame_indices, frame_names):
assert js.frame.name_to_idx(model=model, frame_name=frame_name) == frame_idx
Expand Down Expand Up @@ -60,7 +58,7 @@ def test_frame_transforms(
# Get all names of frames in the iDynTree model.
frame_names = [
frame.name
for frame in model.description.get().frames
for frame in model.description.frames
if frame.name in kin_dyn.frame_names()
]

Expand All @@ -74,7 +72,7 @@ def test_frame_transforms(
# Get indices of frames.
frame_indices = tuple(
frame.index
for frame in model.description.get().frames
for frame in model.description.frames
if frame.index is not None and frame.name in frame_names
)

Expand Down Expand Up @@ -115,7 +113,7 @@ def test_frame_jacobians(
# Get all names of frames in the iDynTree model.
frame_names = [
frame.name
for frame in model.description.get().frames
for frame in model.description.frames
if frame.name in kin_dyn.frame_names()
]

Expand All @@ -127,7 +125,7 @@ def test_frame_jacobians(
# Get indices of frames.
frame_indices = tuple(
frame.index
for frame in model.description.get().frames
for frame in model.description.frames
if frame.index is not None and frame.name in frame_names
)

Expand Down
2 changes: 1 addition & 1 deletion tests/utils_idyntree.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def build_kindyncomputations_from_jaxsim_model(
# Get the default positions already stored in the model description.
removed_joint_positions_default = {
str(j.name): float(j.initial_position)
for j in model.description.get()._joints_removed
for j in model.description._joints_removed
if j.name not in considered_joints
}

Expand Down

0 comments on commit d534de5

Please sign in to comment.