Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a script for recording training data from a Gym simulation and replaying in sim #437

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e8c214d
added data recording in sim example
samzapo Sep 11, 2024
5ad1bcf
Cleaned up for PR
samzapo Sep 12, 2024
6a66f4a
checkpoint
samzapo Sep 12, 2024
59270e0
rolled back changes that recorded reset observations.
samzapo Sep 12, 2024
735d5e5
Merge branch 'main' of github.com:huggingface/lerobot into record_gym…
samzapo Sep 12, 2024
38f9abd
first commit
samzapo Sep 13, 2024
29326a6
checkpoint, code in place for running sim and saving data
samzapo Sep 13, 2024
a77c005
checkpoint, setting-up replaying
samzapo Sep 13, 2024
9cce7d7
checkpoint
samzapo Sep 16, 2024
30baf69
working script lerobot/scripts/record_and_replay_teleop_dataset.py. …
samzapo Sep 16, 2024
66d785c
Checkpoint working replay from disk
samzapo Sep 16, 2024
a3100bd
Cleaned up code (removed runner class). Replay fails during compute …
samzapo Sep 17, 2024
125a438
Cleaned up code (removed 'runner' class). Replay fails during comput…
samzapo Sep 17, 2024
703ee07
Fixed gripper open/close issue
samzapo Sep 17, 2024
525b2f9
dummy code for teleop using leader arm
samzapo Sep 17, 2024
1a1ff0d
Randomizer re-seeds before each sim run
samzapo Sep 18, 2024
067ab01
Merge remote-tracking branch 'upstream/main' into record_gym_script
samzapo Sep 19, 2024
321cb62
does not double reset on start, corrected teleop docs
samzapo Sep 20, 2024
2c29606
Merge remote-tracking branch 'upstream/main' into record_gym_script
samzapo Sep 30, 2024
a47a08b
checkpoint: reading joint angles from leader arm
samzapo Oct 1, 2024
afa633b
Teleoperating with leader arm in sim
samzapo Oct 2, 2024
139dc46
(checkpoint) adding kwargs config
samzapo Oct 2, 2024
1731e35
Using kwargs config instead of explicitly setting args in environment…
samzapo Oct 3, 2024
ce8fe1b
Updated version of gym-drake-lca that uses IDYN controller, and bette…
samzapo Oct 3, 2024
862f40d
Merge remote-tracking branch 'upstream/main' into record_gym_script
samzapo Oct 3, 2024
ee66983
removed explicitly requesting vcodec
samzapo Oct 4, 2024
21bd0e0
checkpoint: final update
samzapo Oct 10, 2024
666e2b7
Merge branch 'main' of github.com:huggingface/lerobot into record_gym…
samzapo Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 266 additions & 0 deletions examples/8_record_dataset_from_gym.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
"""This script demonstrates how to record a LeRobot dataset of training data using a very simple gym environment.
samzapo marked this conversation as resolved.
Show resolved Hide resolved
"""

import argparse
import copy
import os
import sys
import pathlib
import time
import importlib

import gymnasium as gym
import numpy as np
import torch

from datasets import Dataset, Features, Sequence, Value
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, DATA_DIR, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
from lerobot.scripts.push_dataset_to_hub import push_meta_data_to_hub, push_videos_to_hub, save_meta_data
from tqdm import tqdm


def process_args():
# parse the repo_id name via command line
parser = argparse.ArgumentParser()
parser.add_argument("--module-name", type=str, default="gym_lowcostrobot")
parser.add_argument("--env-name", type=str, default="ReachCube-v0")
parser.add_argument("--num-episodes", type=int, default=2)
parser.add_argument("--num-steps", type=int, default=20)
parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument("--keep-last", action="store_true")
parser.add_argument("--repo-id", type=str, default="myrepo")
parser.add_argument("--push-to-hub", action="store_true")
parser.add_argument("--fps", type=np.float32, default=30,
help="Frames (steps) per second of the recording (will be ignored if `step()` info provides timestamp).")

parser.add_argument(
"--image-keys",
type=str,
default="image_top,image_front",
help="The keys of the image observations to record.",
)
parser.add_argument(
"--state-keys",
type=str,
default="arm_qpos,arm_qvel,cube_pos",
help="The keys of the state observations to record.",
)

parser.add_argument(
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
)

return parser.parse_args()


if __name__ == "__main__":
args = process_args()

repo_id = args.repo_id
num_episodes = args.num_episodes
num_steps = args.num_steps
revision = args.revision

if DATA_DIR == None:
DATA_DIR = pathlib.Path("data_traces")
print("Warning: env variable DATA_DIR was not set, defaulting to './{}'.".format(DATA_DIR))

out_data = DATA_DIR / repo_id

# During data collection, frames are stored as png images in `images_dir`
images_dir = out_data / "images"

# After data collection, png images of each episode are encoded into a mp4 file stored in `videos_dir`
videos_dir = out_data / "videos"
meta_data_dir = out_data / "meta_data"

# Create image and video directories
if not os.path.exists(images_dir):
os.makedirs(images_dir, exist_ok=True)
if not os.path.exists(videos_dir):
os.makedirs(videos_dir, exist_ok=True)

# import the gym module containing the environment
try:
# because we want to import using a variable, do it this way
module_obj = __import__(args.module_name)
# create a global object containging our module
globals()[args.module_name] = module_obj
except ImportError:
sys.stderr.write("ERROR: missing python module: " +
args.module_name + "\n")
sys.exit(1)

