Skip to content

Commit

Permalink
Add padding keys and download_data option
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Oct 11, 2024
1 parent 7f68088 commit 3ea5312
Showing 1 changed file with 37 additions and 7 deletions.
44 changes: 37 additions & 7 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
download_data: bool = True,
video_backend: str | None = None,
):
"""LeRobotDataset encapsulates 3 main things:
Expand Down Expand Up @@ -128,6 +129,7 @@ def __init__(
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
multiples of 1/fps. Defaults to 1e-4.
download_data (bool, optional): Flag to download actual data. Defaults to True.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
"""
Expand All @@ -139,6 +141,7 @@ def __init__(
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
self.download_data = download_data
self.video_backend = video_backend if video_backend is not None else "pyav"
self.delta_indices = None

Expand All @@ -149,6 +152,13 @@ def __init__(
self.stats = load_stats(repo_id, self._version, self.root)
self.tasks = load_tasks(repo_id, self._version, self.root)

if not self.download_data:
# TODO(aliberts): Add actual support for this
# maybe use local_files_only=True or HF_HUB_OFFLINE=True
# see thread https://huggingface.slack.com/archives/C06ME3E7JUD/p1728637455476019
self.hf_dataset, self.episode_data_index = None, None
return

# Load actual data
self.download_episodes()
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
Expand Down Expand Up @@ -243,6 +253,11 @@ def camera_keys(self) -> list[str]:
"""Keys to access image and video streams from cameras (regardless of their storage method)."""
return self.image_keys + self.video_keys

@property
def names(self) -> dict[list[str]]:
"""Names of the various dimensions of vector modalities."""
return self.info["names"]

@property
def num_samples(self) -> int:
"""Number of samples/frames."""
Expand Down Expand Up @@ -275,21 +290,29 @@ def episode_length(self, episode_index) -> int:
"""Number of samples/frames for given episode."""
return self.info["episodes"][episode_index]["length"]

def _get_query_indices(self, idx: int, ep_idx: int) -> dict[str, list[int]]:
# Pad values outside of current episode range
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
ep_start = self.episode_data_index["from"][ep_idx]
ep_end = self.episode_data_index["to"][ep_idx]
return {
query_indices = {
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items()
}
padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor(
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
)
for key, delta_idx in self.delta_indices.items()
}
return query_indices, padding

def _get_query_timestamps(
self, query_indices: dict[str, list[int]], current_ts: float
self,
current_ts: float,
query_indices: dict[str, list[int]] | None = None,
) -> dict[str, list[float]]:
query_timestamps = {}
for key in self.video_keys:
if key in query_indices:
if query_indices is not None and key in query_indices:
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
else:
Expand Down Expand Up @@ -320,23 +343,30 @@ def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -

return item

def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
for key, val in padding.items():
item[key] = torch.BoolTensor(val)
return item

def __len__(self):
return self.num_samples

def __getitem__(self, idx) -> dict:
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()

query_indices = None
if self.delta_indices is not None:
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
query_indices = self._get_query_indices(idx, current_ep_idx)
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
for key, val in query_result.items():
item[key] = val

if len(self.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(query_indices, current_ts)
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item}

Expand Down

0 comments on commit 3ea5312

Please sign in to comment.