Skip to content

Commit

Permalink
Add load_metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Oct 18, 2024
1 parent 1a51505 commit bce3dc3
Showing 1 changed file with 20 additions and 41 deletions.
61 changes: 20 additions & 41 deletions lerobot/common/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
# limitations under the License.
import json
import warnings
from functools import cache
from itertools import accumulate
from pathlib import Path
from pprint import pformat
from typing import Dict

import datasets
import jsonlines
import torch
from datasets import load_dataset
from huggingface_hub import DatasetCard, HfApi, hf_hub_download
from huggingface_hub import DatasetCard, HfApi
from PIL import Image as PILImage
from torchvision import transforms

Expand Down Expand Up @@ -96,7 +96,6 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
return items_dict


@cache
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
num_version = float(version.strip("v"))
if num_version < 2 and enforce_v2:
Expand Down Expand Up @@ -144,50 +143,30 @@ def load_hf_dataset(
return hf_dataset


def load_stats(repo_id: str, version: str, local_dir: Path) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
def load_metadata(local_dir: Path) -> tuple[dict | list]:
"""Loads metadata files from a dataset."""
info_path = local_dir / "meta/info.json"
episodes_path = local_dir / "meta/episodes.jsonl"
stats_path = local_dir / "meta/stats.json"
tasks_path = local_dir / "meta/tasks.json"

Example:
```python
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
```
"""
fpath = hf_hub_download(
repo_id, filename="meta/stats.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
stats = json.load(f)

stats = flatten_dict(stats)
stats = {key: torch.tensor(value) for key, value in stats.items()}
return unflatten_dict(stats)
with open(info_path) as f:
info = json.load(f)

with jsonlines.open(episodes_path, "r") as reader:
episode_dicts = list(reader)

def load_info(repo_id: str, version: str, local_dir: Path) -> dict:
"""info contains structural information about the dataset. It should be the reference and
act as the 'source of thruth' for what's inside the dataset.
with open(stats_path) as f:
stats = json.load(f)

Example:
```python
print("frame per second used to collect the video", info["fps"])
```
"""
fpath = hf_hub_download(
repo_id, filename="meta/info.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
return json.load(f)


def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
"""tasks contains all the tasks of the dataset, indexed by their task_index."""
fpath = hf_hub_download(
repo_id, filename="meta/tasks.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
with open(tasks_path) as f:
tasks = json.load(f)

return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
stats = unflatten_dict(stats)
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}

return info, episode_dicts, stats, tasks


def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
Expand Down

0 comments on commit bce3dc3

Please sign in to comment.