Skip to content

Commit

Permalink
make this a flag
Browse files Browse the repository at this point in the history
  • Loading branch information
beneisner committed May 22, 2024
1 parent 634f5b6 commit 52a2d8f
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/rpad/rlbench_utils/placement_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def get_anchor_points(
use_from_simulator=False,
handle_mapping=None,
names_to_handles=None,
gripper_in_first_phase=True,
):
if use_from_simulator:
handle_mapping = {
Expand All @@ -212,7 +213,7 @@ def get_anchor_points(
names = BACKGROUND_NAMES + ROBOT_NONGRIPPER_NAMES

# If it's the first phase, we also omit the gripper.
if phase == TASK_DICT[task_name]["phase_order"][0]:
if phase == TASK_DICT[task_name]["phase_order"][0] and gripper_in_first_phase:
names += GRIPPER_OBJ_NAMES

return filter_out_names(rgb, point_cloud, mask, handle_mapping, names)
Expand Down Expand Up @@ -279,6 +280,7 @@ def __init__(
anchor_mode: AnchorMode = AnchorMode.SINGLE_OBJECT,
action_mode: ActionMode = ActionMode.OBJECT,
include_wrist_cam: bool = False,
gripper_in_first_phase: bool = True,
) -> None:
"""Dataset for RL-Bench placement tasks.
Expand Down Expand Up @@ -336,6 +338,7 @@ def leaf_fn(path, x):
raise ValueError("Anchor mode must be one of the AnchorMode enum values.")
self.action_mode = action_mode
self.anchor_mode = anchor_mode
self.gripper_in_first_phase = gripper_in_first_phase

if cache:
self.memory = Memory(
Expand Down Expand Up @@ -449,6 +452,7 @@ def _select_anchor_vals(rgb, point_cloud, mask):
use_from_simulator=False,
handle_mapping=self.handle_mapping,
names_to_handles=self.names_to_handles,
gripper_in_first_phase=self.gripper_in_first_phase,
)

# Merge all the initial point clouds and masks into one.
Expand Down

0 comments on commit 52a2d8f

Please sign in to comment.