# Create the gym environment - check the kwargs in gym_real_world/gym_environment.py
env = gym.make(args.env_name, disable_env_checker=True,
observation_mode="both", action_mode="joint", render_mode="human")

ep_dicts = []
episode_data_index = {"from": [], "to": []}
ep_fps = []
id_from = 0
id_to = 0
print(f"gym environment created")

ep_idx = 0
while ep_idx < num_episodes:
# bring the follower to the leader and start camera
env.reset()

print(f"go {ep_idx}")

# init buffers
obs_replay = {k: [] for k in env.observation_space}
obs_replay["action"] = []

timestamps = []
drop_episode = False
for i in tqdm(range(num_steps)):
# Apply the next action
action = env.action_space.sample()
observation, _, _, _, info = env.step(action=action)

# store data
for key in observation:
obs_replay[key].append(copy.deepcopy(observation[key]))
obs_replay["action"].append(copy.deepcopy(action))

# Use simulator time, if available.
try:
timestamps.append(info["timestamp"])
except KeyError:
timestamps.append(np.float32(i)/args.fps)

print("stop")

if not drop_episode:
print(f"saving episode {ep_idx}")
ep_dict = {}

# store images in png and create the video
for img_key in args.image_keys.split(","):
save_images_concurrently(
obs_replay[img_key],
images_dir / f"{img_key}_episode_{ep_idx:06d}",
args.num_workers,
)
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"

# store the reference to the video frame
ep_dict[f"observation.{img_key}"] = [
{"path": f"videos/{fname}", "timestamp": tstp} for tstp in timestamps]

states = []
for state_name in args.state_keys.split(","):
states.append(np.array(obs_replay[state_name]))
state = torch.tensor(np.concatenate(states, axis=1))

action = torch.tensor(np.array(obs_replay["action"]))
next_done = torch.zeros(num_steps, dtype=torch.bool)
next_done[-1] = True

ep_dict["observation.state"] = state
ep_dict["action"] = action
ep_dict["episode_index"] = torch.tensor(
[ep_idx] * num_steps, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_steps, 1)
ep_dict["timestamp"] = torch.tensor(timestamps)
ep_dict["next.done"] = next_done

print(f"num_steps={num_steps}")
print(f"timestamps[-1]={timestamps[-1]}")
ep_fps.append(num_steps / timestamps[-1])
ep_dicts.append(ep_dict)
print(f"Episode {ep_idx} done, fps: {ep_fps[-1]:.2f}")

episode_data_index["from"].append(id_from)
episode_data_index["to"].append(
id_from + num_steps if args.keep_last else id_from + num_steps - 1)

id_to = id_from + num_steps if args.keep_last else id_from + num_steps - 1
id_from = id_to

ep_idx += 1

env.close()

print("encode video frames")
for ep_idx in range(num_episodes):
for img_key in args.image_keys.split(","):
encode_video_frames(
vcodec="libx265",
samzapo marked this conversation as resolved.
Show resolved Hide resolved
imgs_dir=images_dir / f"{img_key}_episode_{ep_idx:06d}",
video_path=videos_dir / f"{img_key}_episode_{ep_idx:06d}.mp4",
fps=ep_fps[ep_idx],
)

print("concatenate episodes")
# Since our fps varies we are sometimes off tolerance for the last frame
data_dict = concatenate_episodes(ep_dicts)

features = {}

keys = [key for key in data_dict if "observation.image_" in key]
for key in keys:
features[key.replace("observation.image_",
"observation.images.")] = VideoFrame()
data_dict[key.replace("observation.image_",
"observation.images.")] = data_dict[key]
del data_dict[key]

features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(
dtype="float32", id=None)
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None))
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
features["timestamp"] = Value(dtype="float32", id=None)
features["next.done"] = Value(dtype="bool", id=None)
features["index"] = Value(dtype="int64", id=None)

hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
hf_dataset.set_transform(hf_transform_to_torch)

info = {
# to have a good tolerance in data processing for the slowest video
"fps": sum(ep_fps) / len(ep_fps),
"video": 1,
}

print("from preloaded")
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
hf_dataset=hf_dataset,
episode_data_index=episode_data_index,
info=info,
videos_dir=videos_dir,
)

print("compute stats")
stats = compute_stats(lerobot_dataset, num_workers=args.num_workers)

print("save to disk")
# to remove transforms that cant be saved
hf_dataset = hf_dataset.with_format(None)
hf_dataset.save_to_disk(str(out_data / "train"))

save_meta_data(info, stats, episode_data_index, meta_data_dir)

if args.push_to_hub:
print(f"Pushing dataset to '{repo_id}'")
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)

push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision)

push_videos_to_hub(repo_id, videos_dir, revision="main")
push_videos_to_hub(repo_id, videos_dir, revision=revision)
8 changes: 4 additions & 4 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def tolerance_s(self) -> float:
are not close enough from the requested frames. It is only used when `delta_timestamps`
is provided or when loading video frames from mp4 files.
"""
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
# Add 1e-4 seconds to expected step size to account for possible numerical error
return 1 / self.fps + 1e-4

def __len__(self):
return self.num_samples
Expand Down Expand Up @@ -357,8 +357,8 @@ def tolerance_s(self) -> float:
are not close enough from the requested frames. It is only used when `delta_timestamps`
is provided or when loading video frames from mp4 files.
"""
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
# Add 1e-4 seconds to expected step size to account for possible numerical error
return 1 / self.fps + 1e-4

def __len__(self):
return self.num_samples
Expand Down
Loading