Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added depth frame support for dataset recording/loading #455

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion lerobot/common/datasets/compute_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tqdm
from datasets import Image

from lerobot.common.datasets.video_utils import VideoFrame
from lerobot.common.datasets.video_utils import VideoFrame, DepthFrame


def get_stats_einops_patterns(dataset, num_workers=0):
Expand Down Expand Up @@ -57,6 +57,17 @@ def get_stats_einops_patterns(dataset, num_workers=0):
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"

stats_patterns[key] = "b c h w -> c 1 1"
elif isinstance(feats_type, DepthFrame):
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"

# sanity check that images are float32 in range [0,1]
assert batch[key].dtype == torch.uint16, f"expect torch.float32, but instead {batch[key].dtype=}"
# assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
# assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"

stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2:
stats_patterns[key] = "b c -> c "
Expand Down
23 changes: 22 additions & 1 deletion lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
load_videos,
reset_episode_index,
)
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos, DepthFrame, load_depth_frames

# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
CODEBASE_VERSION = "v1.6"
Expand Down Expand Up @@ -109,6 +109,20 @@ def video_frame_keys(self) -> list[str]:
if isinstance(feats, VideoFrame):
video_frame_keys.append(key)
return video_frame_keys

@property
def depth_frame_keys(self) -> list[str]:
"""Keys to access video frames that requires to be decoded into images.

Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
depth_frame_keys = []
for key, feats in self.hf_dataset.features.items():
if isinstance(feats, DepthFrame):
depth_frame_keys.append(key)
return depth_frame_keys

@property
def num_samples(self) -> int:
Expand Down Expand Up @@ -153,6 +167,13 @@ def __getitem__(self, idx):
self.video_backend,
)

item = load_depth_frames(
item,
self.depth_frame_keys,
self.videos_dir,
)


if self.image_transforms is not None:
for cam in self.camera_keys:
item[cam] = self.image_transforms(item[cam])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames, DepthFrame


def get_cameras(hdf5_data):
Expand Down Expand Up @@ -180,6 +180,12 @@ def to_hf_dataset(data_dict, video) -> Dataset:
else:
features[key] = Image()


depth_keys = [key for key in data_dict if "observation.depth." in key]

for key in depth_keys:
features[key] = DepthFrame()

features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
Expand Down
3 changes: 3 additions & 0 deletions lerobot/common/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
# video frame will be processed downstream
pass
elif isinstance(first_item, dict) and "path" in first_item:
# depth frame will be processed downstream
pass
elif first_item is None:
pass
else:
Expand Down
50 changes: 50 additions & 0 deletions lerobot/common/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,37 @@
import torch
import torchvision
from datasets.features.features import register_feature
from PIL import Image
import numpy as np


def load_depth_frames(
item: dict[str, torch.Tensor],
depth_frame_keys: list[str],
videos_dir: Path,
):
"""
Load depth frames from individual PNG files.
"""
for key in depth_frame_keys:
if isinstance(item[key], list):
# load multiple frames at once
frames = []
for frame in item[key]:
depth_path = frame["path"]
depth_image = Image.open(depth_path)
depth_tensor = torch.from_numpy(np.array(depth_image, dtype=np.uint16)).unsqueeze(0) # Add channel dimension
frames.append(depth_tensor)
item[key] = torch.stack(frames)
else:
# load one frame
depth_path = item[key]["path"]
depth_image = Image.open(depth_path)
item[key] = torch.from_numpy(np.array(depth_image)).unsqueeze(0) # Add channel dimension

# print('depth item[key].shape',item[key].shape, item[key].dtype)
return item

def load_from_videos(
item: dict[str, torch.Tensor],
video_frame_keys: list[str],
Expand Down Expand Up @@ -237,6 +266,26 @@ class VideoFrame:

def __call__(self):
return self.pa_type

@dataclass
class DepthFrame:
"""
Provides a type for a dataset containing depth frames.

Example:

