Skip to content

Commit

Permalink
Fix visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Oct 23, 2024
1 parent a2a8538 commit 7ae8d05
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions lerobot/scripts/visualize_dataset_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,13 @@ def show_episode(dataset_namespace, dataset_name, episode_id):
"num_episodes": dataset.num_episodes,
"fps": dataset.fps,
}
video_paths = get_episode_video_paths(dataset, episode_id)
language_instruction = get_episode_language_instruction(dataset, episode_id)
video_paths = [dataset.get_video_file_path(episode_id, key) for key in dataset.video_keys]
tasks = dataset.episode_dicts[episode_id]["tasks"]
videos_info = [
{"url": url_for("static", filename=video_path), "filename": Path(video_path).name}
{"url": url_for("static", filename=video_path), "filename": video_path.name}
for video_path in video_paths
]
if language_instruction:
videos_info[0]["language_instruction"] = language_instruction
videos_info[0]["language_instruction"] = tasks

ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
return render_template(
Expand Down Expand Up @@ -137,10 +136,10 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
# init header of csv with state and action names
header = ["timestamp"]
if has_state:
dim_state = len(dataset.hf_dataset["observation.state"][0])
dim_state = dataset.shapes["observation.state"]
header += [f"state_{i}" for i in range(dim_state)]
if has_action:
dim_action = len(dataset.hf_dataset["action"][0])
dim_action = dataset.shapes["action"]
header += [f"action_{i}" for i in range(dim_action)]

columns = ["timestamp"]
Expand Down Expand Up @@ -171,7 +170,7 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
# get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
return [
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.camera_keys
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.video_keys
]


Expand Down Expand Up @@ -203,8 +202,8 @@ def visualize_dataset_html(

dataset = LeRobotDataset(repo_id, root=root)

if not dataset.video:
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
if len(dataset.image_keys) > 0:
raise NotImplementedError(f"Image keys ({dataset.image_keys=}) are currently not supported.")

if output_dir is None:
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
Expand All @@ -224,7 +223,7 @@ def visualize_dataset_html(
static_dir.mkdir(parents=True, exist_ok=True)
ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())

template_dir = Path(__file__).resolve().parent.parent / "templates"

Expand Down

0 comments on commit 7ae8d05

Please sign in to comment.