diff --git a/examples/rigid/set_phys_attr.py b/examples/rigid/set_phys_attr.py new file mode 100644 index 00000000..8dc0bf56 --- /dev/null +++ b/examples/rigid/set_phys_attr.py @@ -0,0 +1,229 @@ +import argparse + +import numpy as np + +import genesis as gs + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-v", "--vis", action="store_true", default=False) + args = parser.parse_args() + + ########################## init ########################## + gs.init(backend=gs.gpu) + + ########################## create a scene ########################## + viewer_options = gs.options.ViewerOptions( + camera_pos=(0, -3.5, 2.5), + camera_lookat=(0.0, 0.0, 0.5), + camera_fov=40, + max_FPS=60, + ) + + scene = gs.Scene( + viewer_options=viewer_options, + sim_options=gs.options.SimOptions( + dt=0.01, + ), + show_viewer=args.vis, + rigid_options=gs.options.RigidOptions( + # NOTE: Batching dofs/links info to set different physical attributes across environments (in parallel) + # By default, both are False as it's faster and thus only turn this on if necessary + batch_dofs_info=True, + batch_links_info=True, + ), + ) + + ########################## entities ########################## + plane = scene.add_entity( + gs.morphs.Plane(), + ) + franka = scene.add_entity( + gs.morphs.MJCF(file="xml/franka_emika_panda/panda.xml"), + ) + ########################## build ########################## + scene.build(n_envs=2) # test with 2 different environments + + jnt_names = [ + "joint1", + "joint2", + "joint3", + "joint4", + "joint5", + "joint6", + "joint7", + "finger_joint1", + "finger_joint2", + ] + dofs_idx = [franka.get_joint(name).dof_idx_local for name in jnt_names] + + lnk_names = [ + "link0", + "link1", + "link2", + "link3", + "link4", + "link5", + "link6", + "link7", + "hand", + "left_finger", + "right_finger", + ] + links_idx = [franka.get_link(name).idx_local for name in lnk_names] + + # Optional: set control gains + franka.set_dofs_kp( + np.array( + [ + [4500, 4500, 3500, 3500, 2000, 2000, 2000, 100, 100], + [100, 100, 2000, 2000, 2000, 3500, 3500, 4500, 4500], + ] + ), + dofs_idx, + ) + print("=== kp ===\n", franka.get_dofs_kp()) + franka.set_dofs_kv( + np.array( + [ + [450, 450, 350, 350, 200, 200, 200, 10, 10], + [10, 10, 200, 200, 200, 350, 350, 450, 450], + ] + ), + dofs_idx, + ) + print("=== kv ===\n", franka.get_dofs_kv()) + franka.set_dofs_force_range( + np.array( + [ + [-87, -87, -87, -87, -12, -12, -12, -100, -100], + [-120, -100, -12, -12, -12, -87, -87, -87, -87], + ] + ), + np.array( + [ + [87, 87, 87, 87, 12, 12, 12, 100, 100], + [100, 100, 12, 12, 12, 87, 87, 87, 87], + ] + ), + dofs_idx, + ) + print("=== force range ===\n", franka.get_dofs_force_range()) + franka.set_dofs_armature( + np.array( + [ + [0.1] * len(dofs_idx), + [0.2] * len(dofs_idx), + ] + ), + dofs_idx, + ) + print("=== armature ===\n", franka.get_dofs_armature()) + franka.set_dofs_stiffness( + np.array( + [ + [0.0] * len(dofs_idx), + [0.1] * len(dofs_idx), + ] + ), + dofs_idx, + ) + print("=== stiffness ===\n", franka.get_dofs_stiffness()) + franka.set_dofs_invweight( + np.array( + [ + [5.5882, 0.9693, 6.8053, 3.9007, 7.8085, 6.6139, 9.4213, 8.6984, 8.6984], + [8.6984, 8.6984, 9.4213, 6.6139, 7.8085, 3.9007, 6.8053, 0.9693, 5.5882], + ] + ), + dofs_idx, + ) + print("=== invweight ===\n", franka.get_dofs_invweight()) + franka.set_dofs_damping( + np.array( + [ + [1.0] * len(dofs_idx), + [2.0] * len(dofs_idx), + ] + ), + dofs_idx, + ) + print("=== damping ===\n", franka.get_dofs_damping()) + franka.set_links_inertial_mass( + np.array( + [ + [0.6298, 4.9707, 0.6469, 3.2286, 3.5879, 1.2259, 1.6666, 0.7355, 0.7300, 0.0150, 0.0150], + [0.015, 0.015, 0.73, 0.7355, 1.6666, 1.2259, 3.5879, 3.2286, 0.6469, 4.9707, 0.6298], + ] + ), + links_idx, + ) + print("=== links inertial mass ===\n", franka.get_links_inertial_mass()) + franka.set_links_invweight( + np.array( + [ + [0.0, 3.6037e-05, 0.00030664, 0.025365, 0.036351, 0.072328, 0.089559, 0.11661, 0.11288, 3.0179, 3.0179], + [3.0179, 3.0179, 0.11288, 0.11661, 0.089559, 0.072328, 0.036351, 0.025365, 0.00030664, 3.6037e-05, 0.0], + ] + ), + links_idx, + ) + print("=== links invweight ===\n", franka.get_links_invweight()) + + # Hard reset + for i in range(150): + if i < 50: + franka.set_dofs_position( + np.array([1, 1, 0, 0, 0, 0, 0, 0.04, 0.04])[None, :].repeat(scene.n_envs, 0), dofs_idx + ) + elif i < 100: + franka.set_dofs_position( + np.array([-1, 0.8, 1, -2, 1, 0.5, -0.5, 0.04, 0.04])[None, :].repeat(scene.n_envs, 0), dofs_idx + ) + else: + franka.set_dofs_position(np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])[None, :].repeat(scene.n_envs, 0), dofs_idx) + + scene.step() + + # PD control + for i in range(1250): + if i == 0: + franka.control_dofs_position( + np.array([1, 1, 0, 0, 0, 0, 0, 0.04, 0.04])[None, :].repeat(scene.n_envs, 0), + dofs_idx, + ) + elif i == 250: + franka.control_dofs_position( + np.array([-1, 0.8, 1, -2, 1, 0.5, -0.5, 0.04, 0.04])[None, :].repeat(scene.n_envs, 0), + dofs_idx, + ) + elif i == 500: + franka.control_dofs_position( + np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])[None, :].repeat(scene.n_envs, 0), + dofs_idx, + ) + elif i == 750: + # control first dof with velocity, and the rest with position + franka.control_dofs_position( + np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])[1:][None, :].repeat(scene.n_envs, 0), + dofs_idx[1:], + ) + franka.control_dofs_velocity( + np.array([1.0, 0, 0, 0, 0, 0, 0, 0, 0])[:1][None, :].repeat(scene.n_envs, 0), + dofs_idx[:1], + ) + elif i == 1000: + franka.control_dofs_force( + np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])[None, :].repeat(scene.n_envs, 0), + dofs_idx, + ) + # This is the internal control force computed based on the given control command + # If using force control, it's the same as the given control command + print("control force:", franka.get_dofs_control_force(dofs_idx)) + + scene.step() + + +if __name__ == "__main__": + main() diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index 0517c80a..955af57c 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -666,7 +666,8 @@ def _func_get_jacobian(self, tgt_link_idx, i_b, pos_mask, rot_mask): tgt_link_pos = self._solver.links_state[tgt_link_idx, i_b].pos i_l = tgt_link_idx while i_l > -1: - l_info = self._solver.links_info[i_l] + I_l = [i_l, i_b] if ti.static(self.solver._options.batch_links_info) else i_l + l_info = self._solver.links_info[I_l] l_state = self._solver.links_state[i_l, i_b] if l_info.joint_type == gs.JOINT_TYPE.FIXED: @@ -674,8 +675,9 @@ def _func_get_jacobian(self, tgt_link_idx, i_b, pos_mask, rot_mask): elif l_info.joint_type == gs.JOINT_TYPE.REVOLUTE: i_d = l_info.dof_start + I_d = [i_d, i_b] if ti.static(self.solver._options.batch_dofs_info) else i_d i_d_jac = i_d - self._dof_start - rotation = gu.ti_transform_by_quat(self._solver.dofs_info[i_d].motion_ang, l_state.quat) + rotation = gu.ti_transform_by_quat(self._solver.dofs_info[I_d].motion_ang, l_state.quat) translation = rotation.cross(tgt_link_pos - l_state.pos) self._jacobian[0, i_d_jac, i_b] = translation[0] * pos_mask[0] @@ -687,8 +689,9 @@ def _func_get_jacobian(self, tgt_link_idx, i_b, pos_mask, rot_mask): elif l_info.joint_type == gs.JOINT_TYPE.PRISMATIC: i_d = l_info.dof_start + I_d = [i_d, i_b] if ti.static(self.solver._options.batch_dofs_info) else i_d i_d_jac = i_d - self._dof_start - translation = gu.ti_transform_by_quat(self._solver.dofs_info[i_d].motion_vel, l_state.quat) + translation = gu.ti_transform_by_quat(self._solver.dofs_info[I_d].motion_vel, l_state.quat) self._jacobian[0, i_d_jac, i_b] = translation[0] * pos_mask[0] self._jacobian[1, i_d_jac, i_b] = translation[1] * pos_mask[1] @@ -706,7 +709,8 @@ def _func_get_jacobian(self, tgt_link_idx, i_b, pos_mask, rot_mask): for i_d_ in range(3): i_d = l_info.dof_start + i_d_ + 3 i_d_jac = i_d - self._dof_start - rotation = self._solver.dofs_info[i_d].motion_ang + I_d = [i_d, i_b] if ti.static(self.solver._options.batch_dofs_info) else i_d + rotation = self._solver.dofs_info[I_d].motion_ang translation = rotation.cross(tgt_link_pos - l_state.pos) self._jacobian[0, i_d_jac, i_b] = translation[0] * pos_mask[0] @@ -1156,8 +1160,14 @@ def _kernel_inverse_kinematics( # Resample init q if respect_joint_limit and i_sample < max_samples - 1: for i_l in range(self.link_start, self.link_end): - l_info = self._solver.links_info[i_l] - dof_info = self._solver.dofs_info[l_info.dof_start] + I_l = [i_l, i_b] if ti.static(self.solver._options.batch_links_info) else i_l + l_info = self._solver.links_info[I_l] + I_dof_start = ( + [l_info.dof_start, i_b] + if ti.static(self.solver._options.batch_dofs_info) + else l_info.dof_start + ) + dof_info = self._solver.dofs_info[I_dof_start] q_start = l_info.q_start if l_info.joint_type == gs.JOINT_TYPE.FREE: @@ -1561,6 +1571,14 @@ def get_links_ang(self, envs_idx=None): """ return self._solver.get_links_ang(np.arange(self.link_start, self.link_end), envs_idx) + @gs.assert_built + def get_links_inertial_mass(self, ls_idx_local=None, envs_idx=None): + return self._solver.get_links_inertial_mass(self._get_ls_idx(ls_idx_local), envs_idx) + + @gs.assert_built + def get_links_invweight(self, ls_idx_local=None, envs_idx=None): + return self._solver.get_links_invweight(self._get_ls_idx(ls_idx_local), envs_idx) + @gs.assert_built def set_pos(self, pos, zero_velocity=True, envs_idx=None): """ @@ -1694,6 +1712,18 @@ def _get_dofs_idx_local(self, dofs_idx_local=None): def _get_dofs_idx(self, dofs_idx_local=None): return self._get_dofs_idx_local(dofs_idx_local) + self._dof_start + def _get_ls_idx_local(self, ls_idx_local=None): + if ls_idx_local is None: + ls_idx_local = torch.arange(self.n_links, dtype=torch.int32, device=gs.device) + else: + ls_idx_local = torch.as_tensor(ls_idx_local, dtype=gs.tc_int) + if (ls_idx_local < 0).any() or (ls_idx_local >= self.n_links).any(): + gs.raise_exception("`ls_idx_local` exceeds valid range.") + return ls_idx_local + + def _get_ls_idx(self, ls_idx_local=None): + return self._get_ls_idx_local(ls_idx_local) + self._link_start + @gs.assert_built def set_qpos(self, qpos, qs_idx_local=None, zero_velocity=True, envs_idx=None): """ @@ -1716,7 +1746,7 @@ def set_qpos(self, qpos, qs_idx_local=None, zero_velocity=True, envs_idx=None): self.zero_all_dofs_velocity(envs_idx) @gs.assert_built - def set_dofs_kp(self, kp, dofs_idx_local=None): + def set_dofs_kp(self, kp, dofs_idx_local=None, envs_idx=None): """ Set the entity's dofs' positional gains for the PD controller. @@ -1726,12 +1756,14 @@ def set_dofs_kp(self, kp, dofs_idx_local=None): The positional gains to set. dofs_idx_local : None | array_like, optional The indices of the dofs to set. If None, all dofs will be set. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None. + envs_idx : None | array_like, optional + The indices of the environments. If None, all environments will be considered. Defaults to None. """ - self._solver.set_dofs_kp(kp, self._get_dofs_idx(dofs_idx_local)) + self._solver.set_dofs_kp(kp, self._get_dofs_idx(dofs_idx_local), envs_idx) @gs.assert_built - def set_dofs_kv(self, kv, dofs_idx_local=None): + def set_dofs_kv(self, kv, dofs_idx_local=None, envs_idx=None): """ Set the entity's dofs' velocity gains for the PD controller. @@ -1741,11 +1773,13 @@ def set_dofs_kv(self, kv, dofs_idx_local=None): The velocity gains to set. dofs_idx_local : None | array_like, optional The indices of the dofs to set. If None, all dofs will be set. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None. + envs_idx : None | array_like, optional + The indices of the environments. If None, all environments will be considered. Defaults to None. """ - self._solver.set_dofs_kv(kv, self._get_dofs_idx(dofs_idx_local)) + self._solver.set_dofs_kv(kv, self._get_dofs_idx(dofs_idx_local), envs_idx) @gs.assert_built - def set_dofs_force_range(self, lower, upper, dofs_idx_local=None): + def set_dofs_force_range(self, lower, upper, dofs_idx_local=None, envs_idx=None): """ Set the entity's dofs' force range. @@ -1757,9 +1791,27 @@ def set_dofs_force_range(self, lower, upper, dofs_idx_local=None): The upper bounds of the force range. dofs_idx_local : None | array_like, optional The indices of the dofs to set. If None, all dofs will be set. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None. + envs_idx : None | array_like, optional + The indices of the environments. If None, all environments will be considered. Defaults to None. """ - self._solver.set_dofs_force_range(lower, upper, self._get_dofs_idx(dofs_idx_local)) + self._solver.set_dofs_force_range(lower, upper, self._get_dofs_idx(dofs_idx_local), envs_idx) + + @gs.assert_built + def set_dofs_stiffness(self, stiffness, dofs_idx_local=None, envs_idx=None): + self._solver.set_dofs_stiffness(stiffness, self._get_dofs_idx(dofs_idx_local), envs_idx) + + @gs.assert_built + def set_dofs_invweight(self, invweight, dofs_idx_local=None, envs_idx=None): + self._solver.set_dofs_invweight(invweight, self._get_dofs_idx(dofs_idx_local), envs_idx) + + @gs.assert_built + def set_dofs_armature(self, armature, dofs_idx_local=None, envs_idx=None): + self._solver.set_dofs_armature(armature, self._get_dofs_idx(dofs_idx_local), envs_idx) + + @gs.assert_built + def set_dofs_damping(self, damping, dofs_idx_local=None, envs_idx=None): + self._solver.set_dofs_damping(damping, self._get_dofs_idx(dofs_idx_local), envs_idx) @gs.assert_built def set_dofs_velocity(self, velocity, dofs_idx_local=None, envs_idx=None): @@ -1948,7 +2000,7 @@ def get_dofs_position(self, dofs_idx_local=None, envs_idx=None): return self._solver.get_dofs_position(self._get_dofs_idx(dofs_idx_local), envs_idx) @gs.assert_built - def get_dofs_kp(self, dofs_idx_local=None): + def get_dofs_kp(self, dofs_idx_local=None, envs_idx=None): """ Get the positional gain (kp) for the entity's dofs used by the PD controller. @@ -1956,16 +2008,18 @@ def get_dofs_kp(self, dofs_idx_local=None): ---------- dofs_idx_local : None | array_like, optional The indices of the dofs to get. If None, all dofs will be returned. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None. + envs_idx : None | array_like, optional + The indices of the environments. If None, all environments will be considered. Defaults to None. Returns ------- - kp : torch.Tensor, shape (n_dofs,) + kp : torch.Tensor, shape (n_dofs,) or (n_envs, n_dofs) The positional gain (kp) for the entity's dofs. """ - return self._solver.get_dofs_kp(self._get_dofs_idx(dofs_idx_local)) + return self._solver.get_dofs_kp(self._get_dofs_idx(dofs_idx_local), envs_idx) @gs.assert_built - def get_dofs_kv(self, dofs_idx_local=None): + def get_dofs_kv(self, dofs_idx_local=None, envs_idx=None): """ Get the velocity gain (kv) for the entity's dofs used by the PD controller. @@ -1973,16 +2027,18 @@ def get_dofs_kv(self, dofs_idx_local=None): ---------- dofs_idx_local : None | array_like, optional The indices of the dofs to get. If None, all dofs will be returned. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None. + envs_idx : None | array_like, optional + The indices of the environments. If None, all environments will be considered. Defaults to None. Returns ------- - kv : torch.Tensor, shape (n_dofs,) + kv : torch.Tensor, shape (n_dofs,) or (n_envs, n_dofs) The velocity gain (kv) for the entity's dofs. """ - return self._solver.get_dofs_kv(self._get_dofs_idx(dofs_idx_local)) + return self._solver.get_dofs_kv(self._get_dofs_idx(dofs_idx_local), envs_idx) @gs.assert_built - def get_dofs_force_range(self, dofs_idx_local=None): + def get_dofs_force_range(self, dofs_idx_local=None, envs_idx=None): """ Get the force range (min and max limits) for the entity's dofs. @@ -1990,18 +2046,20 @@ def get_dofs_force_range(self, dofs_idx_local=None): ---------- dofs_idx_local : None | array_like, optional The indices of the dofs to get. If None, all dofs will be returned. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None. + envs_idx : None | array_like, optional + The indices of the environments. If None, all environments will be considered. Defaults to None. Returns ------- - lower_limit : torch.Tensor, shape (n_dofs,) + lower_limit : torch.Tensor, shape (n_dofs,) or (n_envs, n_dofs) The lower limit of the force range for the entity's dofs. - upper_limit : torch.Tensor, shape (n_dofs,) + upper_limit : torch.Tensor, shape (n_dofs,) or (n_envs, n_dofs) The upper limit of the force range for the entity's dofs. """ - return self._solver.get_dofs_force_range(self._get_dofs_idx(dofs_idx_local)) + return self._solver.get_dofs_force_range(self._get_dofs_idx(dofs_idx_local), envs_idx) @gs.assert_built - def get_dofs_limit(self, dofs_idx=None): + def get_dofs_limit(self, dofs_idx=None, envs_idx=None): """ Get the positional limits (min and max) for the entity's dofs. @@ -2009,15 +2067,33 @@ def get_dofs_limit(self, dofs_idx=None): ---------- dofs_idx : None | array_like, optional The indices of the dofs to get. If None, all dofs will be returned. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None. + envs_idx : None | array_like, optional + The indices of the environments. If None, all environments will be considered. Defaults to None. Returns ------- - lower_limit : torch.Tensor, shape (n_dofs,) + lower_limit : torch.Tensor, shape (n_dofs,) or (n_envs, n_dofs) The lower limit of the positional limits for the entity's dofs. - upper_limit : torch.Tensor, shape (n_dofs,) + upper_limit : torch.Tensor, shape (n_dofs,) or (n_envs, n_dofs) The upper limit of the positional limits for the entity's dofs. """ - return self._solver.get_dofs_limit(self._get_dofs_idx(dofs_idx)) + return self._solver.get_dofs_limit(self._get_dofs_idx(dofs_idx), envs_idx) + + @gs.assert_built + def get_dofs_stiffness(self, dofs_idx_local=None, envs_idx=None): + return self._solver.get_dofs_stiffness(self._get_dofs_idx(dofs_idx_local), envs_idx) + + @gs.assert_built + def get_dofs_invweight(self, dofs_idx_local=None, envs_idx=None): + return self._solver.get_dofs_invweight(self._get_dofs_idx(dofs_idx_local), envs_idx) + + @gs.assert_built + def get_dofs_armature(self, dofs_idx_local=None, envs_idx=None): + return self._solver.get_dofs_armature(self._get_dofs_idx(dofs_idx_local), envs_idx) + + @gs.assert_built + def get_dofs_damping(self, dofs_idx_local=None, envs_idx=None): + return self._solver.get_dofs_damping(self._get_dofs_idx(dofs_idx_local), envs_idx) @gs.assert_built def zero_all_dofs_velocity(self, envs_idx=None): @@ -2251,6 +2327,14 @@ def set_COM_shift(self, com_shift, link_indices, envs_idx=None): link_indices[i] += self._link_start self._solver.set_links_COM_shift(com_shift, link_indices, envs_idx) + @gs.assert_built + def set_links_inertial_mass(self, inertial_mass, ls_idx_local=None, envs_idx=None): + self._solver.set_links_inertial_mass(inertial_mass, self._get_ls_idx(ls_idx_local), envs_idx) + + @gs.assert_built + def set_links_invweight(self, invweight, ls_idx_local=None, envs_idx=None): + self._solver.set_links_invweight(invweight, self._get_ls_idx(ls_idx_local), envs_idx) + @gs.assert_built def get_mass(self): """ diff --git a/genesis/engine/entities/rigid_entity/rigid_joint.py b/genesis/engine/entities/rigid_entity/rigid_joint.py index 660971d9..e81c25b2 100644 --- a/genesis/engine/entities/rigid_entity/rigid_joint.py +++ b/genesis/engine/entities/rigid_entity/rigid_joint.py @@ -90,7 +90,8 @@ def get_pos(self): def _kernel_get_pos(self, tensor: ti.types.ndarray()): for i_b in range(self._solver._B): - l_info = self._solver.links_info[self._idx, i_b] + I_l = [self._idx, i_b] if ti.static(self._solver._options.batch_links_info) else self._idx + l_info = self._solver.links_info[I_l] i_p = l_info.parent_idx p_pos = ti.Vector.zero(gs.ti_float, 3) @@ -124,7 +125,8 @@ def get_quat(self): def _kernel_get_quat(self, tensor: ti.types.ndarray()): for i_b in range(self._solver._B): - l_info = self._solver.links_info[self._idx, i_b] + I_l = [self._idx, i_b] if ti.static(self._solver._options.batch_links_info) else self._idx + l_info = self._solver.links_info[I_l] i_p = l_info.parent_idx p_pos = ti.Vector.zero(gs.ti_float, 3) diff --git a/genesis/engine/solvers/rigid/collider_decomp.py b/genesis/engine/solvers/rigid/collider_decomp.py index 968ad214..35a7756f 100644 --- a/genesis/engine/solvers/rigid/collider_decomp.py +++ b/genesis/engine/solvers/rigid/collider_decomp.py @@ -51,6 +51,10 @@ def _init_collision_fields(self): links_root_idx = self._solver.links_info.root_idx.to_numpy() links_parent_idx = self._solver.links_info.parent_idx.to_numpy() links_is_fixed = self._solver.links_info.is_fixed.to_numpy() + if self._solver._options.batch_links_info: + links_root_idx = links_root_idx[:, 0] + links_parent_idx = links_parent_idx[:, 0] + links_is_fixed = links_is_fixed[:, 0] n_possible_pairs = 0 for i in range(self._solver.n_geoms): for j in range(i + 1, self._solver.n_geoms): @@ -164,10 +168,13 @@ def clear(self): i_la = self.contact_data[i_c, i_b].link_a i_lb = self.contact_data[i_c, i_b].link_b + I_la = [i_la, i_b] if ti.static(self._solver._options.batch_links_info) else i_la + I_lb = [i_lb, i_b] if ti.static(self._solver._options.batch_links_info) else i_lb + # pair of hibernated-fixed links -> hibernated contact # TODO: we should also include hibernated-hibernated links and wake up the whole contact island once a new collision is detected - if (self._solver.links_state[i_la, i_b].hibernated and self._solver.links_info[i_lb].is_fixed) or ( - self._solver.links_state[i_lb, i_b].hibernated and self._solver.links_info[i_la].is_fixed + if (self._solver.links_state[i_la, i_b].hibernated and self._solver.links_info[I_lb].is_fixed) or ( + self._solver.links_state[i_lb, i_b].hibernated and self._solver.links_info[I_la].is_fixed ): i_c_hibernated = self.n_contacts_hibernated[i_b] if i_c != i_c_hibernated: @@ -549,6 +556,8 @@ def _func_update_aabbs(self): def _func_check_collision_valid(self, i_ga, i_gb, i_b): i_la = self._solver.geoms_info[i_ga].link_idx i_lb = self._solver.geoms_info[i_gb].link_idx + I_la = [i_la, i_b] if ti.static(self._solver._options.batch_links_info) else i_la + I_lb = [i_lb, i_b] if ti.static(self._solver._options.batch_links_info) else i_lb is_valid = True # geoms in the same link @@ -558,22 +567,22 @@ def _func_check_collision_valid(self, i_ga, i_gb, i_b): # self collision if ( ti.static(not self._solver._enable_self_collision) - and self._solver.links_info[i_la].root_idx == self._solver.links_info[i_lb].root_idx + and self._solver.links_info[I_la].root_idx == self._solver.links_info[I_lb].root_idx ): is_valid = False # adjacent links - if self._solver.links_info[i_la].parent_idx == i_lb or self._solver.links_info[i_lb].parent_idx == i_la: + if self._solver.links_info[I_la].parent_idx == i_lb or self._solver.links_info[I_lb].parent_idx == i_la: is_valid = False # pair of fixed links - if self._solver.links_info[i_la].is_fixed and self._solver.links_info[i_lb].is_fixed: + if self._solver.links_info[I_la].is_fixed and self._solver.links_info[I_lb].is_fixed: is_valid = False # hibernated <-> fixed links if ti.static(self._solver._use_hibernation): - if (self._solver.links_state[i_la, i_b].hibernated and self._solver.links_info[i_lb].is_fixed) or ( - self._solver.links_state[i_lb, i_b].hibernated and self._solver.links_info[i_la].is_fixed + if (self._solver.links_state[i_la, i_b].hibernated and self._solver.links_info[I_lb].is_fixed) or ( + self._solver.links_state[i_lb, i_b].hibernated and self._solver.links_info[I_la].is_fixed ): is_valid = False @@ -994,7 +1003,9 @@ def _func_mpr(self, i_ga, i_gb, i_b): i_la = self._solver.geoms_info[i_ga].link_idx i_lb = self._solver.geoms_info[i_gb].link_idx - is_self_pair = self._solver.links_info.root_idx[i_la] == self._solver.links_info.root_idx[i_lb] + I_la = [i_la, i_b] if ti.static(self._solver._options.batch_links_info) else i_la + I_lb = [i_lb, i_b] if ti.static(self._solver._options.batch_links_info) else i_lb + is_self_pair = self._solver.links_info.root_idx[I_la] == self._solver.links_info.root_idx[I_lb] multi_contact = ( self._solver.geoms_info[i_ga].type != gs.GEOM_TYPE.SPHERE and self._solver.geoms_info[i_gb].type != gs.GEOM_TYPE.SPHERE diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp.py b/genesis/engine/solvers/rigid/constraint_solver_decomp.py index c680c31f..d3592b92 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp.py @@ -105,14 +105,16 @@ def add_collision_constraints(self): impact = self._collider.contact_data[i_col, i_b] link_a = impact.link_a link_b = impact.link_b + link_a_maybe_batch = [link_a, i_b] if ti.static(self._solver._options.batch_links_info) else link_a + link_b_maybe_batch = [link_b, i_b] if ti.static(self._solver._options.batch_links_info) else link_b f = impact.friction pos = impact.pos d1, d2 = gu.orthogonals(impact.normal) - t = self._solver.links_info[link_a].invweight + self._solver.links_info[link_b].invweight * ( - link_b > -1 - ) + t = self._solver.links_info[link_a_maybe_batch].invweight + self._solver.links_info[ + link_b_maybe_batch + ].invweight * (link_b > -1) for i in range(4): n = -d1 * f - impact.normal if i == 1: @@ -143,10 +145,13 @@ def add_collision_constraints(self): link = link_b while link > -1: + link_maybe_batch = ( + [link, i_b] if ti.static(self._solver._options.batch_links_info) else link + ) # reverse order to make sure dofs in each row of self.jac_relevant_dofs is strictly descending - for i_d_ in range(self._solver.links_info[link].n_dofs): - i_d = self._solver.links_info[link].dof_end - 1 - i_d_ + for i_d_ in range(self._solver.links_info[link_maybe_batch].n_dofs): + i_d = self._solver.links_info[link_maybe_batch].dof_end - 1 - i_d_ cdof_ang = self._solver.dofs_state[i_d, i_b].cdof_ang cdot_vel = self._solver.dofs_state[i_d, i_b].cdof_vel @@ -164,7 +169,7 @@ def add_collision_constraints(self): self.jac_relevant_dofs[n_con, con_n_relevant_dofs, i_b] = i_d con_n_relevant_dofs += 1 - link = self._solver.links_info[link].parent_idx + link = self._solver.links_info[link_maybe_batch].parent_idx if ti.static(self.sparse_solve): self.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs @@ -183,21 +188,23 @@ def add_joint_limit_constraints(self): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) for i_b in range(self._B): for i_l in range(self._solver.n_links): - l_info = self._solver.links_info[i_l] + I_l = [i_l, i_b] if ti.static(self._solver._options.batch_links_info) else i_l + l_info = self._solver.links_info[I_l] if l_info.joint_type == gs.JOINT_TYPE.REVOLUTE or l_info.joint_type == gs.JOINT_TYPE.PRISMATIC: i_q = l_info.q_start i_d = l_info.dof_start - pos_min = self._solver.qpos[i_q, i_b] - self._solver.dofs_info[i_d].limit[0] - pos_max = self._solver.dofs_info[i_d].limit[1] - self._solver.qpos[i_q, i_b] + I_d = [i_d, i_b] if ti.static(self._solver._options.batch_dofs_info) else i_d + pos_min = self._solver.qpos[i_q, i_b] - self._solver.dofs_info[I_d].limit[0] + pos_max = self._solver.dofs_info[I_d].limit[1] - self._solver.qpos[i_q, i_b] pos = min(min(pos_min, pos_max), 0) side = ((pos_min < pos_max) * 2 - 1) * (pos < 0) jac = side jac_qvel = jac * self._solver.dofs_state[i_d, i_b].vel - imp, aref = gu.imp_aref(self._solver.dofs_info[i_d].sol_params, pos, jac_qvel) - diag = self._solver.dofs_info[i_d].invweight * (pos < 0) * (1 - imp) / (imp + gs.EPS) + imp, aref = gu.imp_aref(self._solver.dofs_info[I_d].sol_params, pos, jac_qvel) + diag = self._solver.dofs_info[I_d].invweight * (pos < 0) * (1 - imp) / (imp + gs.EPS) aref = aref * (pos < 0) if pos < 0: n_con = self.n_constraints[i_b] diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py b/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py index e33867a0..4a7d2257 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py @@ -131,12 +131,16 @@ def add_collision_constraints(self, island, i_b): impact = self._collider.contact_data[i_col, i_b] link_a = impact.link_a link_b = impact.link_b + link_a_maybe_batch = [link_a, i_b] if ti.static(self._solver._options.batch_links_info) else link_a + link_b_maybe_batch = [link_b, i_b] if ti.static(self._solver._options.batch_links_info) else link_b f = impact.friction pos = impact.pos d1, d2 = gu.orthogonals(impact.normal) - t = self._solver.links_info[link_a].invweight + self._solver.links_info[link_b].invweight * (link_b > -1) + t = self._solver.links_info[link_a_maybe_batch].invweight + self._solver.links_info[ + link_b_maybe_batch + ].invweight * (link_b > -1) for i in range(4): n = -d1 * f - impact.normal @@ -170,10 +174,11 @@ def add_collision_constraints(self, island, i_b): link = link_b while link > -1: + link_maybe_batch = [link, i_b] if ti.static(self._solver._options.batch_links_info) else link # reverse order to make sure dofs in each row of self.jac_relevant_dofs is strictly descending for i_d_ in range(self._solver.links_info[link].n_dofs): - i_d = self._solver.links_info[link].dof_end - 1 - i_d_ + i_d = self._solver.links_info[link_maybe_batch].dof_end - 1 - i_d_ cdof_ang = self._solver.dofs_state[i_d, i_b].cdof_ang cdot_vel = self._solver.dofs_state[i_d, i_b].cdof_vel @@ -190,7 +195,7 @@ def add_collision_constraints(self, island, i_b): self.jac_relevant_dofs[n_con, con_n_relevant_dofs, i_b] = i_d con_n_relevant_dofs += 1 - link = self._solver.links_info[link].parent_idx + link = self._solver.links_info[link_maybe_batch].parent_idx if ti.static(self.sparse_solve): self.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs @@ -207,8 +212,8 @@ def add_collision_constraints(self, island, i_b): if ti.static(self._solver._use_hibernation): # wake up entities - self._solver._func_wakeup_entity(self._solver.links_info[link_a].entity_idx, i_b) - self._solver._func_wakeup_entity(self._solver.links_info[link_b].entity_idx, i_b) + self._solver._func_wakeup_entity(self._solver.links_info[link_a_maybe_batch].entity_idx, i_b) + self._solver._func_wakeup_entity(self._solver.links_info[link_b_maybe_batch].entity_idx, i_b) @ti.func def add_joint_limit_constraints(self, island, i_b): @@ -219,22 +224,23 @@ def add_joint_limit_constraints(self, island, i_b): e_info = self.entities_info[i_e] for i_l in range(e_info.link_start, e_info.link_end): - - l_info = self._solver.links_info[i_l] + I_l = [i_l, i_b] if ti.static(self._solver._options.batch_links_info) else i_l + l_info = self._solver.links_info[I_l] if l_info.joint_type == gs.JOINT_TYPE.REVOLUTE or l_info.joint_type == gs.JOINT_TYPE.PRISMATIC: i_q = l_info.q_start i_d = l_info.dof_start - pos_min = self._solver.qpos[i_q, i_b] - self._solver.dofs_info[i_d].limit[0] - pos_max = self._solver.dofs_info[i_d].limit[1] - self._solver.qpos[i_q, i_b] + I_d = [i_d, i_b] if ti.static(self._solver._options.batch_dofs_info) else i_d + pos_min = self._solver.qpos[i_q, i_b] - self._solver.dofs_info[I_d].limit[0] + pos_max = self._solver.dofs_info[I_d].limit[1] - self._solver.qpos[i_q, i_b] pos = min(min(pos_min, pos_max), 0) side = ((pos_min < pos_max) * 2 - 1) * (pos < 0) jac = side jac_qvel = jac * self._solver.dofs_state[i_d, i_b].vel - imp, aref = gu.imp_aref(self._solver.dofs_info[i_d].sol_params, pos, jac_qvel) - diag = self._solver.dofs_info[i_d].invweight * (pos < 0) * (1 - imp) / (imp + gs.EPS) + imp, aref = gu.imp_aref(self._solver.dofs_info[I_d].sol_params, pos, jac_qvel) + diag = self._solver.dofs_info[I_d].invweight * (pos < 0) * (1 - imp) / (imp + gs.EPS) aref = aref * (pos < 0) if pos < 0: n_con = self.n_constraints[i_b] diff --git a/genesis/engine/solvers/rigid/contact_island.py b/genesis/engine/solvers/rigid/contact_island.py index 2914b503..97ae3b93 100644 --- a/genesis/engine/solvers/rigid/contact_island.py +++ b/genesis/engine/solvers/rigid/contact_island.py @@ -66,8 +66,11 @@ def clear(self): @ti.func def add_edge(self, link_a, link_b, i_b): - ea = self.solver.links_info[link_a].entity_idx - eb = self.solver.links_info[link_b].entity_idx + link_a_maybe_batch = [link_a, i_b] if ti.static(self.solver._options.batch_links_info) else link_a + link_b_maybe_batch = [link_b, i_b] if ti.static(self.solver._options.batch_links_info) else link_b + + ea = self.solver.links_info[link_a_maybe_batch].entity_idx + eb = self.solver.links_info[link_b_maybe_batch].entity_idx self.entity_edge[ea, i_b].n = self.entity_edge[ea, i_b].n + 1 self.entity_edge[eb, i_b].n = self.entity_edge[eb, i_b].n + 1 @@ -100,9 +103,11 @@ def postprocess_island(self): impact = self.collider.contact_data[i_col, i_b] link_a = impact.link_a link_b = impact.link_b + link_a_maybe_batch = [link_a, i_b] if ti.static(self.solver._options.batch_links_info) else link_a + link_b_maybe_batch = [link_b, i_b] if ti.static(self.solver._options.batch_links_info) else link_b - ea = self.solver.links_info[link_a].entity_idx - eb = self.solver.links_info[link_b].entity_idx + ea = self.solver.links_info[link_a_maybe_batch].entity_idx + eb = self.solver.links_info[link_b_maybe_batch].entity_idx island_a = self.entity_island[ea, i_b] island_b = self.entity_island[eb, i_b] diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index a6abeb53..b6dd3a17 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -159,6 +159,11 @@ def _init_invweight(self): dof_end = self.links_info.dof_end.to_numpy() n_dofs = self.links_info.n_dofs.to_numpy() parent_idx = self.links_info.parent_idx.to_numpy() + if self._options.batch_links_info: + dof_start = dof_start[:, 0] + dof_end = dof_end[:, 0] + n_dofs = n_dofs[:, 0] + parent_idx = parent_idx[:, 0] offsets = self.links_state.i_pos.to_numpy()[:, 0, :] @@ -209,9 +214,9 @@ def _kernel_init_invweight( invweight: ti.types.ndarray(), ): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) - for i, i_b in ti.ndrange(self.n_links, self._B): - if self.links_info[i].invweight < 0: - self.links_info[i].invweight = invweight[i] + for I in ti.grouped(self.links_info): + if self.links_info[I].invweight < 0: + self.links_info[I].invweight = invweight[I[0]] def _batch_shape(self, shape=None, first_dim=False, B=None): if B is None: @@ -300,25 +305,28 @@ def _init_dof_fields(self): ctrl_mode=gs.ti_int, hibernated=gs.ti_int, # Flag for dofs that converge into a static state (hibernation) ) - self.dofs_info = struct_dof_info.field(shape=self.n_dofs_, needs_grad=False, layout=ti.Layout.SOA) + dofs_info_shape = self._batch_shape(self.n_dofs_) if self._options.batch_dofs_info else self.n_dofs_ + self.dofs_info = struct_dof_info.field(shape=dofs_info_shape, needs_grad=False, layout=ti.Layout.SOA) self.dofs_state = struct_dof_state.field( shape=self._batch_shape(self.n_dofs_), needs_grad=False, layout=ti.Layout.SOA ) joints = self.joints - self._kernel_init_dof_fields( - dofs_motion_ang=np.concatenate([joint.dofs_motion_ang for joint in joints], dtype=gs.np_float), - dofs_motion_vel=np.concatenate([joint.dofs_motion_vel for joint in joints], dtype=gs.np_float), - dofs_limit=np.concatenate([joint.dofs_limit for joint in joints], dtype=gs.np_float), - dofs_invweight=np.concatenate([joint.dofs_invweight for joint in joints], dtype=gs.np_float), - dofs_stiffness=np.concatenate([joint.dofs_stiffness for joint in joints], dtype=gs.np_float), - dofs_sol_params=np.concatenate([joint.dofs_sol_params for joint in joints], dtype=gs.np_float), - dofs_damping=np.concatenate([joint.dofs_damping for joint in joints], dtype=gs.np_float), - dofs_armature=np.concatenate([joint.dofs_armature for joint in joints], dtype=gs.np_float), - dofs_kp=np.concatenate([joint.dofs_kp for joint in joints], dtype=gs.np_float), - dofs_kv=np.concatenate([joint.dofs_kv for joint in joints], dtype=gs.np_float), - dofs_force_range=np.concatenate([joint.dofs_force_range for joint in joints], dtype=gs.np_float), - ) + is_nonempty = np.concatenate([joint.dofs_motion_ang for joint in joints], dtype=gs.np_float).shape[0] > 0 + if is_nonempty: # handle the case where there is a link with no dofs -- otherwise may cause invalid memory + self._kernel_init_dof_fields( + dofs_motion_ang=np.concatenate([joint.dofs_motion_ang for joint in joints], dtype=gs.np_float), + dofs_motion_vel=np.concatenate([joint.dofs_motion_vel for joint in joints], dtype=gs.np_float), + dofs_limit=np.concatenate([joint.dofs_limit for joint in joints], dtype=gs.np_float), + dofs_invweight=np.concatenate([joint.dofs_invweight for joint in joints], dtype=gs.np_float), + dofs_stiffness=np.concatenate([joint.dofs_stiffness for joint in joints], dtype=gs.np_float), + dofs_sol_params=np.concatenate([joint.dofs_sol_params for joint in joints], dtype=gs.np_float), + dofs_damping=np.concatenate([joint.dofs_damping for joint in joints], dtype=gs.np_float), + dofs_armature=np.concatenate([joint.dofs_armature for joint in joints], dtype=gs.np_float), + dofs_kp=np.concatenate([joint.dofs_kp for joint in joints], dtype=gs.np_float), + dofs_kv=np.concatenate([joint.dofs_kv for joint in joints], dtype=gs.np_float), + dofs_force_range=np.concatenate([joint.dofs_force_range for joint in joints], dtype=gs.np_float), + ) # just in case self.dofs_state.force.fill(0) @@ -339,26 +347,28 @@ def _kernel_init_dof_fields( dofs_force_range: ti.types.ndarray(), ): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) - for i, b in ti.ndrange(self.n_dofs, self._B): + for I in ti.grouped(self.dofs_info): + i = I[0] # batching (if any) will be the second dim + for j in ti.static(range(3)): - self.dofs_info[i].motion_ang[j] = dofs_motion_ang[i, j] - self.dofs_info[i].motion_vel[j] = dofs_motion_vel[i, j] + self.dofs_info[I].motion_ang[j] = dofs_motion_ang[i, j] + self.dofs_info[I].motion_vel[j] = dofs_motion_vel[i, j] for j in ti.static(range(2)): - self.dofs_info[i].limit[j] = dofs_limit[i, j] - self.dofs_info[i].force_range[j] = dofs_force_range[i, j] + self.dofs_info[I].limit[j] = dofs_limit[i, j] + self.dofs_info[I].force_range[j] = dofs_force_range[i, j] for j in ti.static(range(7)): - self.dofs_info[i].sol_params[j] = dofs_sol_params[i, j] + self.dofs_info[I].sol_params[j] = dofs_sol_params[i, j] - self.dofs_info[i].sol_params[0] = self._sol_contact_resolve_time + self.dofs_info[I].sol_params[0] = self._sol_contact_resolve_time - self.dofs_info[i].armature = dofs_armature[i] - self.dofs_info[i].invweight = dofs_invweight[i] - self.dofs_info[i].stiffness = dofs_stiffness[i] - self.dofs_info[i].damping = dofs_damping[i] - self.dofs_info[i].kp = dofs_kp[i] - self.dofs_info[i].kv = dofs_kv[i] + self.dofs_info[I].armature = dofs_armature[i] + self.dofs_info[I].invweight = dofs_invweight[i] + self.dofs_info[I].stiffness = dofs_stiffness[i] + self.dofs_info[I].damping = dofs_damping[i] + self.dofs_info[I].kp = dofs_kp[i] + self.dofs_info[I].kv = dofs_kv[i] ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) for i, b in ti.ndrange(self.n_dofs, self._B): @@ -442,7 +452,8 @@ def _init_link_fields(self): hibernated=gs.ti_int, ) - self.links_info = struct_link_info.field(shape=self.n_links, needs_grad=False, layout=ti.Layout.SOA) + links_info_shape = self._batch_shape(self.n_links) if self._options.batch_links_info else self.n_links + self.links_info = struct_link_info.field(shape=links_info_shape, needs_grad=False, layout=ti.Layout.SOA) self.links_state = struct_link_state.field( shape=self._batch_shape(self.n_links), needs_grad=False, layout=ti.Layout.SOA ) @@ -501,38 +512,42 @@ def _kernel_init_link_fields( links_entity_idx: ti.types.ndarray(), ): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) - for i, b in ti.ndrange(self.n_links, self._B): - self.links_info[i].parent_idx = links_parent_idx[i] - self.links_info[i].root_idx = links_root_idx[i] - self.links_info[i].q_start = links_q_start[i] - self.links_info[i].dof_start = links_dof_start[i] - self.links_info[i].q_end = links_q_end[i] - self.links_info[i].dof_end = links_dof_end[i] - self.links_info[i].n_dofs = links_dof_end[i] - links_dof_start[i] - self.links_info[i].joint_type = links_joint_type[i] - self.links_info[i].invweight = links_invweight[i] - self.links_info[i].is_fixed = links_is_fixed[i] - self.links_info[i].entity_idx = links_entity_idx[i] + for I in ti.grouped(self.links_info): + i = I[0] + + self.links_info[I].parent_idx = links_parent_idx[i] + self.links_info[I].root_idx = links_root_idx[i] + self.links_info[I].q_start = links_q_start[i] + self.links_info[I].dof_start = links_dof_start[i] + self.links_info[I].q_end = links_q_end[i] + self.links_info[I].dof_end = links_dof_end[i] + self.links_info[I].n_dofs = links_dof_end[i] - links_dof_start[i] + self.links_info[I].joint_type = links_joint_type[i] + self.links_info[I].invweight = links_invweight[i] + self.links_info[I].is_fixed = links_is_fixed[i] + self.links_info[I].entity_idx = links_entity_idx[i] for j in ti.static(range(4)): - self.links_info[i].quat[j] = links_quat[i, j] - self.links_info[i].joint_quat[j] = links_joint_quat[i, j] - self.links_info[i].inertial_quat[j] = links_inertial_quat[i, j] + self.links_info[I].quat[j] = links_quat[i, j] + self.links_info[I].joint_quat[j] = links_joint_quat[i, j] + self.links_info[I].inertial_quat[j] = links_inertial_quat[i, j] for j in ti.static(range(3)): - self.links_info[i].pos[j] = links_pos[i, j] - self.links_info[i].joint_pos[j] = links_joint_pos[i, j] - self.links_info[i].inertial_pos[j] = links_inertial_pos[i, j] + self.links_info[I].pos[j] = links_pos[i, j] + self.links_info[I].joint_pos[j] = links_joint_pos[i, j] + self.links_info[I].inertial_pos[j] = links_inertial_pos[i, j] - self.links_info[i].inertial_mass = links_inertial_mass[i] + self.links_info[I].inertial_mass = links_inertial_mass[i] for j1 in ti.static(range(3)): for j2 in ti.static(range(3)): - self.links_info[i].inertial_i[j1, j2] = links_inertial_i[i, j1, j2] + self.links_info[I].inertial_i[j1, j2] = links_inertial_i[i, j1, j2] ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) for i, b in ti.ndrange(self.n_links, self._B): + I = [i, b] if ti.static(self._options.batch_links_info) else i + # Update state for root fixed link. Their state will not be updated in forward kinematics later but can be manually changed by user. - if self.links_info[i].parent_idx == -1 and self.links_info[i].is_fixed: + if self.links_info[I].parent_idx == -1 and self.links_info[I].is_fixed: for j in ti.static(range(4)): self.links_state[i, b].quat[j] = links_quat[i, j] @@ -836,10 +851,11 @@ def _kernel_adjust_link_inertia( link_idx: ti.i32, ratio: ti.f32, ): - self.links_info[link_idx].invweight /= ratio - self.links_info[link_idx].inertial_mass *= ratio - for j1, j2 in ti.ndrange(3, 3): - self.links_info[link_idx].inertial_i[j1, j2] *= ratio + for I_l in ti.grouped(self.links_info): + self.links_info[I_l].invweight /= ratio + self.links_info[I_l].inertial_mass *= ratio + for j1, j2 in ti.ndrange(3, 3): + self.links_info[I_l].inertial_i[j1, j2] *= ratio def _init_vgeom_fields(self): struct_vgeom_info = ti.types.struct( @@ -971,8 +987,13 @@ def _kernel_init_entity_fields( self.entities_info[i].gravity_compensation = entities_gravity_compensation[i] - for i_d in range(entities_dof_start[i], entities_dof_end[i]): - self.dofs_info[i_d].dof_start = entities_dof_start[i] + if ti.static(self._options.batch_dofs_info): + for i_b in range(self._B): + for i_d in range(entities_dof_start[i], entities_dof_end[i]): + self.dofs_info[i_d, i_b].dof_start = entities_dof_start[i] + else: + for i_d in range(entities_dof_start[i], entities_dof_end[i]): + self.dofs_info[i_d].dof_start = entities_dof_start[i] if ti.static(self._use_hibernation): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) @@ -1053,7 +1074,8 @@ def _func_compute_mass_matrix(self): i_e = self.awake_entities[i_e_, i_b] for i in range(self.entities_info[i_e].n_links): i_l = self.entities_info[i_e].link_end - 1 - i - i_p = self.links_info[i_l].parent_idx + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + i_p = self.links_info[I_l].parent_idx if i_p != -1: self.links_state[i_p, i_b].crb_inertial = ( @@ -1075,7 +1097,8 @@ def _func_compute_mass_matrix(self): for i_b in range(self._B): for i_l_ in range(self.n_awake_links[i_b]): i_l = self.awake_links[i_l_, i_b] - l_info = self.links_info[i_l] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + l_info = self.links_info[I_l] for i_d in range(l_info.dof_start, l_info.dof_end): self.dofs_state[i_d, i_b].f_ang, self.dofs_state[i_d, i_b].f_vel = gu.inertial_mul( self.links_state[i_l, i_b].crb_pos, @@ -1098,10 +1121,11 @@ def _func_compute_mass_matrix(self): ) * self.mass_parent_mask[i_d, j_d] for i_d in range(e_info.dof_start, e_info.dof_end): + I_d = [i_d, i_b] if ti.static(self._options.batch_dofs_info) else i_d self.mass_mat[i_d, i_d, i_b] = ( self.mass_mat[i_d, i_d, i_b] - + self.dofs_info[i_d].armature - + self.dofs_info[i_d].damping * self._substep_dt + + self.dofs_info[I_d].armature + + self.dofs_info[I_d].damping * self._substep_dt ) for j_d in range(i_d + 1, e_info.dof_end): self.mass_mat[i_d, j_d, i_b] = self.mass_mat[j_d, i_d, i_b] @@ -1111,18 +1135,19 @@ def _func_compute_mass_matrix(self): # qDeriv += d qfrc_actuator / d qvel ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_d, i_b in ti.ndrange(self.n_dofs, self._B): + I_d = [i_d, i_b] if ti.static(self._options.batch_dofs_info) else i_d if self.dofs_state[i_d, i_b].ctrl_mode == gs.CTRL_MODE.FORCE: pass elif self.dofs_state[i_d, i_b].ctrl_mode == gs.CTRL_MODE.VELOCITY: self.mass_mat[i_d, i_d, i_b] = ( - self.mass_mat[i_d, i_d, i_b] + self.dofs_info[i_d].kv * self._substep_dt + self.mass_mat[i_d, i_d, i_b] + self.dofs_info[I_d].kv * self._substep_dt ) elif self.dofs_state[i_d, i_b].ctrl_mode == gs.CTRL_MODE.POSITION: self.mass_mat[i_d, i_d, i_b] = ( - self.mass_mat[i_d, i_d, i_b] + self.dofs_info[i_d].kv * self._substep_dt + self.mass_mat[i_d, i_d, i_b] + self.dofs_info[I_d].kv * self._substep_dt ) else: @@ -1139,7 +1164,8 @@ def _func_compute_mass_matrix(self): for i_e, i_b in ti.ndrange(self.n_entities, self._B): for i in range(self.entities_info[i_e].n_links): i_l = self.entities_info[i_e].link_end - 1 - i - i_p = self.links_info[i_l].parent_idx + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + i_p = self.links_info[I_l].parent_idx if i_p != -1: self.links_state[i_p, i_b].crb_inertial = ( @@ -1160,7 +1186,8 @@ def _func_compute_mass_matrix(self): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_b in range(self._B): for i_l in range(self.n_links): - l_info = self.links_info[i_l] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + l_info = self.links_info[I_l] for i_d in range(l_info.dof_start, l_info.dof_end): self.dofs_state[i_d, i_b].f_ang, self.dofs_state[i_d, i_b].f_vel = gu.inertial_mul( self.links_state[i_l, i_b].crb_pos, @@ -1181,10 +1208,11 @@ def _func_compute_mass_matrix(self): ) * self.mass_parent_mask[i_d, j_d] for i_d in range(e_info.dof_start, e_info.dof_end): + I_d = [i_d, i_b] if ti.static(self._options.batch_dofs_info) else i_d self.mass_mat[i_d, i_d, i_b] = ( self.mass_mat[i_d, i_d, i_b] - + self.dofs_info[i_d].armature - + self.dofs_info[i_d].damping * self._substep_dt + + self.dofs_info[I_d].armature + + self.dofs_info[I_d].damping * self._substep_dt ) for j_d in range(i_d + 1, e_info.dof_end): self.mass_mat[i_d, j_d, i_b] = self.mass_mat[j_d, i_d, i_b] @@ -1194,18 +1222,19 @@ def _func_compute_mass_matrix(self): # qDeriv += d qfrc_actuator / d qvel ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_d, i_b in ti.ndrange(self.n_dofs, self._B): + I_d = [i_d, i_b] if ti.static(self._options.batch_dofs_info) else i_d if self.dofs_state[i_d, i_b].ctrl_mode == gs.CTRL_MODE.FORCE: pass elif self.dofs_state[i_d, i_b].ctrl_mode == gs.CTRL_MODE.VELOCITY: self.mass_mat[i_d, i_d, i_b] = ( - self.mass_mat[i_d, i_d, i_b] + self.dofs_info[i_d].kv * self._substep_dt + self.mass_mat[i_d, i_d, i_b] + self.dofs_info[I_d].kv * self._substep_dt ) elif self.dofs_state[i_d, i_b].ctrl_mode == gs.CTRL_MODE.POSITION: self.mass_mat[i_d, i_d, i_b] = ( - self.mass_mat[i_d, i_d, i_b] + self.dofs_info[i_d].kv * self._substep_dt + self.mass_mat[i_d, i_d, i_b] + self.dofs_info[I_d].kv * self._substep_dt ) @ti.func @@ -1353,15 +1382,16 @@ def _func_implicit_damping(self): # TODO: hibernate ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_d, i_b in ti.ndrange(self.n_dofs, self._B): + I_d = [i_d, i_b] if ti.static(self._options.batch_dofs_info) else i_d if self.dofs_state[i_d, i_b].ctrl_mode == gs.CTRL_MODE.FORCE: pass elif self.dofs_state[i_d, i_b].ctrl_mode == gs.CTRL_MODE.VELOCITY: - self.mass_mat[i_d, i_d, i_b] = self.mass_mat[i_d, i_d, i_b] + self.dofs_info[i_d].kv * self._substep_dt + self.mass_mat[i_d, i_d, i_b] = self.mass_mat[i_d, i_d, i_b] + self.dofs_info[I_d].kv * self._substep_dt elif self.dofs_state[i_d, i_b].ctrl_mode == gs.CTRL_MODE.POSITION: - self.mass_mat[i_d, i_d, i_b] = self.mass_mat[i_d, i_d, i_b] + self.dofs_info[i_d].kv * self._substep_dt + self.mass_mat[i_d, i_d, i_b] = self.mass_mat[i_d, i_d, i_b] + self.dofs_info[I_d].kv * self._substep_dt self.dofs_state[i_d, i_b].force += self.dofs_state[i_d, i_b].qf_constraint @@ -1491,9 +1521,10 @@ def _func_COM_links(self): for i_b in range(self._B): for i_l_ in range(self.n_awake_links[i_b]): i_l = self.awake_links[i_l_, i_b] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l l = self.links_state[i_l, i_b] - l_info = self.links_info[i_l] + l_info = self.links_info[I_l] mass = l_info.inertial_mass + l.mass_shift self.links_state[i_l, i_b].i_pos, self.links_state[i_l, i_b].i_quat = ( gu.ti_transform_pos_quat_by_trans_quat( @@ -1501,7 +1532,7 @@ def _func_COM_links(self): ) ) - i_r = self.links_info[i_l].root_idx + i_r = self.links_info[I_l].root_idx ti.atomic_add(self.links_state[i_r, i_b].mass_sum, mass) COM = mass * self.links_state[i_l, i_b].i_pos @@ -1511,8 +1542,9 @@ def _func_COM_links(self): for i_b in range(self._B): for i_l_ in range(self.n_awake_links[i_b]): i_l = self.awake_links[i_l_, i_b] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l - i_r = self.links_info[i_l].root_idx + i_r = self.links_info[I_l].root_idx if i_l == i_r: self.links_state[i_l, i_b].root_COM = ( self.links_state[i_l, i_b].root_COM / self.links_state[i_l, i_b].mass_sum @@ -1522,19 +1554,21 @@ def _func_COM_links(self): for i_b in range(self._B): for i_l_ in range(self.n_awake_links[i_b]): i_l = self.awake_links[i_l_, i_b] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l - i_r = self.links_info[i_l].root_idx + i_r = self.links_info[I_l].root_idx self.links_state[i_l, i_b].root_COM = self.links_state[i_r, i_b].root_COM ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_b in range(self._B): for i_l_ in range(self.n_awake_links[i_b]): i_l = self.awake_links[i_l_, i_b] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l l = self.links_state[i_l, i_b] - l_info = self.links_info[i_l] + l_info = self.links_info[I_l] - i_r = self.links_info[i_l].root_idx + i_r = self.links_info[I_l].root_idx self.links_state[i_l, i_b].COM = self.links_state[i_r, i_b].root_COM self.links_state[i_l, i_b].i_pos = self.links_state[i_l, i_b].i_pos - self.links_state[i_l, i_b].COM @@ -1553,8 +1587,9 @@ def _func_COM_links(self): for i_b in range(self._B): for i_l_ in range(self.n_awake_links[i_b]): i_l = self.awake_links[i_l_, i_b] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l - l_info = self.links_info[i_l] + l_info = self.links_info[I_l] i_p = l_info.parent_idx p_pos = ti.Vector.zero(gs.ti_float, 3) @@ -1586,14 +1621,16 @@ def _func_COM_links(self): for i_b in range(self._B): for i_l_ in range(self.n_awake_links[i_b]): i_l = self.awake_links[i_l_, i_b] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l - l_info = self.links_info[i_l] + l_info = self.links_info[I_l] if l_info.joint_type == gs.JOINT_TYPE.FREE: for i_d in range(l_info.dof_start, l_info.dof_end): - self.dofs_state[i_d, i_b].cdof_vel = self.dofs_info[i_d].motion_vel + I_d = [i_d, i_b] if ti.static(self._options.batch_dofs_info) else i_d + self.dofs_state[i_d, i_b].cdof_vel = self.dofs_info[I_d].motion_vel self.dofs_state[i_d, i_b].cdof_ang = gu.ti_transform_by_quat( - self.dofs_info[i_d].motion_ang, self.links_state[i_l, i_b].j_quat + self.dofs_info[I_d].motion_ang, self.links_state[i_l, i_b].j_quat ) offset_pos = self.links_state[i_l, i_b].COM - self.links_state[i_l, i_b].j_pos @@ -1617,8 +1654,9 @@ def _func_COM_links(self): pass else: for i_d in range(l_info.dof_start, l_info.dof_end): - motion_vel = self.dofs_info[i_d].motion_vel - motion_ang = self.dofs_info[i_d].motion_ang + I_d = [i_d, i_b] if ti.static(self._options.batch_dofs_info) else i_d + motion_vel = self.dofs_info[I_d].motion_vel + motion_ang = self.dofs_info[I_d].motion_ang self.dofs_state[i_d, i_b].cdof_ang = gu.ti_transform_by_quat( motion_ang, self.links_state[i_l, i_b].j_quat @@ -1654,9 +1692,10 @@ def _func_COM_links(self): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_b in range(self._B): for i_l in range(self.n_links): + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l l = self.links_state[i_l, i_b] - l_info = self.links_info[i_l] + l_info = self.links_info[I_l] mass = l_info.inertial_mass + l.mass_shift self.links_state[i_l, i_b].i_pos, self.links_state[i_l, i_b].i_quat = ( gu.ti_transform_pos_quat_by_trans_quat( @@ -1664,7 +1703,7 @@ def _func_COM_links(self): ) ) - i_r = self.links_info[i_l].root_idx + i_r = self.links_info[I_l].root_idx ti.atomic_add(self.links_state[i_r, i_b].mass_sum, mass) COM = mass * self.links_state[i_l, i_b].i_pos @@ -1673,8 +1712,9 @@ def _func_COM_links(self): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_b in range(self._B): for i_l in range(self.n_links): + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l - i_r = self.links_info[i_l].root_idx + i_r = self.links_info[I_l].root_idx if i_l == i_r: self.links_state[i_l, i_b].root_COM = ( self.links_state[i_l, i_b].root_COM / self.links_state[i_l, i_b].mass_sum @@ -1683,18 +1723,20 @@ def _func_COM_links(self): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_b in range(self._B): for i_l in range(self.n_links): + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l - i_r = self.links_info[i_l].root_idx + i_r = self.links_info[I_l].root_idx self.links_state[i_l, i_b].root_COM = self.links_state[i_r, i_b].root_COM ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_b in range(self._B): for i_l in range(self.n_links): + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l l = self.links_state[i_l, i_b] - l_info = self.links_info[i_l] + l_info = self.links_info[I_l] - i_r = self.links_info[i_l].root_idx + i_r = self.links_info[I_l].root_idx self.links_state[i_l, i_b].COM = self.links_state[i_r, i_b].root_COM self.links_state[i_l, i_b].i_pos = self.links_state[i_l, i_b].i_pos - self.links_state[i_l, i_b].COM @@ -1712,8 +1754,9 @@ def _func_COM_links(self): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_b in range(self._B): for i_l in range(self.n_links): + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l - l_info = self.links_info[i_l] + l_info = self.links_info[I_l] i_p = l_info.parent_idx p_pos = ti.Vector.zero(gs.ti_float, 3) @@ -1744,14 +1787,16 @@ def _func_COM_links(self): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_b in range(self._B): for i_l in range(self.n_links): + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l - l_info = self.links_info[i_l] + l_info = self.links_info[I_l] if l_info.joint_type == gs.JOINT_TYPE.FREE: for i_d in range(l_info.dof_start, l_info.dof_end): - self.dofs_state[i_d, i_b].cdof_vel = self.dofs_info[i_d].motion_vel + I_d = [i_d, i_b] if ti.static(self._options.batch_dofs_info) else i_d + self.dofs_state[i_d, i_b].cdof_vel = self.dofs_info[I_d].motion_vel self.dofs_state[i_d, i_b].cdof_ang = gu.ti_transform_by_quat( - self.dofs_info[i_d].motion_ang, self.links_state[i_l, i_b].j_quat + self.dofs_info[I_d].motion_ang, self.links_state[i_l, i_b].j_quat ) offset_pos = self.links_state[i_l, i_b].COM - self.links_state[i_l, i_b].j_pos @@ -1776,8 +1821,9 @@ def _func_COM_links(self): else: for i_d in range(l_info.dof_start, l_info.dof_end): - motion_vel = self.dofs_info[i_d].motion_vel - motion_ang = self.dofs_info[i_d].motion_ang + I_d = [i_d, i_b] if ti.static(self._options.batch_dofs_info) else i_d + motion_vel = self.dofs_info[I_d].motion_vel + motion_ang = self.dofs_info[I_d].motion_ang self.dofs_state[i_d, i_b].cdof_ang = gu.ti_transform_by_quat( motion_ang, self.links_state[i_l, i_b].j_quat @@ -1836,7 +1882,8 @@ def _func_COM_cd(self): e_info = self.entities_info[i_e] for i_l in range(e_info.link_start, e_info.link_end): - l_info = self.links_info[i_l] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + l_info = self.links_info[I_l] i_p = l_info.parent_idx cd_vel = ti.Vector.zero(gs.ti_float, 3) @@ -1861,8 +1908,9 @@ def _func_COM_cdofd(self): for i_b in range(self._B): for i_l_ in range(self.n_awake_links[i_b]): i_l = self.awake_links[i_l_, i_b] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l - l_info = self.links_info[i_l] + l_info = self.links_info[I_l] if l_info.joint_type == gs.JOINT_TYPE.FREE: cd_ang = ti.Vector.zero(gs.ti_float, 3) @@ -1904,8 +1952,9 @@ def _func_COM_cdofd(self): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_b in range(self._B): for i_l in range(self.n_links): + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l - l_info = self.links_info[i_l] + l_info = self.links_info[I_l] if l_info.joint_type == gs.JOINT_TYPE.FREE: cd_ang = ti.Vector.zero(gs.ti_float, 3) @@ -1967,7 +2016,8 @@ def _func_forward_kinematics(self): def _func_forward_kinematics_entity(self, i_e, i_b): # calculate_j for i_l in range(self.entities_info[i_e].link_start, self.entities_info[i_e].link_end): - l_info = self.links_info[i_l] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + l_info = self.links_info[I_l] if l_info.joint_type == gs.JOINT_TYPE.FREE: @@ -1991,7 +2041,8 @@ def _func_forward_kinematics_entity(self, i_e, i_b): else: self.dofs_state[l_info.dof_start, i_b].pos = self.qpos[l_info.q_start, i_b] - dof_info = self.dofs_info[l_info.dof_start] + I_dof_start = [l_info.dof_start, i_b] if ti.static(self._options.batch_dofs_info) else l_info.dof_start + dof_info = self.dofs_info[I_dof_start] self.links_state[i_l, i_b].j_pos = dof_info.motion_vel * self.qpos[l_info.q_start, i_b] self.links_state[i_l, i_b].j_quat = gu.ti_rotvec_to_quat( dof_info.motion_ang * self.qpos[l_info.q_start, i_b] @@ -2000,7 +2051,8 @@ def _func_forward_kinematics_entity(self, i_e, i_b): self.links_state[i_l, i_b].j_vel = dof_info.motion_vel * self.dofs_state[l_info.dof_start, i_b].vel for i_d in range(l_info.dof_start + 1, l_info.dof_end): - dof_info = self.dofs_info[i_d] + I_d = [i_d, i_b] if ti.static(self._options.batch_dofs_info) else i_d + dof_info = self.dofs_info[I_d] qi = l_info.q_start + i_d - l_info.dof_start self.dofs_state[i_d, i_b].pos = self.qpos[qi, i_b] ji_pos = dof_info.motion_vel * self.qpos[qi, i_b] @@ -2041,7 +2093,8 @@ def _func_forward_kinematics_entity(self, i_e, i_b): # joint_to_world for i_l in range(self.entities_info[i_e].link_start, self.entities_info[i_e].link_end): l_state = self.links_state[i_l, i_b] - l_info = self.links_info[i_l] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + l_info = self.links_info[I_l] i_p = l_info.parent_idx if i_p == -1: # root link @@ -2318,15 +2371,17 @@ def _func_apply_external_torque_link_frame(self, torque, link_idx, batch_idx): @ti.func def _func_apply_external_force_link_inertial_frame(self, pos, force, link_idx, batch_idx): + link_I = [link_idx, batch_idx] if ti.static(self._options.batch_links_info) else link_idx pos = gu.ti_transform_by_trans_quat( - pos, self.links_info[link_idx].inertial_pos, self.links_info[link_idx].inertial_quat + pos, self.links_info[link_I].inertial_pos, self.links_info[link_I].inertial_quat ) - force = gu.ti_transform_by_quat(force, self.links_info[link_idx].inertial_quat) + force = gu.ti_transform_by_quat(force, self.links_info[link_I].inertial_quat) self._func_apply_external_force_link_frame(pos, force, link_idx, batch_idx) @ti.func def _func_apply_external_torque_link_inertial_frame(self, torque, link_idx, batch_idx): - torque = gu.ti_transform_by_quat(torque, self.links_info[link_idx].inertial_quat) + link_I = [link_idx, batch_idx] if ti.static(self._options.batch_links_info) else link_idx + torque = gu.ti_transform_by_quat(torque, self.links_info[link_I].inertial_quat) self._func_apply_external_torque_link_frame(torque, link_idx, batch_idx) @ti.func @@ -2352,27 +2407,29 @@ def _func_torque_and_passive_force(self): wakeup = False for i_l in range(self.entities_info[i_e].link_start, self.entities_info[i_e].link_end): force = gs.ti_float(0.0) - l_info = self.links_info[i_l] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + l_info = self.links_info[I_l] for i_d in range(l_info.dof_start, l_info.dof_end): + I_d = [i_d, i_b] if ti.static(self._options.batch_dofs_info) else i_d if self.dofs_state[i_d, i_b].ctrl_mode == gs.CTRL_MODE.FORCE: force = self.dofs_state[i_d, i_b].ctrl_force elif self.dofs_state[i_d, i_b].ctrl_mode == gs.CTRL_MODE.VELOCITY: - force = self.dofs_info[i_d].kv * ( + force = self.dofs_info[I_d].kv * ( self.dofs_state[i_d, i_b].ctrl_vel - self.dofs_state[i_d, i_b].vel ) elif self.dofs_state[i_d, i_b].ctrl_mode == gs.CTRL_MODE.POSITION and not ( l_info.joint_type == gs.JOINT_TYPE.FREE and i_d >= l_info.dof_start + 3 ): force = ( - self.dofs_info[i_d].kp + self.dofs_info[I_d].kp * (self.dofs_state[i_d, i_b].ctrl_pos - self.dofs_state[i_d, i_b].pos) - - self.dofs_info[i_d].kv * self.dofs_state[i_d, i_b].vel + - self.dofs_info[I_d].kv * self.dofs_state[i_d, i_b].vel ) self.dofs_state[i_d, i_b].qf_applied = ti.math.clamp( force, - self.dofs_info[i_d].force_range[0], - self.dofs_info[i_d].force_range[1], + self.dofs_info[I_d].force_range[0], + self.dofs_info[I_d].force_range[1], ) if ti.abs(force) > gs.EPS: @@ -2410,13 +2467,14 @@ def _func_torque_and_passive_force(self): rotvec = gu.ti_quat_to_rotvec(q_diff) for i_d in range(l_info.dof_start + 3, l_info.dof_end): + I_d = [i_d, i_b] if ti.static(self._options.batch_dofs_info) else i_d force = ( - self.dofs_info[i_d].kp * rotvec[i_d - l_info.dof_start - 3] - - self.dofs_info[i_d].kv * self.dofs_state[i_d, i_b].vel + self.dofs_info[I_d].kp * rotvec[i_d - l_info.dof_start - 3] + - self.dofs_info[I_d].kv * self.dofs_state[i_d, i_b].vel ) self.dofs_state[i_d, i_b].qf_applied = ti.math.clamp( - force, self.dofs_info[i_d].force_range[0], self.dofs_info[i_d].force_range[1] + force, self.dofs_info[I_d].force_range[0], self.dofs_info[I_d].force_range[1] ) if ti.abs(force) > gs.EPS: wakeup = True @@ -2430,7 +2488,8 @@ def _func_torque_and_passive_force(self): for i_b in range(self._B): for i_l_ in range(self.n_awake_links[i_b]): i_l = self.awake_links[i_l_, i_b] - l_info = self.links_info[i_l] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + l_info = self.links_info[I_l] if ( l_info.joint_type == gs.JOINT_TYPE.REVOLUTE or l_info.joint_type == gs.JOINT_TYPE.PRISMATIC @@ -2442,20 +2501,25 @@ def _func_torque_and_passive_force(self): q_end = l_info.q_end for j_d in range(q_end - q_start): + I_d = ( + [dof_start + j_d, i_b] if ti.static(self._options.batch_dofs_info) else dof_start + j_d + ) self.dofs_state[dof_start + j_d, i_b].qf_passive = ( - -self.qpos[q_start + j_d, i_b] * self.dofs_info[dof_start + j_d].stiffness + -self.qpos[q_start + j_d, i_b] * self.dofs_info[I_d].stiffness ) ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_b in range(self._B): for i_d_ in range(self.n_awake_dofs[i_b]): i_d = self.awake_dofs[i_d_, i_b] + I_d = [i_d, i_b] if ti.static(self._options.batch_dofs_info) else i_d - self.dofs_state[i_d, i_b].qf_passive += -self.dofs_info[i_d].damping * self.dofs_state[i_d, i_b].vel + self.dofs_state[i_d, i_b].qf_passive += -self.dofs_info[I_d].damping * self.dofs_state[i_d, i_b].vel else: ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_l, i_b in ti.ndrange(self.n_links, self._B): - l_info = self.links_info[i_l] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + l_info = self.links_info[I_l] if ( l_info.joint_type == gs.JOINT_TYPE.REVOLUTE or l_info.joint_type == gs.JOINT_TYPE.PRISMATIC @@ -2467,13 +2531,15 @@ def _func_torque_and_passive_force(self): q_end = l_info.q_end for j_d in range(q_end - q_start): + I_d = [dof_start + j_d, i_b] if ti.static(self._options.batch_dofs_info) else dof_start + j_d self.dofs_state[dof_start + j_d, i_b].qf_passive = ( - -self.qpos[q_start + j_d, i_b] * self.dofs_info[dof_start + j_d].stiffness + -self.qpos[q_start + j_d, i_b] * self.dofs_info[I_d].stiffness ) ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_d, i_b in ti.ndrange(self.n_dofs, self._B): - self.dofs_state[i_d, i_b].qf_passive += -self.dofs_info[i_d].damping * self.dofs_state[i_d, i_b].vel + I_d = [i_d, i_b] if ti.static(self._options.batch_dofs_info) else i_d + self.dofs_state[i_d, i_b].qf_passive += -self.dofs_info[I_d].damping * self.dofs_state[i_d, i_b].vel @ti.func def _func_system_update_acc(self): @@ -2484,8 +2550,9 @@ def _func_system_update_acc(self): i_e = self.awake_entities[i_e_, i_b] e_info = self.entities_info[i_e] for i_l in range(e_info.link_start, e_info.link_end): + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l - i_p = self.links_info[i_l].parent_idx + i_p = self.links_info[I_l].parent_idx if i_p == -1: self.links_state[i_l, i_b].cdd_vel = -self._gravity[None] * ( 1 - e_info.gravity_compensation @@ -2499,7 +2566,7 @@ def _func_system_update_acc(self): map_sum_vel = ti.Vector.zero(gs.ti_float, 3) map_sum_ang = ti.Vector.zero(gs.ti_float, 3) - for i_d in range(self.links_info[i_l].dof_start, self.links_info[i_l].dof_end): + for i_d in range(self.links_info[I_l].dof_start, self.links_info[I_l].dof_end): map_sum_vel = ( map_sum_vel + self.dofs_state[i_d, i_b].cdofd_vel * self.dofs_state[i_d, i_b].vel ) @@ -2514,8 +2581,9 @@ def _func_system_update_acc(self): for i_e, i_b in ti.ndrange(self.n_entities, self._B): e_info = self.entities_info[i_e] for i_l in range(e_info.link_start, e_info.link_end): + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l - i_p = self.links_info[i_l].parent_idx + i_p = self.links_info[I_l].parent_idx if i_p == -1: self.links_state[i_l, i_b].cdd_vel = -self._gravity[None] * (1 - e_info.gravity_compensation) self.links_state[i_l, i_b].cdd_ang = ti.Vector.zero(gs.ti_float, 3) @@ -2527,7 +2595,7 @@ def _func_system_update_acc(self): map_sum_vel = ti.Vector.zero(gs.ti_float, 3) map_sum_ang = ti.Vector.zero(gs.ti_float, 3) - for i_d in range(self.links_info[i_l].dof_start, self.links_info[i_l].dof_end): + for i_d in range(self.links_info[I_l].dof_start, self.links_info[I_l].dof_end): map_sum_vel = map_sum_vel + self.dofs_state[i_d, i_b].cdofd_vel * self.dofs_state[i_d, i_b].vel map_sum_ang = map_sum_ang + self.dofs_state[i_d, i_b].cdofd_ang * self.dofs_state[i_d, i_b].vel @@ -2596,7 +2664,8 @@ def _func_inverse_link_force(self): e_info = self.entities_info[i_e] for i in range(e_info.n_links): i_l = e_info.link_end - 1 - i - i_p = self.links_info[i_l].parent_idx + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + i_p = self.links_info[I_l].parent_idx if i_p != -1: self.links_state[i_p, i_b].cfrc_flat_vel = ( self.links_state[i_p, i_b].cfrc_flat_vel + self.links_state[i_l, i_b].cfrc_flat_vel @@ -2611,7 +2680,8 @@ def _func_inverse_link_force(self): e_info = self.entities_info[i_e] for i in range(e_info.n_links): i_l = e_info.link_end - 1 - i - i_p = self.links_info[i_l].parent_idx + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + i_p = self.links_info[I_l].parent_idx if i_p != -1: self.links_state[i_p, i_b].cfrc_flat_vel = ( self.links_state[i_p, i_b].cfrc_flat_vel + self.links_state[i_l, i_b].cfrc_flat_vel @@ -2628,15 +2698,16 @@ def _func_actuation(self): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_l, i_b in ti.ndrange(self.n_links, self._B): joint_type = self.links_info[i_l].joint_type - q_start = self.links_info[i_l].q_start + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + q_start = self.links_info[I_l].q_start if joint_type == gs.JOINT_TYPE.REVOLUTE or joint_type == gs.JOINT_TYPE.PRISMATIC: gear = -1 # TODO - i_d = self.links_info[i_l].dof_start + i_d = self.links_info[I_l].dof_start self.dofs_state[i_d, i_b].act_length = gear * self.qpos[q_start, i_b] self.dofs_state[i_d, i_b].qf_actuator = self.dofs_state[i_d, i_b].act_length else: - for i_d in range(self.links_info[i_l].dof_start, self.links_info[i_l].dof_end): + for i_d in range(self.links_info[I_l].dof_start, self.links_info[I_l].dof_end): self.dofs_state[i_d, i_b].act_length = 0.0 self.dofs_state[i_d, i_b].qf_actuator = self.dofs_state[i_d, i_b].act_length @@ -2647,8 +2718,9 @@ def _func_bias_force(self): for i_b in range(self._B): for i_l_ in range(self.n_awake_links[i_b]): i_l = self.awake_links[i_l_, i_b] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l - l_info = self.links_info[i_l] + l_info = self.links_info[I_l] dof_start = l_info.dof_start dof_end = l_info.dof_end @@ -2667,8 +2739,9 @@ def _func_bias_force(self): else: ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_l, i_b in ti.ndrange(self.n_links, self._B): + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l - l_info = self.links_info[i_l] + l_info = self.links_info[I_l] dof_start = l_info.dof_start dof_end = l_info.dof_end @@ -2727,10 +2800,11 @@ def _func_integrate(self): for i_b in range(self._B): for i_l_ in range(self.n_awake_links[i_b]): i_l = self.awake_links[i_l_, i_b] - joint_type = self.links_info[i_l].joint_type - dof_start = self.links_info[i_l].dof_start - q_start = self.links_info[i_l].q_start - q_end = self.links_info[i_l].q_end + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + joint_type = self.links_info[I_l].joint_type + dof_start = self.links_info[I_l].dof_start + q_start = self.links_info[I_l].q_start + q_end = self.links_info[I_l].q_end if joint_type == gs.JOINT_TYPE.FREE: rot = ti.Vector( @@ -2795,10 +2869,11 @@ def _func_integrate(self): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) for i_l, i_b in ti.ndrange(self.n_links, self._B): - joint_type = self.links_info[i_l].joint_type - dof_start = self.links_info[i_l].dof_start - q_start = self.links_info[i_l].q_start - q_end = self.links_info[i_l].q_end + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + joint_type = self.links_info[I_l].joint_type + dof_start = self.links_info[I_l].dof_start + q_start = self.links_info[I_l].q_start + q_end = self.links_info[I_l].q_end if joint_type == gs.JOINT_TYPE.FREE: rot = ti.Vector( @@ -2844,7 +2919,8 @@ def _func_integrate(self): def _func_integrate_dq_entity(self, dq, i_e, i_b, respect_joint_limit): e_info = self.entities_info[i_e] for i_l in range(e_info.link_start, e_info.link_end): - l_info = self.links_info[i_l] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + l_info = self.links_info[I_l] q_start = l_info.q_start dof_start = l_info.dof_start dq_start = l_info.dof_start - e_info.dof_start @@ -2883,10 +2959,11 @@ def _func_integrate_dq_entity(self, dq, i_e, i_b, respect_joint_limit): self.qpos[q_start + i_d_, i_b] = self.qpos[q_start + i_d_, i_b] + dq[dq_start + i_d_, i_b] if respect_joint_limit: + I_d = [dof_start + i_d_, i_b] if ti.static(self._options.batch_dofs_info) else dof_start + i_d_ self.qpos[q_start + i_d_, i_b] = ti.math.clamp( self.qpos[q_start + i_d_, i_b], - self.dofs_info[dof_start + i_d_].limit[0], - self.dofs_info[dof_start + i_d_].limit[1], + self.dofs_info[I_d].limit[0], + self.dofs_info[I_d].limit[1], ) def substep_pre_coupling(self, f): @@ -3206,12 +3283,13 @@ def _kernel_set_links_pos( ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) for i_l_, i_b_ in ti.ndrange(links_idx.shape[0], envs_idx.shape[0]): i_l = links_idx[i_l_] - if self.links_info[i_l].is_fixed: # change links_state directly as the link's pose is not contained in qpos + I_l = [i_l, i_b_] if ti.static(self._options.batch_links_info) else i_l + if self.links_info[I_l].is_fixed: # change links_state directly as the link's pose is not contained in qpos for i in ti.static(range(3)): self.links_state[i_l, envs_idx[i_b_]].pos[i] = pos[i_b_, i_l_, i] else: # free base link's pose is reflected in qpos, and links_state will be computed automatically - q_start = self.links_info[i_l].q_start + q_start = self.links_info[I_l].q_start for i in ti.static(range(3)): self.qpos[q_start + i, envs_idx[i_b_]] = pos[i_b_, i_l_, i] @@ -3239,12 +3317,13 @@ def _kernel_set_links_quat( ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) for i_l_, i_b_ in ti.ndrange(links_idx.shape[0], envs_idx.shape[0]): i_l = links_idx[i_l_] - if self.links_info[i_l].is_fixed: # change links_state directly as the link's pose is not contained in qpos + I_l = [i_l, i_b_] if ti.static(self._options.batch_links_info) else i_l + if self.links_info[I_l].is_fixed: # change links_state directly as the link's pose is not contained in qpos for i in ti.static(range(4)): self.links_state[i_l, envs_idx[i_b_]].quat[i] = quat[i_b_, i_l_, i] else: # free base link's pose is reflected in qpos, and links_state will be computed automatically - q_start = self.links_info[i_l].q_start + q_start = self.links_info[I_l].q_start for i in ti.static(range(4)): self.qpos[q_start + i + 3, envs_idx[i_b_]] = quat[i_b_, i_l_, i] @@ -3281,6 +3360,58 @@ def _kernel_set_links_COM_shift( for i in ti.static(range(3)): self.links_state[links_idx[i_l_], envs_idx[i_b_]].i_pos_shift[i] = com[i_b_, i_l_, i] + def _set_links_info(self, tensor, links_idx, name, envs_idx=None): + if self._options.batch_links_info: + tensor, links_idx, envs_idx = self._validate_1D_io_variables( + tensor, links_idx, envs_idx, idx_name="links_idx" + ) + else: + tensor, links_idx = self._validate_1D_io_variables(tensor, links_idx, idx_name="links_idx", batched=False) + envs_idx = torch.empty(()) + + if name == "invweight": + self._kernel_set_links_invweight(tensor, links_idx, envs_idx) + elif name == "inertial_mass": + self._kernel_set_links_inertial_mass(tensor, links_idx, envs_idx) + else: + gs.raise_exception(f"Invalid `name` {name}.") + + def set_links_inertial_mass(self, invweight, links_idx, envs_idx=None): + self._set_links_info(invweight, links_idx, "inertial_mass", envs_idx) + + @ti.kernel + def _kernel_set_links_inertial_mass( + self, + inertial_mass: ti.types.ndarray(), + links_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + ): + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) + if ti.static(self._options.batch_links_info): + for i_l_, i_b_ in ti.ndrange(links_idx.shape[0], envs_idx.shape[0]): + self.links_info[links_idx[i_l_], envs_idx[i_b_]].inertial_mass = inertial_mass[i_b_, i_l_] + else: + for i_l_ in range(links_idx.shape[0]): + self.links_info[links_idx[i_l_]].inertial_mass = inertial_mass[i_l_] + + def set_links_invweight(self, invweight, links_idx, envs_idx=None): + self._set_links_info(invweight, links_idx, "invweight", envs_idx) + + @ti.kernel + def _kernel_set_links_invweight( + self, + invweight: ti.types.ndarray(), + links_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + ): + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) + if ti.static(self._options.batch_links_info): + for i_l_, i_b_ in ti.ndrange(links_idx.shape[0], envs_idx.shape[0]): + self.links_info[links_idx[i_l_], envs_idx[i_b_]].invweight = invweight[i_b_, i_l_] + else: + for i_l_ in range(links_idx.shape[0]): + self.links_info[links_idx[i_l_]].invweight = invweight[i_l_] + def set_geoms_friction_ratio(self, friction_ratio, geoms_idx, envs_idx=None): friction_ratio, geoms_idx, envs_idx = self._validate_1D_io_variables( friction_ratio, geoms_idx, envs_idx, idx_name="geoms_idx" @@ -3337,59 +3468,195 @@ def _kernel_set_global_sol_params(self, sol_params: ti.types.ndarray()): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) for i, b in ti.ndrange(self.n_dofs, self._B): + I = [i, b] if ti.static(self._options.batch_dofs_info) else i for j in ti.static(range(7)): - self.dofs_info[i].sol_params[j] = sol_params[j] + self.dofs_info[I].sol_params[j] = sol_params[j] - self.dofs_info[i].sol_params[0] = self._substep_dt * 2 + self.dofs_info[I].sol_params[0] = self._substep_dt * 2 - def set_dofs_kp(self, kp, dofs_idx): - kp, dofs_idx = self._validate_1D_io_variables(kp, dofs_idx, batched=False) - self._kernel_set_dofs_kp(kp, dofs_idx) + def _set_dofs_info(self, tensor_list, dofs_idx, name, envs_idx=None): + if self._options.batch_dofs_info: + for i, tensor in enumerate(tensor_list): + if i == (len(tensor_list) - 1): + tensor_list[i], dofs_idx, envs_idx = self._validate_1D_io_variables(tensor, dofs_idx, envs_idx) + else: + tensor_list[i], _, _ = self._validate_1D_io_variables(tensor, dofs_idx, envs_idx) + else: + for i, tensor in enumerate(tensor_list): + if i == (len(tensor_list) - 1): + tensor_list[i], _ = self._validate_1D_io_variables(tensor, dofs_idx, batched=False) + else: + tensor_list[i], dofs_idx = self._validate_1D_io_variables(tensor, dofs_idx, batched=False) + envs_idx = torch.empty(()) + + if name == "kp": + self._kernel_set_dofs_kp(tensor_list[0], dofs_idx, envs_idx) + elif name == "kv": + self._kernel_set_dofs_kv(tensor_list[0], dofs_idx, envs_idx) + elif name == "force_range": + self._kernel_set_dofs_force_range(tensor_list[0], tensor_list[1], dofs_idx, envs_idx) + elif name == "stiffness": + self._kernel_set_dofs_stiffness(tensor_list[0], dofs_idx, envs_idx) + elif name == "invweight": + self._kernel_set_dofs_invweight(tensor_list[0], dofs_idx, envs_idx) + elif name == "armature": + self._kernel_set_dofs_armature(tensor_list[0], dofs_idx, envs_idx) + elif name == "damping": + self._kernel_set_dofs_damping(tensor_list[0], dofs_idx, envs_idx) + elif name == "limit": + self._kernel_set_dofs_limit(tensor_list[0], tensor_list[1], dofs_idx, envs_idx) + else: + gs.raise_exception(f"Invalid `name` {name}.") + + def set_dofs_kp(self, kp, dofs_idx, envs_idx=None): + self._set_dofs_info([kp], dofs_idx, "kp", envs_idx) @ti.kernel def _kernel_set_dofs_kp( self, kp: ti.types.ndarray(), dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), ): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) - for i_d_ in range(dofs_idx.shape[0]): - self.dofs_info[dofs_idx[i_d_]].kp = kp[i_d_] + if ti.static(self._options.batch_dofs_info): + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].kp = kp[i_b_, i_d_] + else: + for i_d_ in range(dofs_idx.shape[0]): + self.dofs_info[dofs_idx[i_d_]].kp = kp[i_d_] - def set_dofs_kv(self, kv, dofs_idx): - kv, dofs_idx = self._validate_1D_io_variables(kv, dofs_idx, batched=False) - self._kernel_set_dofs_kv(kv, dofs_idx) + def set_dofs_kv(self, kv, dofs_idx, envs_idx=None): + self._set_dofs_info([kv], dofs_idx, "kv", envs_idx) @ti.kernel def _kernel_set_dofs_kv( self, kv: ti.types.ndarray(), dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), ): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) - for i_d_ in range(dofs_idx.shape[0]): - self.dofs_info[dofs_idx[i_d_]].kv = kv[i_d_] + if ti.static(self._options.batch_dofs_info): + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].kv = kv[i_b_, i_d_] + else: + for i_d_ in range(dofs_idx.shape[0]): + self.dofs_info[dofs_idx[i_d_]].kv = kv[i_d_] - def set_dofs_force_range(self, lower, upper, dofs_idx): - lower, _ = self._validate_1D_io_variables(lower, dofs_idx, batched=False) - upper, dofs_idx = self._validate_1D_io_variables(upper, dofs_idx, batched=False) + def set_dofs_force_range(self, lower, upper, dofs_idx, envs_idx=None): + self._set_dofs_info([lower, upper], dofs_idx, "force_range", envs_idx) - if (lower > upper).any(): - gs.raise_exception("`lower` should be less than or equal to `upper`.") + @ti.kernel + def _kernel_set_dofs_force_range( + self, + lower: ti.types.ndarray(), + upper: ti.types.ndarray(), + dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + ): + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) + if ti.static(self._options.batch_dofs_info): + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].force_range[0] = lower[i_b_, i_d_] + self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].force_range[1] = upper[i_b_, i_d_] + else: + for i_d_ in range(dofs_idx.shape[0]): + self.dofs_info[dofs_idx[i_d_]].force_range[0] = lower[i_d_] + self.dofs_info[dofs_idx[i_d_]].force_range[1] = upper[i_d_] - self._kernel_set_dofs_force_range(lower, upper, dofs_idx) + def set_dofs_stiffness(self, stiffness, dofs_idx, envs_idx=None): + self._set_dofs_info([stiffness], dofs_idx, "stiffness", envs_idx) @ti.kernel - def _kernel_set_dofs_force_range( + def _kernel_set_dofs_stiffness( + self, + stiffness: ti.types.ndarray(), + dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + ): + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) + if ti.static(self._options.batch_dofs_info): + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].stiffness = stiffness[i_b_, i_d_] + else: + for i_d_ in range(dofs_idx.shape[0]): + self.dofs_info[dofs_idx[i_d_]].stiffness = stiffness[i_d_] + + def set_dofs_invweight(self, invweight, dofs_idx, envs_idx=None): + self._set_dofs_info([invweight], dofs_idx, "invweight", envs_idx) + + @ti.kernel + def _kernel_set_dofs_invweight( + self, + invweight: ti.types.ndarray(), + dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + ): + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) + if ti.static(self._options.batch_dofs_info): + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].invweight = invweight[i_b_, i_d_] + else: + for i_d_ in range(dofs_idx.shape[0]): + self.dofs_info[dofs_idx[i_d_]].invweight = invweight[i_d_] + + def set_dofs_armature(self, armature, dofs_idx, envs_idx=None): + self._set_dofs_info([armature], dofs_idx, "armature", envs_idx) + + @ti.kernel + def _kernel_set_dofs_armature( + self, + armature: ti.types.ndarray(), + dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + ): + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) + if ti.static(self._options.batch_dofs_info): + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].armature = armature[i_b_, i_d_] + else: + for i_d_ in range(dofs_idx.shape[0]): + self.dofs_info[dofs_idx[i_d_]].armature = armature[i_d_] + + def set_dofs_damping(self, damping, dofs_idx, envs_idx=None): + self._set_dofs_info([damping], dofs_idx, "damping", envs_idx) + + @ti.kernel + def _kernel_set_dofs_damping( + self, + damping: ti.types.ndarray(), + dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + ): + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) + if ti.static(self._options.batch_dofs_info): + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].damping = damping[i_b_, i_d_] + else: + for i_d_ in range(dofs_idx.shape[0]): + self.dofs_info[dofs_idx[i_d_]].damping = damping[i_d_] + + def set_dofs_limit(self, lower, upper, dofs_idx, envs_idx=None): + self._set_dofs_info([lower, upper], dofs_idx, "limit", envs_idx) + + @ti.kernel + def _kernel_set_dofs_limit( self, lower: ti.types.ndarray(), upper: ti.types.ndarray(), dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), ): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) - for i_d_ in range(dofs_idx.shape[0]): - self.dofs_info[dofs_idx[i_d_]].force_range[0] = lower[i_d_] - self.dofs_info[dofs_idx[i_d_]].force_range[1] = upper[i_d_] + if ti.static(self._options.batch_dofs_info): + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].limit[0] = lower[i_b_, i_d_] + self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].limit[1] = upper[i_b_, i_d_] + else: + for i_d_ in range(dofs_idx.shape[0]): + self.dofs_info[dofs_idx[i_d_]].limit[0] = lower[i_d_] + self.dofs_info[dofs_idx[i_d_]].limit[1] = upper[i_d_] def set_dofs_velocity(self, velocity, dofs_idx, envs_idx=None): velocity, dofs_idx, envs_idx = self._validate_1D_io_variables(velocity, dofs_idx, envs_idx) @@ -3437,7 +3704,8 @@ def _kernel_set_dofs_position( for i_e, i_b_ in ti.ndrange(self.n_entities, envs_idx.shape[0]): i_b = envs_idx[i_b_] for i_l in range(self.entities_info[i_e].link_start, self.entities_info[i_e].link_end): - l_info = self.links_info[i_l] + I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l + l_info = self.links_info[I_l] if l_info.joint_type == gs.JOINT_TYPE.FREE: xyz = ti.Vector( @@ -3664,6 +3932,60 @@ def _kernel_get_links_COM_shift( for i in ti.static(range(3)): tensor[i_b_, i_l_, i] = self.links_state[links_idx[i_l_], envs_idx[i_b_]].i_pos_shift[i] + def _get_links_info(self, links_idx, name, envs_idx=None): + if self._options.batch_links_info: + tensor, links_idx, envs_idx = self._validate_1D_io_variables( + None, links_idx, envs_idx, idx_name="links_idx" + ) + else: + tensor, links_idx = self._validate_1D_io_variables(None, links_idx, idx_name="links_idx", batched=False) + envs_idx = torch.empty(()) + + if name == "invweight": + self._kernel_get_links_invweight(tensor, links_idx, envs_idx) + return tensor + elif name == "inertial_mass": + self._kernel_get_links_inertial_mass(tensor, links_idx, envs_idx) + return tensor + else: + gs.raise_exception(f"Invalid `name` {name}.") + + def get_links_inertial_mass(self, links_idx, envs_idx=None): + return self._get_links_info(links_idx, "inertial_mass", envs_idx) + + @ti.kernel + def _kernel_get_links_inertial_mass( + self, + tensor: ti.types.ndarray(), + links_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + ): + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) + if ti.static(self._options.batch_links_info): + for i_l_, i_b_ in ti.ndrange(links_idx.shape[0], envs_idx.shape[0]): + tensor[i_b_, i_l_] = self.links_info[links_idx[i_l_], envs_idx[i_b_]].inertial_mass + else: + for i_l_ in range(links_idx.shape[0]): + tensor[i_l_] = self.links_info[links_idx[i_l_]].inertial_mass + + def get_links_invweight(self, links_idx, envs_idx=None): + return self._get_links_info(links_idx, "invweight", envs_idx) + + @ti.kernel + def _kernel_get_links_invweight( + self, + tensor: ti.types.ndarray(), + links_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + ): + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) + if ti.static(self._options.batch_links_info): + for i_l_, i_b_ in ti.ndrange(links_idx.shape[0], envs_idx.shape[0]): + tensor[i_b_, i_l_] = self.links_info[links_idx[i_l_], envs_idx[i_b_]].invweight + else: + for i_l_ in range(links_idx.shape[0]): + tensor[i_l_] = self.links_info[links_idx[i_l_]].invweight + def get_geoms_friction_ratio(self, geoms_idx, envs_idx=None): tensor, geoms_idx, envs_idx = self._validate_1D_io_variables(None, geoms_idx, envs_idx, idx_name="geoms_idx") @@ -3767,20 +4089,21 @@ def _kernel_get_dofs_control_force( for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): i_d = dofs_idx[i_d_] i_b = envs_idx[i_b_] + I_d = [i_d, i_b] if ti.static(self._options.batch_dofs_info) else i_d force = gs.ti_float(0.0) if self.dofs_state[i_d, i_b].ctrl_mode == gs.CTRL_MODE.FORCE: force = self.dofs_state[i_d, i_b].ctrl_force elif self.dofs_state[i_d, i_b].ctrl_mode == gs.CTRL_MODE.VELOCITY: - force = self.dofs_info[i_d].kv * (self.dofs_state[i_d, i_b].ctrl_vel - self.dofs_state[i_d, i_b].vel) + force = self.dofs_info[I_d].kv * (self.dofs_state[i_d, i_b].ctrl_vel - self.dofs_state[i_d, i_b].vel) elif self.dofs_state[i_d, i_b].ctrl_mode == gs.CTRL_MODE.POSITION: force = ( - self.dofs_info[i_d].kp * (self.dofs_state[i_d, i_b].ctrl_pos - self.dofs_state[i_d, i_b].pos) - - self.dofs_info[i_d].kv * self.dofs_state[i_d, i_b].vel + self.dofs_info[I_d].kp * (self.dofs_state[i_d, i_b].ctrl_pos - self.dofs_state[i_d, i_b].pos) + - self.dofs_info[I_d].kv * self.dofs_state[i_d, i_b].vel ) tensor[i_b_, i_d_] = ti.math.clamp( force, - self.dofs_info[i_d].force_range[0], - self.dofs_info[i_d].force_range[1], + self.dofs_info[I_d].force_range[0], + self.dofs_info[I_d].force_range[1], ) @ti.kernel @@ -3828,29 +4151,61 @@ def get_dofs_force_range(self, dofs_idx, envs_idx=None): def get_dofs_limit(self, dofs_idx, envs_idx=None): return self._get_dofs_info(dofs_idx, "limit", envs_idx) + def get_dofs_stiffness(self, dofs_idx, envs_idx=None): + return self._get_dofs_info(dofs_idx, "stiffness", envs_idx) + + def get_dofs_invweight(self, dofs_idx, envs_idx=None): + return self._get_dofs_info(dofs_idx, "invweight", envs_idx) + + def get_dofs_armature(self, dofs_idx, envs_idx=None): + return self._get_dofs_info(dofs_idx, "armature", envs_idx) + + def get_dofs_damping(self, dofs_idx, envs_idx=None): + return self._get_dofs_info(dofs_idx, "damping", envs_idx) + def _get_dofs_info(self, dofs_idx, name, envs_idx=None): - tensor, dofs_idx = self._validate_1D_io_variables(None, dofs_idx, batched=False) + if self._options.batch_dofs_info: + tensor, dofs_idx, envs_idx = self._validate_1D_io_variables(None, dofs_idx, envs_idx) + else: + tensor, dofs_idx = self._validate_1D_io_variables(None, dofs_idx, batched=False) + envs_idx = torch.empty(()) if name == "kp": - self._kernel_get_dofs_kp(tensor, dofs_idx) + self._kernel_get_dofs_kp(tensor, dofs_idx, envs_idx) return tensor elif name == "kv": - self._kernel_get_dofs_kv(tensor, dofs_idx) + self._kernel_get_dofs_kv(tensor, dofs_idx, envs_idx) return tensor elif name == "force_range": lower = torch.empty_like(tensor) upper = torch.empty_like(tensor) - self._kernel_get_dofs_force_range(lower, upper, dofs_idx) + self._kernel_get_dofs_force_range(lower, upper, dofs_idx, envs_idx) return lower, upper elif name == "limit": lower = torch.empty_like(tensor) upper = torch.empty_like(tensor) - self._kernel_get_dofs_limit(lower, upper, dofs_idx) + self._kernel_get_dofs_limit(lower, upper, dofs_idx, envs_idx) return lower, upper + elif name == "stiffness": + self._kernel_get_dofs_stiffness(tensor, dofs_idx, envs_idx) + return tensor + + elif name == "invweight": + self._kernel_get_dofs_invweight(tensor, dofs_idx, envs_idx) + return tensor + + elif name == "armature": + self._kernel_get_dofs_armature(tensor, dofs_idx, envs_idx) + return tensor + + elif name == "damping": + self._kernel_get_dofs_damping(tensor, dofs_idx, envs_idx) + return tensor + else: gs.raise_exception() @@ -3859,20 +4214,30 @@ def _kernel_get_dofs_kp( self, tensor: ti.types.ndarray(), dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), ): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) - for i_d_ in range(dofs_idx.shape[0]): - tensor[i_d_] = self.dofs_info[dofs_idx[i_d_]].kp + if ti.static(self._options.batch_dofs_info): + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + tensor[i_b_, i_d_] = self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].kp + else: + for i_d_ in range(dofs_idx.shape[0]): + tensor[i_d_] = self.dofs_info[dofs_idx[i_d_]].kp @ti.kernel def _kernel_get_dofs_kv( self, tensor: ti.types.ndarray(), dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), ): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) - for i_d_ in range(dofs_idx.shape[0]): - tensor[i_d_] = self.dofs_info[dofs_idx[i_d_]].kv + if ti.static(self._options.batch_dofs_info): + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + tensor[i_b_, i_d_] = self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].kv + else: + for i_d_ in range(dofs_idx.shape[0]): + tensor[i_d_] = self.dofs_info[dofs_idx[i_d_]].kv @ti.kernel def _kernel_get_dofs_force_range( @@ -3880,11 +4245,17 @@ def _kernel_get_dofs_force_range( lower: ti.types.ndarray(), upper: ti.types.ndarray(), dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), ): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) - for i_d_ in range(dofs_idx.shape[0]): - lower[i_d_] = self.dofs_info[dofs_idx[i_d_]].force_range[0] - upper[i_d_] = self.dofs_info[dofs_idx[i_d_]].force_range[1] + if ti.static(self._options.batch_dofs_info): + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + lower[i_b_, i_d_] = self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].force_range[0] + upper[i_b_, i_d_] = self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].force_range[1] + else: + for i_d_ in range(dofs_idx.shape[0]): + lower[i_d_] = self.dofs_info[dofs_idx[i_d_]].force_range[0] + upper[i_d_] = self.dofs_info[dofs_idx[i_d_]].force_range[1] @ti.kernel def _kernel_get_dofs_limit( @@ -3892,11 +4263,77 @@ def _kernel_get_dofs_limit( lower: ti.types.ndarray(), upper: ti.types.ndarray(), dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), ): ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) - for i_d_ in range(dofs_idx.shape[0]): - lower[i_d_] = self.dofs_info[dofs_idx[i_d_]].limit[0] - upper[i_d_] = self.dofs_info[dofs_idx[i_d_]].limit[1] + if ti.static(self._options.batch_dofs_info): + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + lower[i_b_, i_d_] = self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].limit[0] + upper[i_b_, i_d_] = self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].limit[1] + else: + for i_d_ in range(dofs_idx.shape[0]): + lower[i_d_] = self.dofs_info[dofs_idx[i_d_]].limit[0] + upper[i_d_] = self.dofs_info[dofs_idx[i_d_]].limit[1] + + @ti.kernel + def _kernel_get_dofs_stiffness( + self, + tensor: ti.types.ndarray(), + dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + ): + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) + if ti.static(self._options.batch_dofs_info): + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + tensor[i_b_, i_d_] = self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].stiffness + else: + for i_d_ in range(dofs_idx.shape[0]): + tensor[i_d_] = self.dofs_info[dofs_idx[i_d_]].stiffness + + @ti.kernel + def _kernel_get_dofs_invweight( + self, + tensor: ti.types.ndarray(), + dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + ): + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) + if ti.static(self._options.batch_dofs_info): + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + tensor[i_b_, i_d_] = self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].invweight + else: + for i_d_ in range(dofs_idx.shape[0]): + tensor[i_d_] = self.dofs_info[dofs_idx[i_d_]].invweight + + @ti.kernel + def _kernel_get_dofs_armature( + self, + tensor: ti.types.ndarray(), + dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + ): + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) + if ti.static(self._options.batch_dofs_info): + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + tensor[i_b_, i_d_] = self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].armature + else: + for i_d_ in range(dofs_idx.shape[0]): + tensor[i_d_] = self.dofs_info[dofs_idx[i_d_]].armature + + @ti.kernel + def _kernel_get_dofs_damping( + self, + tensor: ti.types.ndarray(), + dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + ): + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) + if ti.static(self._options.batch_dofs_info): + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + tensor[i_b_, i_d_] = self.dofs_info[dofs_idx[i_d_], envs_idx[i_b_]].damping + else: + for i_d_ in range(dofs_idx.shape[0]): + tensor[i_d_] = self.dofs_info[dofs_idx[i_d_]].damping @ti.kernel def _kernel_set_drone_rpm( diff --git a/genesis/options/solvers.py b/genesis/options/solvers.py index 54e80b18..1b3b6c51 100644 --- a/genesis/options/solvers.py +++ b/genesis/options/solvers.py @@ -185,6 +185,10 @@ class RigidOptions(Options): integrator: gs.integrator = gs.integrator.approximate_implicitfast IK_max_targets: int = 6 + # batching info + batch_links_info: Optional[bool] = False + batch_dofs_info: Optional[bool] = False + # constraint solver constraint_solver: gs.constraint_solver = gs.constraint_solver.CG iterations: int = 100