diff --git a/mani_skill2/envs/sapien_env.py b/mani_skill2/envs/sapien_env.py index d18d99677..75da7fad4 100644 --- a/mani_skill2/envs/sapien_env.py +++ b/mani_skill2/envs/sapien_env.py @@ -42,6 +42,11 @@ class BaseEnv(gym.Env): sim_freq (int): simulation frequency (Hz) control_freq (int): control frequency (Hz) device (str): GPU device for renderer, e.g., 'cuda:x' + enable_shadow (bool): whether to enable shadow for lights. Defaults to False. + enable_gt_seg (bool): whether to include GT segmentaiton masks in observations. Defaults to False. + enable_kuafu (bool): whether to use KuafuRenderer (ray-tracing). + KuafuRenderer is only experimentally and partially supported now. + kuafu_kwargs (dict, optional): kwargs to set KuafuRenderer. """ # fmt: off @@ -65,14 +70,31 @@ def __init__( device: str = "", enable_shadow=False, enable_gt_seg=False, + enable_kuafu=False, + kuafu_kwargs: Optional[dict] = None, ): - # SAPIEN + # Create SAPIEN engine self._engine = sapien.Engine() - self._renderer = sapien.VulkanRenderer(default_mipmap_levels=1, device=device) + + # Create SAPIEN renderer + self._enable_kuafu = enable_kuafu + if self._enable_kuafu: + kuafu_config = sapien.KuafuConfig() + if kuafu_kwargs is not None: + for k, v in kuafu_kwargs.items(): + setattr(kuafu_config, k, v) + self._renderer = sapien.KuafuRenderer(kuafu_config) + logger.warning("Only rgb is supported by KuafuRenderer.") + else: + self._renderer = sapien.VulkanRenderer( + default_mipmap_levels=1, device=device + ) self._renderer.set_log_level("off") + self._engine.set_renderer(self._renderer) self._viewer = None + # Set simulation and control frequency self._sim_freq = sim_freq self._control_freq = control_freq if sim_freq % control_freq != 0: @@ -102,7 +124,7 @@ def __init__( self.enable_shadow = enable_shadow # For training purpose - self.enable_gt_seg = enable_gt_seg + self._enable_gt_seg = enable_gt_seg # TODO(jigu): `seed` is deprecated in the latest gym. self.seed() @@ -187,9 +209,12 @@ def _get_obs_extra(self) -> OrderedDict: def _get_obs_rgbd(self, **kwargs) -> OrderedDict: # Overwrite options if using GT segmentation - if self.enable_gt_seg: - kwargs["visual_seg"] = True - kwargs["actor_seg"] = True + if self._enable_gt_seg: + kwargs.update(visual_seg=True, actor_seg=True) + + # Overwrite options if using KuaFu renderer + if self._enable_kuafu: + kwargs.update(depth=False, visual_seg=False, actor_seg=False) return OrderedDict( image=self._get_obs_images(**kwargs), @@ -240,9 +265,12 @@ def _get_camera_images(self, camera: sapien.CameraEntity, **kwargs) -> OrderedDi def _get_obs_pointcloud(self, **kwargs): """Fuse pointclouds from all cameras in the world frame.""" # Overwrite options if using GT segmentation - if self.enable_gt_seg: - kwargs["visual_seg"] = True - kwargs["actor_seg"] = True + if self._enable_gt_seg: + kwargs.update(visual_seg=True, actor_seg=True) + + # Overwrite options if using KuaFu renderer + if self._enable_kuafu: + raise NotImplementedError("Do not support pointcloud mode for KuafuRenderer yet.") self.update_render() @@ -593,9 +621,16 @@ def render(self, mode="human", **kwargs): # NOTE(jigu): Must update renderer again # since some visual-only sites like goals should be hidden. self.update_render() - cameras_images = self._get_obs_images( - actor_seg=self.enable_gt_seg, visual_seg=self.enable_gt_seg - ) + + # Overwrite options if using GT segmentation + if self._enable_gt_seg: + kwargs.update(visual_seg=True, actor_seg=True) + + # Overwrite options if using KuaFu renderer + if self._enable_kuafu: + kwargs.update(depth=False, visual_seg=False, actor_seg=False) + + cameras_images = self._get_obs_images(**kwargs) for camera_images in cameras_images.values(): images.extend(observations_to_images(camera_images)) return tile_images(images)