diff --git a/zsos/semexp_env/eval.py b/zsos/semexp_env/eval.py index ea8503d..985be0c 100644 --- a/zsos/semexp_env/eval.py +++ b/zsos/semexp_env/eval.py @@ -10,6 +10,7 @@ from zsos.semexp_env.semexp_policy import SemExpITMPolicyV3 from zsos.utils.img_utils import reorient_rescale_map, resize_images +from zsos.utils.log_saver import is_evaluated, log_episode os.environ["OMP_NUM_THREADS"] = "1" @@ -57,29 +58,46 @@ def main(): torch.set_num_threads(1) envs = make_vec_envs(args) obs, infos = envs.reset() - + ep_id, scene_id, target_object = None, None, None for ep_num in range(num_episodes): vis_imgs = [] for step in range(args.max_episode_length): - obs_dict = merge_obs_infos(obs, infos) if step == 0: masks = torch.zeros(1, 1, device=obs.device) + ep_id, scene_id = infos[0]["episode_id"], infos[0]["scene_id"] + target_object = infos[0]["goal_name"] + print("Episode:", ep_id, "Scene:", scene_id) else: masks = torch.ones(1, 1, device=obs.device) - action, policy_infos = policy.act(obs_dict, masks) - if "VIDEO_DIR" in os.environ: - vis_imgs.append(create_frame(policy_infos)) + if "ZSOS_LOG_DIR" in os.environ and is_evaluated(ep_id, scene_id): + print(f"Episode {ep_id} in scene {scene_id} already evaluated") + # Call stop action to move on to the next episode + obs, rew, done, infos = envs.step(torch.tensor([0], dtype=torch.long)) + else: + obs_dict = merge_obs_infos(obs, infos) + action, policy_infos = policy.act(obs_dict, masks) - action = action.squeeze(0) + if "VIDEO_DIR" in os.environ: + vis_imgs.append(create_frame(policy_infos)) - obs, rew, done, infos = envs.step(action) + action = action.squeeze(0) + + obs, rew, done, infos = envs.step(action) if done: print("Success:", infos[0]["success"]) print("SPL:", infos[0]["spl"]) + data = { + "success": infos[0]["success"], + "spl": infos[0]["spl"], + "distance_to_goal": infos[0]["distance_to_goal"], + "target_object": target_object, + } if "VIDEO_DIR" in os.environ: - generate_video(vis_imgs, infos[0]) + generate_video(vis_imgs, ep_id, scene_id, data) + if "ZSOS_LOG_DIR" in os.environ and not is_evaluated(ep_id, scene_id): + log_episode(ep_id, scene_id, data) break print("Test successfully completed") @@ -136,7 +154,9 @@ def create_frame(policy_infos: Dict[str, Any]) -> np.ndarray: return vis_img -def generate_video(frames: List[np.ndarray], infos: Dict[str, Any]) -> None: +def generate_video( + frames: List[np.ndarray], ep_id: str, scene_id: str, infos: Dict[str, Any] +) -> None: """ Saves the given list of rgb frames as a video at 10 FPS. Uses the infos to get the files name, which should contain the following: @@ -151,16 +171,16 @@ def generate_video(frames: List[np.ndarray], infos: Dict[str, Any]) -> None: video_dir = os.environ.get("VIDEO_DIR", "video_dir") if not os.path.exists(video_dir): os.makedirs(video_dir) - episode_id = int(infos["episode_id"]) - scene_id = infos["scene_id"] + episode_id = int(ep_id) success = int(infos["success"]) spl = infos["spl"] dtg = infos["distance_to_goal"] - goal_name = infos["goal_name"] + goal_name = infos["target_object"] filename = ( f"epid={episode_id:03d}-scid={scene_id}-succ={success}-spl={spl:.2f}" - f"-dtg={dtg:.2f}-goal={goal_name}.mp4" + f"-dtg={dtg:.2f}-target={goal_name}.mp4" ) + filename = os.path.join(video_dir, filename) # Create a video clip from the frames clip = ImageSequenceClip(frames, fps=10) diff --git a/zsos/utils/episode_stats_logger.py b/zsos/utils/episode_stats_logger.py index 5d5b32c..629db48 100644 --- a/zsos/utils/episode_stats_logger.py +++ b/zsos/utils/episode_stats_logger.py @@ -1,4 +1,3 @@ -import json import os from typing import Any, Dict @@ -8,6 +7,7 @@ from frontier_exploration.utils.general_utils import xyz_to_habitat from zsos.utils.geometry_utils import transform_points from zsos.utils.habitat_visualizer import sim_xy_to_grid_xy +from zsos.utils.log_saver import log_episode def log_episode_stats(episode_id: int, scene_id: str, infos: Dict) -> str: @@ -26,29 +26,15 @@ def log_episode_stats(episode_id: int, scene_id: str, infos: Dict) -> str: print(f"Episode {episode_id} in scene {scene} failed due to '{failure_cause}'.") if "ZSOS_LOG_DIR" in os.environ: - log_dir = os.environ["ZSOS_LOG_DIR"] - try: - os.makedirs(log_dir, exist_ok=True) - except Exception: - pass - base = f"{episode_id}_{scene}.json" - filename = os.path.join(log_dir, base) - infos_no_map = infos.copy() infos_no_map.pop("top_down_map") data = { - "episode_id": episode_id, - "scene_id": scene_id, "failure_cause": failure_cause, **remove_numpy_arrays(infos_no_map), } - # Skip if the filename already exists AND it isn't empty - if not (os.path.exists(filename) and os.path.getsize(filename) > 0): - print(f"Logging episode {int(episode_id):04d} to {filename}") - with open(filename, "w") as f: - json.dump(data, f, indent=4) + log_episode(episode_id, scene, data) return failure_cause diff --git a/zsos/utils/log_saver.py b/zsos/utils/log_saver.py new file mode 100644 index 0000000..9994e98 --- /dev/null +++ b/zsos/utils/log_saver.py @@ -0,0 +1,44 @@ +import json +import os +import time +from typing import Dict, Union + + +def log_episode(episode_id: Union[str, int], scene_id: str, data: Dict) -> None: + log_dir = os.environ["ZSOS_LOG_DIR"] + try: + os.makedirs(log_dir, exist_ok=True) + except Exception: + pass + base = f"{episode_id}_{scene_id}.json" + filename = os.path.join(log_dir, base) + + # Skip if the filename already exists AND it isn't empty + if not (os.path.exists(filename) and os.path.getsize(filename) > 0): + print(f"Logging episode {int(episode_id):04d} to {filename}") + with open(filename, "w") as f: + json.dump( + {"episode_id": episode_id, "scene_id": scene_id, **data}, f, indent=4 + ) + + +def is_evaluated(episode_id: Union[str, int], scene_id: str) -> bool: + log_dir = os.environ["ZSOS_LOG_DIR"] + base = f"{episode_id}_{scene_id}.json" + filename = os.path.join(log_dir, base) + + # Return false if the directory doesn't exist + if not os.path.exists(log_dir): + return False + + # Delete any empty files that are older than 5 minutes + for f in os.listdir(log_dir): + try: + if os.path.getsize(os.path.join(log_dir, f)) == 0 and ( + time.time() - os.path.getmtime(os.path.join(log_dir, f)) > 300 + ): + os.remove(os.path.join(log_dir, f)) + except Exception: + pass + + return os.path.exists(filename)