Skip to content

Commit

Permalink
[Bug Fix] avoid changing input link index in rigid_entity (#482)
Browse files Browse the repository at this point in the history
* avoid changing input link index in rigid_entity

* use _get_ls_idx to convert local indx to global idx

* updated api call to use ls_idx_local in examaple

* reforamt to let pre_commit pass
  • Loading branch information
woshialex authored Jan 11, 2025
1 parent 0d5cd8f commit 138ece3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 20 deletions.
6 changes: 3 additions & 3 deletions examples/rigid/domain_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ def main():
########################## domain randomization ##########################
robot.set_friction_ratio(
friction_ratio=0.5 + torch.rand(scene.n_envs, robot.n_links),
link_indices=np.arange(0, robot.n_links),
ls_idx_local=np.arange(0, robot.n_links),
)
robot.set_mass_shift(
mass_shift=-0.5 + torch.rand(scene.n_envs, robot.n_links),
link_indices=np.arange(0, robot.n_links),
ls_idx_local=np.arange(0, robot.n_links),
)
robot.set_COM_shift(
com_shift=-0.05 + 0.1 * torch.rand(scene.n_envs, robot.n_links, 3),
link_indices=np.arange(0, robot.n_links),
ls_idx_local=np.arange(0, robot.n_links),
)

joint_names = [
Expand Down
30 changes: 13 additions & 17 deletions genesis/engine/entities/rigid_entity/rigid_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2310,27 +2310,27 @@ def get_links_net_contact_force(self):

return entity_links_force

def set_friction_ratio(self, friction_ratio, link_indices, envs_idx=None):
def set_friction_ratio(self, friction_ratio, ls_idx_local, envs_idx=None):
"""
Set the friction ratio of the geoms of the specified links.
Parameters
----------
friction_ratio : torch.Tensor, shape (n_envs, n_links)
The friction ratio
link_indices : array_like
ls_idx_local : array_like
The indices of the links to set friction ratio.
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
"""
geom_indices = []
for i in link_indices:
for j in range(self._links[i].n_geoms):
geom_indices.append(self._links[i]._geom_start + j)
geom_indices = [
self._links[il]._geom_start + g_idx for il in ls_idx_local for g_idx in range(self._links[il].n_geoms)
]

self._solver.set_geoms_friction_ratio(
torch.cat(
[
ratio.unsqueeze(-1).repeat(1, self._links[j].n_geoms)
for j, ratio in zip(link_indices, friction_ratio.unbind(-1))
for j, ratio in zip(ls_idx_local, friction_ratio.unbind(-1))
],
dim=-1,
),
Expand Down Expand Up @@ -2359,37 +2359,33 @@ def set_friction(self, friction):
for link in self._links:
link.set_friction(friction)

def set_mass_shift(self, mass_shift, link_indices, envs_idx=None):
def set_mass_shift(self, mass_shift, ls_idx_local, envs_idx=None):
"""
Set the mass shift of specified links.
Parameters
----------
mass : torch.Tensor, shape (n_envs, n_links)
The mass shift
link_indices : array_like
ls_idx_local : array_like
The indices of the links to set mass shift.
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
"""
for i in range(len(link_indices)):
link_indices[i] += self._link_start
self._solver.set_links_mass_shift(mass_shift, link_indices, envs_idx)
self._solver.set_links_mass_shift(mass_shift, self._get_ls_idx(ls_idx_local), envs_idx)

def set_COM_shift(self, com_shift, link_indices, envs_idx=None):
def set_COM_shift(self, com_shift, ls_idx_local, envs_idx=None):
"""
Set the center of mass (COM) shift of specified links.
Parameters
----------
com : torch.Tensor, shape (n_envs, n_links, 3)
The COM shift
link_indices : array_like
ls_idx_local : array_like
The indices of the links to set COM shift.
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
"""
for i in range(len(link_indices)):
link_indices[i] += self._link_start
self._solver.set_links_COM_shift(com_shift, link_indices, envs_idx)
self._solver.set_links_COM_shift(com_shift, self._get_ls_idx(ls_idx_local), envs_idx)

@gs.assert_built
def set_links_inertial_mass(self, inertial_mass, ls_idx_local=None, envs_idx=None):
Expand Down

0 comments on commit 138ece3

Please sign in to comment.