Skip to content

Commit

Permalink
Simplify pick single reward to be more friendly to RL (#27)
Browse files Browse the repository at this point in the history
* For PickSingle, change the original complex reward to simple reward, since the original reward can be unfriendly for RL, even though MPC can solve many objects through the original reward. The old reward is renamed as legacy.
  • Loading branch information
xuanlinli17 authored Sep 22, 2022
1 parent 787439b commit dd6554f
Showing 1 changed file with 40 additions and 1 deletion.
41 changes: 40 additions & 1 deletion mani_skill2/envs/pick_and_place/pick_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,45 @@ def evaluate(self, **kwargs):
)

def compute_dense_reward(self, info, **kwargs):

# Sep. 14, 2022:
# We changed the original complex reward to simple reward,
# since the original reward can be unfriendly for RL,
# even though MPC can solve many objects through the original reward.

reward = 0.0

if info["success"]:
reward = 10.0
else:
obj_pose = self.obj_pose

# reaching reward
tcp_wrt_obj_pose = obj_pose.inv() * self.tcp.pose
tcp_to_obj_dist = np.linalg.norm(tcp_wrt_obj_pose.p)
reaching_reward = 1 - np.tanh(
3.0 * np.maximum(tcp_to_obj_dist - np.linalg.norm(self.model_bbox_size), 0.0)
)
reward = reward + reaching_reward

# grasp reward
is_grasped = self.agent.check_grasp(self.obj, max_angle=30)
reward += 3.0 if is_grasped else 0.0

# reaching-goal reward
if is_grasped:
obj_to_goal_pos = self.goal_pos - obj_pose.p
obj_to_goal_dist = np.linalg.norm(obj_to_goal_pos)
reaching_goal_reward = 3 * (1 - np.tanh(3.0 * obj_to_goal_dist))
reward += reaching_goal_reward

return reward

def compute_dense_reward_legacy(self, info, **kwargs):
# original complex reward that is geometry-independent,
# which ensures that MPC can successfully pick up most objects,
# but can be unfriendly for RL.

reward = 0.0
# hard code gripper info
finger_length = 0.025
Expand Down Expand Up @@ -324,7 +363,7 @@ def compute_dense_reward(self, info, **kwargs):
gripper_width / 2 - self.agent.robot.get_qpos()[-2:]
) # ensures that gripper is open

return reward
return reward

def render(self, mode="human"):
if mode in ["human", "rgb_array"]:
Expand Down

0 comments on commit dd6554f

Please sign in to comment.