diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index b12ea7a..2dd8a89 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -894,31 +894,45 @@ def inverse_kinematics_multilink( if n_links == 0: gs.raise_exception("Target link not provided.") - if len(poss) == n_links: - if self._solver.n_envs > 0: - if poss[0].shape[0] != self._solver.n_envs: - gs.raise_exception("First dimension of elements in `poss` must be equal to scene.n_envs.") - elif len(poss) == 0: - if self._solver.n_envs == 0: - poss = [gu.zero_pos()] * n_links - else: - poss = [self._solver._batch_array(gu.zero_pos(), True)] * n_links + if len(poss) == 0: + poss = [None] * n_links pos_mask = [False, False, False] - else: + elif len(poss) != n_links: gs.raise_exception("Accepting only `poss` with length equal to `links` or empty list.") - if len(quats) == n_links: - if self._solver.n_envs > 0: - if quats[0].shape[0] != self._solver.n_envs: - gs.raise_exception("First dimension of elements in `quats` must be equal to scene.n_envs.") - elif len(quats) == 0: - if self._solver.n_envs == 0: - quats = [gu.identity_quat()] * n_links - else: - quats = [self._solver._batch_array(gu.identity_quat(), True)] * n_links + if len(quats) == 0: + quats = [None] * n_links rot_mask = [False, False, False] - else: - gs.raise_exception("Accepting only `quats` with length equal to `links` or empty list.") + elif len(quats) != n_links: + gs.raise_exception("Accepting only `quatss` with length equal to `links` or empty list.") + + link_pos_mask = [] + link_rot_mask = [] + for i in range(n_links): + if poss[i] is None and quats[i] is None: + gs.raise_exception("At least one of `poss` or `quats` must be provided.") + if poss[i] is not None: + link_pos_mask.append(True) + if self._solver.n_envs > 0: + if poss[i].shape[0] != self._solver.n_envs: + gs.raise_exception("First dimension of elements in `poss` must be equal to scene.n_envs.") + else: + link_pos_mask.append(False) + if self._solver.n_envs == 0: + poss[i] = gu.zero_pos() + else: + poss[i] = self._solver._batch_array(gu.zero_pos(), True) + if quats[i] is not None: + link_rot_mask.append(True) + if self._solver.n_envs > 0: + if quats[i].shape[0] != self._solver.n_envs: + gs.raise_exception("First dimension of elements in `quats` must be equal to scene.n_envs.") + else: + link_rot_mask.append(False) + if self._solver.n_envs == 0: + quats[i] = gu.identity_quat() + else: + quats[i] = self._solver._batch_array(gu.identity_quat(), True) if init_qpos is not None: init_qpos = torch.as_tensor(init_qpos, dtype=gs.tc_float) @@ -947,6 +961,8 @@ def inverse_kinematics_multilink( gs.raise_exception("You can only align 0, 1 axis or all 3 axes.") else: pass # nothing needs to change for 0 or 3 axes + link_pos_mask = torch.as_tensor(link_pos_mask, dtype=gs.tc_int, device=gs.device) + link_rot_mask = torch.as_tensor(link_rot_mask, dtype=gs.tc_int, device=gs.device) links_idx = torch.as_tensor([link.idx for link in links], dtype=gs.tc_int, device=gs.device) poss = torch.stack( @@ -992,6 +1008,8 @@ def inverse_kinematics_multilink( rot_tol, pos_mask, rot_mask, + link_pos_mask, + link_rot_mask, max_step_size, respect_joint_limit, ) @@ -1032,6 +1050,8 @@ def _kernel_inverse_kinematics( rot_tol: ti.f32, pos_mask_: ti.types.ndarray(), rot_mask_: ti.types.ndarray(), + link_pos_mask: ti.types.ndarray(), + link_rot_mask: ti.types.ndarray(), max_step_size: ti.f32, respect_joint_limit: ti.i32, ): @@ -1067,7 +1087,7 @@ def _kernel_inverse_kinematics( tgt_pos_i = ti.Vector([poss[i_ee, i_b, 0], poss[i_ee, i_b, 1], poss[i_ee, i_b, 2]]) err_pos_i = tgt_pos_i - self._solver.links_state[i_l_ee, i_b].pos for k in range(3): - err_pos_i[k] *= pos_mask[k] + err_pos_i[k] *= pos_mask[k] * link_pos_mask[i_ee] if err_pos_i.norm() > pos_tol: solved = False @@ -1080,7 +1100,7 @@ def _kernel_inverse_kinematics( ) ) for k in range(3): - err_rot_i[k] *= rot_mask[k] + err_rot_i[k] *= rot_mask[k] * link_rot_mask[i_ee] if err_rot_i.norm() > rot_tol: solved = False @@ -1150,7 +1170,7 @@ def _kernel_inverse_kinematics( tgt_pos_i = ti.Vector([poss[i_ee, i_b, 0], poss[i_ee, i_b, 1], poss[i_ee, i_b, 2]]) err_pos_i = tgt_pos_i - self._solver.links_state[i_l_ee, i_b].pos for k in range(3): - err_pos_i[k] *= pos_mask[k] + err_pos_i[k] *= pos_mask[k] * link_pos_mask[i_ee] if err_pos_i.norm() > pos_tol: solved = False @@ -1163,7 +1183,7 @@ def _kernel_inverse_kinematics( ) ) for k in range(3): - err_rot_i[k] *= rot_mask[k] + err_rot_i[k] *= rot_mask[k] * link_rot_mask[i_ee] if err_rot_i.norm() > rot_tol: solved = False