diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index 208284465..495a2575e 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -21,7 +21,7 @@ import tqdm from datasets import Image -from lerobot.common.datasets.video_utils import VideoFrame +from lerobot.common.datasets.video_utils import VideoFrame, DepthFrame def get_stats_einops_patterns(dataset, num_workers=0): @@ -57,6 +57,17 @@ def get_stats_einops_patterns(dataset, num_workers=0): assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}" assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}" + stats_patterns[key] = "b c h w -> c 1 1" + elif isinstance(feats_type, DepthFrame): + # sanity check that images are channel first + _, c, h, w = batch[key].shape + assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}" + + # sanity check that images are float32 in range [0,1] + assert batch[key].dtype == torch.uint16, f"expect torch.float32, but instead {batch[key].dtype=}" + # assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}" + # assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}" + stats_patterns[key] = "b c h w -> c 1 1" elif batch[key].ndim == 2: stats_patterns[key] = "b c -> c " diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index eb76f78d6..84b5de8b2 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -33,7 +33,7 @@ load_videos, reset_episode_index, ) -from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos +from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos, DepthFrame, load_depth_frames # For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md CODEBASE_VERSION = "v1.6" @@ -109,6 +109,20 @@ def video_frame_keys(self) -> list[str]: if isinstance(feats, VideoFrame): video_frame_keys.append(key) return video_frame_keys + + @property + def depth_frame_keys(self) -> list[str]: + """Keys to access video frames that requires to be decoded into images. + + Note: It is empty if the dataset contains images only, + or equal to `self.cameras` if the dataset contains videos only, + or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. + """ + depth_frame_keys = [] + for key, feats in self.hf_dataset.features.items(): + if isinstance(feats, DepthFrame): + depth_frame_keys.append(key) + return depth_frame_keys @property def num_samples(self) -> int: @@ -153,6 +167,13 @@ def __getitem__(self, idx): self.video_backend, ) + item = load_depth_frames( + item, + self.depth_frame_keys, + self.videos_dir, + ) + + if self.image_transforms is not None: for cam in self.camera_keys: item[cam] = self.image_transforms(item[cam]) diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py index 52c4bba3d..206ac1bbe 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py @@ -38,7 +38,7 @@ calculate_episode_data_index, hf_transform_to_torch, ) -from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames +from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames, DepthFrame def get_cameras(hdf5_data): @@ -180,6 +180,12 @@ def to_hf_dataset(data_dict, video) -> Dataset: else: features[key] = Image() + + depth_keys = [key for key in data_dict if "observation.depth." in key] + + for key in depth_keys: + features[key] = DepthFrame() + features["observation.state"] = Sequence( length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) ) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index d6aef15f5..763f8f3a3 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -88,6 +88,9 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item: # video frame will be processed downstream pass + elif isinstance(first_item, dict) and "path" in first_item: + # depth frame will be processed downstream + pass elif first_item is None: pass else: diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 4d4ac6b0a..2e74d9ab2 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -25,8 +25,37 @@ import torch import torchvision from datasets.features.features import register_feature +from PIL import Image +import numpy as np +def load_depth_frames( + item: dict[str, torch.Tensor], + depth_frame_keys: list[str], + videos_dir: Path, +): + """ + Load depth frames from individual PNG files. + """ + for key in depth_frame_keys: + if isinstance(item[key], list): + # load multiple frames at once + frames = [] + for frame in item[key]: + depth_path = frame["path"] + depth_image = Image.open(depth_path) + depth_tensor = torch.from_numpy(np.array(depth_image, dtype=np.uint16)).unsqueeze(0) # Add channel dimension + frames.append(depth_tensor) + item[key] = torch.stack(frames) + else: + # load one frame + depth_path = item[key]["path"] + depth_image = Image.open(depth_path) + item[key] = torch.from_numpy(np.array(depth_image)).unsqueeze(0) # Add channel dimension + + # print('depth item[key].shape',item[key].shape, item[key].dtype) + return item + def load_from_videos( item: dict[str, torch.Tensor], video_frame_keys: list[str], @@ -237,6 +266,26 @@ class VideoFrame: def __call__(self): return self.pa_type + +@dataclass +class DepthFrame: + """ + Provides a type for a dataset containing depth frames. + + Example: + + ```python + data_dict = [{"image": {"path": "videos/observation.depth.cam_high_episode_000000/frame_000000.png"}}] + features = {"image": DepthFrame()} + Dataset.from_dict(data_dict, features=Features(features)) + ``` + """ + + pa_type: ClassVar[Any] = pa.struct({"path": pa.string()}) + _type: str = field(default="DepthFrame", init=False, repr=False) + + def __call__(self): + return self.pa_type with warnings.catch_warnings(): @@ -247,3 +296,4 @@ def __call__(self): ) # to make VideoFrame available in HuggingFace `datasets` register_feature(VideoFrame, "VideoFrame") + register_feature(DepthFrame, "DepthFrame") \ No newline at end of file diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 337519765..647fc9bc0 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -590,7 +590,14 @@ def teleop_step( for name in self.cameras: before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() - images[name] = torch.from_numpy(images[name]) + + if isinstance(images[name], tuple): + images_1 = images[name][1].astype('int32') #TODO: use uint16 instead of int32, because torch does not support uint16 in eager mode + images[name] = (images[name][0], images_1) + images[name] = tuple(torch.from_numpy(arr) for arr in images[name]) + # print('images[name][1].shape', images[name][1].shape) + else: + images[name] = torch.from_numpy(images[name]) self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t @@ -599,7 +606,12 @@ def teleop_step( obs_dict["observation.state"] = state action_dict["action"] = action for name in self.cameras: - obs_dict[f"observation.images.{name}"] = images[name] + if isinstance(images[name], tuple): + obs_dict[f"observation.images.{name}"] = images[name][0] + obs_dict[f"observation.depth.{name}"] = images[name][1] + + else: + obs_dict[f"observation.images.{name}"] = images[name] return obs_dict, action_dict @@ -630,7 +642,12 @@ def capture_observation(self): for name in self.cameras: before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() - images[name] = torch.from_numpy(images[name]) + if isinstance(images[name], tuple): + images_1 = images[name][1].astype('int32') #TODO: use uint16 instead of int32, because torch does not support uint16 in eager mode + images[name] = (images[name][0], images_1) + images[name] = tuple(torch.from_numpy(arr) for arr in images[name]) + else: + images[name] = torch.from_numpy(images[name]) self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t @@ -638,7 +655,11 @@ def capture_observation(self): obs_dict = {} obs_dict["observation.state"] = state for name in self.cameras: - obs_dict[f"observation.images.{name}"] = images[name] + if isinstance(images[name], tuple): + obs_dict[f"observation.images.{name}"] = images[name][0] + obs_dict[f"observation.depth.{name}"] = images[name][1] + else: + obs_dict[f"observation.images.{name}"] = images[name] return obs_dict def send_action(self, action: torch.Tensor) -> torch.Tensor: diff --git a/lerobot/configs/robot/aloha.yaml b/lerobot/configs/robot/aloha.yaml index 938fa2e3d..a1d0d0b00 100644 --- a/lerobot/configs/robot/aloha.yaml +++ b/lerobot/configs/robot/aloha.yaml @@ -95,21 +95,25 @@ cameras: fps: 30 width: 640 height: 480 + use_depth: false #set "true" if you want to use the depth camera cam_low: _target_: lerobot.common.robot_devices.cameras.intelrealsense.IntelRealSenseCamera camera_index: 130322270656 fps: 30 width: 640 height: 480 + use_depth: false #set "true" if you want to use the depth camera cam_left_wrist: _target_: lerobot.common.robot_devices.cameras.intelrealsense.IntelRealSenseCamera camera_index: 218622272670 fps: 30 width: 640 height: 480 + use_depth: false #set "true" if you want to use the depth camera cam_right_wrist: _target_: lerobot.common.robot_devices.cameras.intelrealsense.IntelRealSenseCamera camera_index: 130322272300 fps: 30 width: 640 height: 480 + use_depth: false #set "true" if you want to use the depth camera diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index a6506a3fe..cba3e656e 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -111,6 +111,7 @@ from functools import cache from pathlib import Path +import numpy as np import cv2 import torch import tqdm @@ -169,6 +170,20 @@ def save_image(img_tensor, key, frame_index, episode_index, videos_dir): path.parent.mkdir(parents=True, exist_ok=True) img.save(str(path), quality=100) +def save_depth(depth_tensor, key, frame_index, episode_index, videos_dir): + # Convert the torch tensor to a numpy array + depth_array = depth_tensor.numpy().astype(np.uint16) + + # Convert the numpy array to a PIL Image + depth_image_pil = Image.fromarray(depth_array) + + # Define the path for saving the PNG file + path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png" + path.parent.mkdir(parents=True, exist_ok=True) + + # Save the depth image as a PNG file + depth_image_pil.save(str(path), quality=100) + def none_or_int(value): if value == "None": @@ -463,7 +478,6 @@ def on_press(key): observation = robot.capture_observation() image_keys = [key for key in observation if "image" in key] - not_image_keys = [key for key in observation if "image" not in key] for key in image_keys: futures += [ @@ -472,13 +486,24 @@ def on_press(key): ) ] + depth_keys = [key for key in observation if "depth" in key] + + not_image_depth_keys = [key for key in observation if "image" not in key and "depth" not in key] + + for key in depth_keys: + futures += [ + executor.submit( + save_depth, observation[key], key, frame_index, episode_index, videos_dir + ) + ] + if not is_headless(): image_keys = [key for key in observation if "image" in key] for key in image_keys: cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) cv2.waitKey(1) - for key in not_image_keys: + for key in not_image_depth_keys: if key not in ep_dict: ep_dict[key] = [] ep_dict[key].append(observation[key]) @@ -555,7 +580,12 @@ def on_press(key): for i in range(num_frames): ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps}) - for key in not_image_keys: + for key in depth_keys: + ep_dict[key] = [] + for i in range(num_frames): + ep_dict[key].append({"path": str(videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{i:06d}.png")}) + + for key in not_image_depth_keys: ep_dict[key] = torch.stack(ep_dict[key]) for key in action: