Skip to content

Commit

Permalink
Support using env_states when replaying trajectories
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiayuan-Gu committed Nov 28, 2022
1 parent 286d753 commit 5178d73
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 10 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,9 @@ python -m mani_skill2.trajectory.replay_trajectory --traj-path demos/rigid_body_
- `--num-procs=10`: split trajectories to multiple processes (e.g., 10 processes) for acceleration.
- `--obs-mode=none`: specify the observation mode as `none`, i.e. not saving any observations.
- `--obs-mode=rgbd`: (not included in the script above) specify the observation mode as `rgbd` to replay the trajectory. If `--save-traj`, the saved trajectory will contain the RGBD observations. RGB images are saved as uint8 and depth images (multiplied by 1024) are saved as uint16.
- `--obs-mode=pointcloud`: (not included in the script above) specify the observation mode as `pointcloud`. We encourage you to further process the point cloud instead of using this point clould directly.
- `--obs-mode=pointcloud`: (not included in the script above) specify the observation mode as `pointcloud`. We encourage you to further process the point cloud instead of using this point cloud directly.
- `--obs-mode=state`: (not included in the script above) specify the observation mode as `state`. Note that the `state` observation mode is not allowed for challenge submission.
- `--use-env-states`: use the recorded state to set the environment after the action is replayed at each step. It is necessary to replay trajectories recorded by MPC or RL (e.g., tasks migrated from ManiSkill1).

</details>

Expand Down
1 change: 1 addition & 0 deletions mani_skill2/envs/ms1/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _get_default_scene_config(self):
return scene_config

def reset(self, seed=None, reconfigure=False, model_id=None):
self._prev_actor_pose = None
self.set_episode_rng(seed)
_reconfigure = self._set_model(model_id)
reconfigure = _reconfigure or reconfigure
Expand Down
6 changes: 4 additions & 2 deletions mani_skill2/envs/ms1/move_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def _set_bucket_links_pcd(self):
# -------------------------------------------------------------------------- #
# Success metric and shaped reward
# -------------------------------------------------------------------------- #

def evaluate(self, **kwargs):
w2b = (
self.bucket_body_link.pose.inv().to_transformation_matrix()
Expand Down Expand Up @@ -394,10 +393,13 @@ def compute_dense_reward(self, action, info: dict, **kwargs):
# ---------------------------------------------------------------------------- #
# Observation
# ---------------------------------------------------------------------------- #

def _get_task_actors(self):
return self.balls

def _get_task_articulations(self):
# bucket max dof is 1 in our data
return [(self.bucket, 2)]

def set_state(self, state: np.ndarray):
super().set_state(state)
self._prev_actor_pose = self.bucket.pose
4 changes: 4 additions & 0 deletions mani_skill2/envs/ms1/open_cabinet_door_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,10 @@ def _get_task_articulations(self):
# The maximum DoF is 6 in our data.
return [(self.cabinet, 8)]

def set_state(self, state: np.ndarray):
super().set_state(state)
self._prev_actor_pose = self.target_link.pose


@register_gym_env(name="OpenCabinetDoor-v1", max_episode_steps=200)
class OpenCabinetDoorEnv(OpenCabinetEnv):
Expand Down
4 changes: 4 additions & 0 deletions mani_skill2/envs/ms1/push_chair.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,7 @@ def compute_dense_reward(self, action: np.ndarray, info: dict, **kwargs):
def _get_task_articulations(self):
# The maximum DoF is 20 in our data.
return [(self.chair, 25)]

def set_state(self, state: np.ndarray):
super().set_state(state)
self._prev_actor_pose = self.root_link.pose
39 changes: 32 additions & 7 deletions mani_skill2/trajectory/replay_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,33 @@ def from_pd_joint_delta_pos(
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--traj-path", type=str, required=True)
parser.add_argument("-o", "--obs-mode", type=str)
parser.add_argument("-c", "--target-control-mode", type=str)
parser.add_argument("-o", "--obs-mode", type=str, help="target observation mode")
parser.add_argument(
"-c", "--target-control-mode", type=str, help="target control mode"
)
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--save-traj", action="store_true")
parser.add_argument("--save-video", action="store_true")
parser.add_argument(
"--save-traj", action="store_true", help="whether to save trajectories"
)
parser.add_argument(
"--save-video", action="store_true", help="whether to save videos"
)
parser.add_argument("--num-procs", type=int, default=1)
parser.add_argument("--max-retry", type=int, default=0)
parser.add_argument("--discard-timeout", action="store_true")
parser.add_argument("--allow-failure", action="store_true")
parser.add_argument(
"--discard-timeout",
action="store_true",
help="whether to discard timeout episodes",
)
parser.add_argument(
"--allow-failure", action="store_true", help="whether to allow failure episodes"
)
parser.add_argument("--vis", action="store_true")
parser.add_argument(
"--use-env-states",
action="store_true",
help="whether to replay by env states instead of actions",
)
return parser.parse_args()


Expand Down Expand Up @@ -397,19 +414,25 @@ def _main(args, proc_id: int = 0, num_procs=1, pbar=None):
# Original actions to replay
ori_actions = ori_h5_file[traj_id]["actions"][:]

# Original env states to replay
if args.use_env_states:
ori_env_states = ori_h5_file[traj_id]["env_states"][1:]

info = {}

# Without conversion between control modes
if target_control_mode is None:
n = len(ori_actions)
if pbar is not None:
pbar.reset(total=n)
for a in ori_actions:
for t, a in enumerate(ori_actions):
if pbar is not None:
pbar.update()
_, _, _, info = env.step(a)
if args.vis:
env.render()
if args.use_env_states:
env.set_state(ori_env_states[t])

# From joint position to others
elif ori_control_mode == "pd_joint_pos":
Expand Down Expand Up @@ -447,6 +470,8 @@ def _main(args, proc_id: int = 0, num_procs=1, pbar=None):
else:
# Rollback episode id for failed attempts
env._episode_id -= 1
if args.verbose:
print("info", info)
else:
tqdm.write(f"Episode {episode_id} is not replayed successfully. Skipping")

Expand Down

0 comments on commit 5178d73

Please sign in to comment.