```python
data_dict = [{"image": {"path": "videos/observation.depth.cam_high_episode_000000/frame_000000.png"}}]
features = {"image": DepthFrame()}
Dataset.from_dict(data_dict, features=Features(features))
```
"""

pa_type: ClassVar[Any] = pa.struct({"path": pa.string()})
_type: str = field(default="DepthFrame", init=False, repr=False)

def __call__(self):
return self.pa_type


with warnings.catch_warnings():
Expand All @@ -247,3 +296,4 @@ def __call__(self):
)
# to make VideoFrame available in HuggingFace `datasets`
register_feature(VideoFrame, "VideoFrame")
register_feature(DepthFrame, "DepthFrame")
29 changes: 25 additions & 4 deletions lerobot/common/robot_devices/robots/manipulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,14 @@ def teleop_step(
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])

if isinstance(images[name], tuple):
images_1 = images[name][1].astype('int32') #TODO: use uint16 instead of int32, because torch does not support uint16 in eager mode
images[name] = (images[name][0], images_1)
images[name] = tuple(torch.from_numpy(arr) for arr in images[name])
# print('images[name][1].shape', images[name][1].shape)
else:
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t

Expand All @@ -599,7 +606,12 @@ def teleop_step(
obs_dict["observation.state"] = state
action_dict["action"] = action
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
if isinstance(images[name], tuple):
obs_dict[f"observation.images.{name}"] = images[name][0]
obs_dict[f"observation.depth.{name}"] = images[name][1]

else:
obs_dict[f"observation.images.{name}"] = images[name]

return obs_dict, action_dict

Expand Down Expand Up @@ -630,15 +642,24 @@ def capture_observation(self):
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
if isinstance(images[name], tuple):
images_1 = images[name][1].astype('int32') #TODO: use uint16 instead of int32, because torch does not support uint16 in eager mode
images[name] = (images[name][0], images_1)
images[name] = tuple(torch.from_numpy(arr) for arr in images[name])
else:
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t

# Populate output dictionnaries and format to pytorch
obs_dict = {}
obs_dict["observation.state"] = state
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
if isinstance(images[name], tuple):
obs_dict[f"observation.images.{name}"] = images[name][0]
obs_dict[f"observation.depth.{name}"] = images[name][1]
else:
obs_dict[f"observation.images.{name}"] = images[name]
return obs_dict

def send_action(self, action: torch.Tensor) -> torch.Tensor:
Expand Down
4 changes: 4 additions & 0 deletions lerobot/configs/robot/aloha.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,25 @@ cameras:
fps: 30
width: 640
height: 480
use_depth: false #set "true" if you want to use the depth camera
cam_low:
_target_: lerobot.common.robot_devices.cameras.intelrealsense.IntelRealSenseCamera
camera_index: 130322270656
fps: 30
width: 640
height: 480
use_depth: false #set "true" if you want to use the depth camera
cam_left_wrist:
_target_: lerobot.common.robot_devices.cameras.intelrealsense.IntelRealSenseCamera
camera_index: 218622272670
fps: 30
width: 640
height: 480
use_depth: false #set "true" if you want to use the depth camera
cam_right_wrist:
_target_: lerobot.common.robot_devices.cameras.intelrealsense.IntelRealSenseCamera
camera_index: 130322272300
fps: 30
width: 640
height: 480
use_depth: false #set "true" if you want to use the depth camera
36 changes: 33 additions & 3 deletions lerobot/scripts/control_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
from functools import cache
from pathlib import Path

import numpy as np
import cv2
import torch
import tqdm
Expand Down Expand Up @@ -169,6 +170,20 @@ def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), quality=100)

def save_depth(depth_tensor, key, frame_index, episode_index, videos_dir):
# Convert the torch tensor to a numpy array
depth_array = depth_tensor.numpy().astype(np.uint16)

# Convert the numpy array to a PIL Image
depth_image_pil = Image.fromarray(depth_array)

# Define the path for saving the PNG file
path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
path.parent.mkdir(parents=True, exist_ok=True)

# Save the depth image as a PNG file
depth_image_pil.save(str(path), quality=100)


def none_or_int(value):
if value == "None":
Expand Down Expand Up @@ -463,7 +478,6 @@ def on_press(key):
observation = robot.capture_observation()

image_keys = [key for key in observation if "image" in key]
not_image_keys = [key for key in observation if "image" not in key]

for key in image_keys:
futures += [
Expand All @@ -472,13 +486,24 @@ def on_press(key):
)
]

depth_keys = [key for key in observation if "depth" in key]

not_image_depth_keys = [key for key in observation if "image" not in key and "depth" not in key]

for key in depth_keys:
futures += [
executor.submit(
save_depth, observation[key], key, frame_index, episode_index, videos_dir
)
]

if not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)

for key in not_image_keys:
for key in not_image_depth_keys:
if key not in ep_dict:
ep_dict[key] = []
ep_dict[key].append(observation[key])
Expand Down Expand Up @@ -555,7 +580,12 @@ def on_press(key):
for i in range(num_frames):
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})

for key in not_image_keys:
for key in depth_keys:
ep_dict[key] = []
for i in range(num_frames):
ep_dict[key].append({"path": str(videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{i:06d}.png")})

for key in not_image_depth_keys:
ep_dict[key] = torch.stack(ep_dict[key])

for key in action:
Expand Down