From 52a2d8fb229c49c9878722f4dfde3bcbd2b9b6c6 Mon Sep 17 00:00:00 2001 From: Ben Eisner Date: Wed, 22 May 2024 17:47:38 -0400 Subject: [PATCH] make this a flag --- src/rpad/rlbench_utils/placement_dataset.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/rpad/rlbench_utils/placement_dataset.py b/src/rpad/rlbench_utils/placement_dataset.py index a63ff46..3762215 100644 --- a/src/rpad/rlbench_utils/placement_dataset.py +++ b/src/rpad/rlbench_utils/placement_dataset.py @@ -195,6 +195,7 @@ def get_anchor_points( use_from_simulator=False, handle_mapping=None, names_to_handles=None, + gripper_in_first_phase=True, ): if use_from_simulator: handle_mapping = { @@ -212,7 +213,7 @@ def get_anchor_points( names = BACKGROUND_NAMES + ROBOT_NONGRIPPER_NAMES # If it's the first phase, we also omit the gripper. - if phase == TASK_DICT[task_name]["phase_order"][0]: + if phase == TASK_DICT[task_name]["phase_order"][0] and gripper_in_first_phase: names += GRIPPER_OBJ_NAMES return filter_out_names(rgb, point_cloud, mask, handle_mapping, names) @@ -279,6 +280,7 @@ def __init__( anchor_mode: AnchorMode = AnchorMode.SINGLE_OBJECT, action_mode: ActionMode = ActionMode.OBJECT, include_wrist_cam: bool = False, + gripper_in_first_phase: bool = True, ) -> None: """Dataset for RL-Bench placement tasks. @@ -336,6 +338,7 @@ def leaf_fn(path, x): raise ValueError("Anchor mode must be one of the AnchorMode enum values.") self.action_mode = action_mode self.anchor_mode = anchor_mode + self.gripper_in_first_phase = gripper_in_first_phase if cache: self.memory = Memory( @@ -449,6 +452,7 @@ def _select_anchor_vals(rgb, point_cloud, mask): use_from_simulator=False, handle_mapping=self.handle_mapping, names_to_handles=self.names_to_handles, + gripper_in_first_phase=self.gripper_in_first_phase, ) # Merge all the initial point clouds and masks into one.