Skip to content

Commit

Permalink
moving logging to different file without habitat dependencies, adding…
Browse files Browse the repository at this point in the history
… episode logging support for gibson
  • Loading branch information
naokiyokoyamabd committed Sep 8, 2023
1 parent 6bdb7b5 commit 82ac534
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 29 deletions.
46 changes: 33 additions & 13 deletions zsos/semexp_env/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from zsos.semexp_env.semexp_policy import SemExpITMPolicyV3
from zsos.utils.img_utils import reorient_rescale_map, resize_images
from zsos.utils.log_saver import is_evaluated, log_episode

os.environ["OMP_NUM_THREADS"] = "1"

Expand Down Expand Up @@ -57,29 +58,46 @@ def main():
torch.set_num_threads(1)
envs = make_vec_envs(args)
obs, infos = envs.reset()

ep_id, scene_id, target_object = None, None, None
for ep_num in range(num_episodes):
vis_imgs = []
for step in range(args.max_episode_length):
obs_dict = merge_obs_infos(obs, infos)
if step == 0:
masks = torch.zeros(1, 1, device=obs.device)
ep_id, scene_id = infos[0]["episode_id"], infos[0]["scene_id"]
target_object = infos[0]["goal_name"]
print("Episode:", ep_id, "Scene:", scene_id)
else:
masks = torch.ones(1, 1, device=obs.device)
action, policy_infos = policy.act(obs_dict, masks)

if "VIDEO_DIR" in os.environ:
vis_imgs.append(create_frame(policy_infos))
if "ZSOS_LOG_DIR" in os.environ and is_evaluated(ep_id, scene_id):
print(f"Episode {ep_id} in scene {scene_id} already evaluated")
# Call stop action to move on to the next episode
obs, rew, done, infos = envs.step(torch.tensor([0], dtype=torch.long))
else:
obs_dict = merge_obs_infos(obs, infos)
action, policy_infos = policy.act(obs_dict, masks)

action = action.squeeze(0)
if "VIDEO_DIR" in os.environ:
vis_imgs.append(create_frame(policy_infos))

obs, rew, done, infos = envs.step(action)
action = action.squeeze(0)

obs, rew, done, infos = envs.step(action)

if done:
print("Success:", infos[0]["success"])
print("SPL:", infos[0]["spl"])
data = {
"success": infos[0]["success"],
"spl": infos[0]["spl"],
"distance_to_goal": infos[0]["distance_to_goal"],
"target_object": target_object,
}
if "VIDEO_DIR" in os.environ:
generate_video(vis_imgs, infos[0])
generate_video(vis_imgs, ep_id, scene_id, data)
if "ZSOS_LOG_DIR" in os.environ and not is_evaluated(ep_id, scene_id):
log_episode(ep_id, scene_id, data)
break

print("Test successfully completed")
Expand Down Expand Up @@ -136,7 +154,9 @@ def create_frame(policy_infos: Dict[str, Any]) -> np.ndarray:
return vis_img


def generate_video(frames: List[np.ndarray], infos: Dict[str, Any]) -> None:
def generate_video(
frames: List[np.ndarray], ep_id: str, scene_id: str, infos: Dict[str, Any]
) -> None:
"""
Saves the given list of rgb frames as a video at 10 FPS. Uses the infos to get the
files name, which should contain the following:
Expand All @@ -151,16 +171,16 @@ def generate_video(frames: List[np.ndarray], infos: Dict[str, Any]) -> None:
video_dir = os.environ.get("VIDEO_DIR", "video_dir")
if not os.path.exists(video_dir):
os.makedirs(video_dir)
episode_id = int(infos["episode_id"])
scene_id = infos["scene_id"]
episode_id = int(ep_id)
success = int(infos["success"])
spl = infos["spl"]
dtg = infos["distance_to_goal"]
goal_name = infos["goal_name"]
goal_name = infos["target_object"]
filename = (
f"epid={episode_id:03d}-scid={scene_id}-succ={success}-spl={spl:.2f}"
f"-dtg={dtg:.2f}-goal={goal_name}.mp4"
f"-dtg={dtg:.2f}-target={goal_name}.mp4"
)

filename = os.path.join(video_dir, filename)
# Create a video clip from the frames
clip = ImageSequenceClip(frames, fps=10)
Expand Down
18 changes: 2 additions & 16 deletions zsos/utils/episode_stats_logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import os
from typing import Any, Dict

Expand All @@ -8,6 +7,7 @@
from frontier_exploration.utils.general_utils import xyz_to_habitat
from zsos.utils.geometry_utils import transform_points
from zsos.utils.habitat_visualizer import sim_xy_to_grid_xy
from zsos.utils.log_saver import log_episode


def log_episode_stats(episode_id: int, scene_id: str, infos: Dict) -> str:
Expand All @@ -26,29 +26,15 @@ def log_episode_stats(episode_id: int, scene_id: str, infos: Dict) -> str:
print(f"Episode {episode_id} in scene {scene} failed due to '{failure_cause}'.")

if "ZSOS_LOG_DIR" in os.environ:
log_dir = os.environ["ZSOS_LOG_DIR"]
try:
os.makedirs(log_dir, exist_ok=True)
except Exception:
pass
base = f"{episode_id}_{scene}.json"
filename = os.path.join(log_dir, base)

infos_no_map = infos.copy()
infos_no_map.pop("top_down_map")

data = {
"episode_id": episode_id,
"scene_id": scene_id,
"failure_cause": failure_cause,
**remove_numpy_arrays(infos_no_map),
}

# Skip if the filename already exists AND it isn't empty
if not (os.path.exists(filename) and os.path.getsize(filename) > 0):
print(f"Logging episode {int(episode_id):04d} to {filename}")
with open(filename, "w") as f:
json.dump(data, f, indent=4)
log_episode(episode_id, scene, data)

return failure_cause

Expand Down
44 changes: 44 additions & 0 deletions zsos/utils/log_saver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import json
import os
import time
from typing import Dict, Union


def log_episode(episode_id: Union[str, int], scene_id: str, data: Dict) -> None:
log_dir = os.environ["ZSOS_LOG_DIR"]
try:
os.makedirs(log_dir, exist_ok=True)
except Exception:
pass
base = f"{episode_id}_{scene_id}.json"
filename = os.path.join(log_dir, base)

# Skip if the filename already exists AND it isn't empty
if not (os.path.exists(filename) and os.path.getsize(filename) > 0):
print(f"Logging episode {int(episode_id):04d} to {filename}")
with open(filename, "w") as f:
json.dump(
{"episode_id": episode_id, "scene_id": scene_id, **data}, f, indent=4
)


def is_evaluated(episode_id: Union[str, int], scene_id: str) -> bool:
log_dir = os.environ["ZSOS_LOG_DIR"]
base = f"{episode_id}_{scene_id}.json"
filename = os.path.join(log_dir, base)

# Return false if the directory doesn't exist
if not os.path.exists(log_dir):
return False

# Delete any empty files that are older than 5 minutes
for f in os.listdir(log_dir):
try:
if os.path.getsize(os.path.join(log_dir, f)) == 0 and (
time.time() - os.path.getmtime(os.path.join(log_dir, f)) > 300
):
os.remove(os.path.join(log_dir, f))
except Exception:
pass

return os.path.exists(filename)

0 comments on commit 82ac534

Please sign in to comment.