Skip to content

Commit

Permalink
[FEATURE] Add link-wise mask for poss and quats in multilink IK (#499)
Browse files Browse the repository at this point in the history
* add link-wise mask for poss and quats
* add mask in recompute final error
  • Loading branch information
ziyanx02 authored Jan 9, 2025
1 parent 0a5e8a8 commit 5d83899
Showing 1 changed file with 45 additions and 25 deletions.
70 changes: 45 additions & 25 deletions genesis/engine/entities/rigid_entity/rigid_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,31 +894,45 @@ def inverse_kinematics_multilink(
if n_links == 0:
gs.raise_exception("Target link not provided.")

if len(poss) == n_links:
if self._solver.n_envs > 0:
if poss[0].shape[0] != self._solver.n_envs:
gs.raise_exception("First dimension of elements in `poss` must be equal to scene.n_envs.")
elif len(poss) == 0:
if self._solver.n_envs == 0:
poss = [gu.zero_pos()] * n_links
else:
poss = [self._solver._batch_array(gu.zero_pos(), True)] * n_links
if len(poss) == 0:
poss = [None] * n_links
pos_mask = [False, False, False]
else:
elif len(poss) != n_links:
gs.raise_exception("Accepting only `poss` with length equal to `links` or empty list.")

if len(quats) == n_links:
if self._solver.n_envs > 0:
if quats[0].shape[0] != self._solver.n_envs:
gs.raise_exception("First dimension of elements in `quats` must be equal to scene.n_envs.")
elif len(quats) == 0:
if self._solver.n_envs == 0:
quats = [gu.identity_quat()] * n_links
else:
quats = [self._solver._batch_array(gu.identity_quat(), True)] * n_links
if len(quats) == 0:
quats = [None] * n_links
rot_mask = [False, False, False]
else:
gs.raise_exception("Accepting only `quats` with length equal to `links` or empty list.")
elif len(quats) != n_links:
gs.raise_exception("Accepting only `quatss` with length equal to `links` or empty list.")

link_pos_mask = []
link_rot_mask = []
for i in range(n_links):
if poss[i] is None and quats[i] is None:
gs.raise_exception("At least one of `poss` or `quats` must be provided.")
if poss[i] is not None:
link_pos_mask.append(True)
if self._solver.n_envs > 0:
if poss[i].shape[0] != self._solver.n_envs:
gs.raise_exception("First dimension of elements in `poss` must be equal to scene.n_envs.")
else:
link_pos_mask.append(False)
if self._solver.n_envs == 0:
poss[i] = gu.zero_pos()
else:
poss[i] = self._solver._batch_array(gu.zero_pos(), True)
if quats[i] is not None:
link_rot_mask.append(True)
if self._solver.n_envs > 0:
if quats[i].shape[0] != self._solver.n_envs:
gs.raise_exception("First dimension of elements in `quats` must be equal to scene.n_envs.")
else:
link_rot_mask.append(False)
if self._solver.n_envs == 0:
quats[i] = gu.identity_quat()
else:
quats[i] = self._solver._batch_array(gu.identity_quat(), True)

if init_qpos is not None:
init_qpos = torch.as_tensor(init_qpos, dtype=gs.tc_float)
Expand Down Expand Up @@ -947,6 +961,8 @@ def inverse_kinematics_multilink(
gs.raise_exception("You can only align 0, 1 axis or all 3 axes.")
else:
pass # nothing needs to change for 0 or 3 axes
link_pos_mask = torch.as_tensor(link_pos_mask, dtype=gs.tc_int, device=gs.device)
link_rot_mask = torch.as_tensor(link_rot_mask, dtype=gs.tc_int, device=gs.device)

links_idx = torch.as_tensor([link.idx for link in links], dtype=gs.tc_int, device=gs.device)
poss = torch.stack(
Expand Down Expand Up @@ -992,6 +1008,8 @@ def inverse_kinematics_multilink(
rot_tol,
pos_mask,
rot_mask,
link_pos_mask,
link_rot_mask,
max_step_size,
respect_joint_limit,
)
Expand Down Expand Up @@ -1032,6 +1050,8 @@ def _kernel_inverse_kinematics(
rot_tol: ti.f32,
pos_mask_: ti.types.ndarray(),
rot_mask_: ti.types.ndarray(),
link_pos_mask: ti.types.ndarray(),
link_rot_mask: ti.types.ndarray(),
max_step_size: ti.f32,
respect_joint_limit: ti.i32,
):
Expand Down Expand Up @@ -1067,7 +1087,7 @@ def _kernel_inverse_kinematics(
tgt_pos_i = ti.Vector([poss[i_ee, i_b, 0], poss[i_ee, i_b, 1], poss[i_ee, i_b, 2]])
err_pos_i = tgt_pos_i - self._solver.links_state[i_l_ee, i_b].pos
for k in range(3):
err_pos_i[k] *= pos_mask[k]
err_pos_i[k] *= pos_mask[k] * link_pos_mask[i_ee]
if err_pos_i.norm() > pos_tol:
solved = False

Expand All @@ -1080,7 +1100,7 @@ def _kernel_inverse_kinematics(
)
)
for k in range(3):
err_rot_i[k] *= rot_mask[k]
err_rot_i[k] *= rot_mask[k] * link_rot_mask[i_ee]
if err_rot_i.norm() > rot_tol:
solved = False

Expand Down Expand Up @@ -1150,7 +1170,7 @@ def _kernel_inverse_kinematics(
tgt_pos_i = ti.Vector([poss[i_ee, i_b, 0], poss[i_ee, i_b, 1], poss[i_ee, i_b, 2]])
err_pos_i = tgt_pos_i - self._solver.links_state[i_l_ee, i_b].pos
for k in range(3):
err_pos_i[k] *= pos_mask[k]
err_pos_i[k] *= pos_mask[k] * link_pos_mask[i_ee]
if err_pos_i.norm() > pos_tol:
solved = False

Expand All @@ -1163,7 +1183,7 @@ def _kernel_inverse_kinematics(
)
)
for k in range(3):
err_rot_i[k] *= rot_mask[k]
err_rot_i[k] *= rot_mask[k] * link_rot_mask[i_ee]
if err_rot_i.norm() > rot_tol:
solved = False

Expand Down

0 comments on commit 5d83899

Please sign in to comment.