Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene committed Nov 27, 2024
1 parent fc4df91 commit 272a9d9
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 19 deletions.
22 changes: 11 additions & 11 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,25 +140,25 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)

l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
bsize = actions_hat.shape[0]
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)

loss_dict = {"l1_loss": l1_loss.item()}
out_dict = {}
out_dict["l1_loss"] = l1_loss
if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
else:
loss_dict["loss"] = l1_loss
out_dict["loss"] = l1_loss

return loss_dict
out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"]
return out_dict


class ACTTemporalEnsembler:
Expand Down
5 changes: 3 additions & 2 deletions lerobot/scripts/visualize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,11 @@ def main():
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
# root = kwargs.pop("root")

logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)

dataset = LeRobotDataset(repo_id)

visualize_dataset(dataset, **vars(args))

Expand Down
134 changes: 128 additions & 6 deletions lerobot/scripts/visualize_dataset_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,30 @@
import argparse
import logging
import shutil
import warnings
from pathlib import Path

import torch
import tqdm
from flask import Flask, redirect, render_template, url_for

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.utils.utils import init_logging
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config, init_logging
from lerobot.scripts.eval import get_pretrained_policy_path


class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset, episode_index):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
self.frame_ids = range(from_idx, to_idx)

def __iter__(self):
return iter(self.frame_ids)

def __len__(self):
return len(self.frame_ids)


def run_server(
Expand Down Expand Up @@ -119,14 +136,95 @@ def show_episode(dataset_namespace, dataset_name, episode_id):
app.run(host=host, port=port)


def run_inference(
dataset, episode_index, policy, policy_method="select_action", num_workers=4, batch_size=32, device="mps"
):
if policy_method not in ["select_action", "forward"]:
raise ValueError(
f"`policy_method` is expected to be 'select_action' or 'forward', but '{policy_method}' is provided instead."
)

policy.eval()
policy.to(device)

logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
# When using `select_action`, we set batch size 1 so that we feed 1 frame at a time, in a continuous fashion.
batch_size=1 if policy_method == "select_action" else batch_size,
sampler=episode_sampler,
drop_last=False,
)

warned_ndim_eq_0 = False
warned_ndim_gt_2 = False

logging.info("Running inference")
inference_results = {}
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
with torch.inference_mode():
if policy_method == "select_action":
gt_action = batch.pop("action")
output_dict = {"action": policy.select_action(batch)}
batch["action"] = gt_action
elif policy_method == "forward":
output_dict = policy.forward(batch)
# TODO(rcadene): Save and display all predicted actions at a given timestamp
# Save predicted action for the next timestamp only
output_dict["action"] = output_dict["action"][:, 0, :]

for key in output_dict:
if output_dict[key].ndim == 0:
if not warned_ndim_eq_0:
warnings.warn(
f"Ignore output key '{key}'. Its value is a scalar instead of a vector. It might have been aggregated over the batch dimension (e.g. `loss.mean()`).",
stacklevel=1,
)
warned_ndim_eq_0 = True
continue

if output_dict[key].ndim > 2:
if not warned_ndim_gt_2:
warnings.warn(
f"Ignore output key '{key}'. Its value is a tensor of {output_dict[key].ndim} dimensions instead of a vector.",
stacklevel=1,
)
warned_ndim_gt_2 = True
continue

if key not in inference_results:
inference_results[key] = []
inference_results[key].append(output_dict[key].to("cpu"))

for key in inference_results:
inference_results[key] = torch.cat(inference_results[key])

return inference_results


def get_ep_csv_fname(episode_id: int):
ep_csv_fname = f"episode_{episode_id}.csv"
return ep_csv_fname


def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, policy=None):
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""

if policy is not None:
inference_results = run_inference(
dataset,
episode_index,
policy,
policy_method="select_action",
# num_workers=hydra_cfg.training.num_workers,
# batch_size=hydra_cfg.training.batch_size,
# device=hydra_cfg.device,
)

from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]

Expand All @@ -141,21 +239,26 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
if has_action:
dim_action = dataset.meta.shapes["action"][0]
header += [f"action_{i}" for i in range(dim_action)]
if policy is not None:
dim_action = dataset.meta.shapes["action"][0]
header += [f"pred_action_{i}" for i in range(dim_action)]

columns = ["timestamp"]
if has_state:
columns += ["observation.state"]
if has_action:
columns += ["action"]
data = dataset.hf_dataset.select_columns(columns)

rows = []
data = dataset.hf_dataset.select_columns(columns)
for i in range(from_idx, to_idx):
row = [data[i]["timestamp"].item()]
if has_state:
row += data[i]["observation.state"].tolist()
if has_action:
row += data[i]["action"].tolist()
if policy is not None:
row += inference_results["action"][i].tolist()
rows.append(row)

output_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -183,6 +286,9 @@ def visualize_dataset_html(
host: str = "127.0.0.1",
port: int = 9090,
force_override: bool = False,
policy_method: str = "select_action",
pretrained_policy_name_or_path: str | None = None,
policy_overrides: list[str] | None = None,
) -> Path | None:
init_logging()

Expand Down Expand Up @@ -214,12 +320,28 @@ def visualize_dataset_html(
if episodes is None:
episodes = list(range(dataset.num_episodes))

pretrained_policy_name_or_path = "aliberts/act_reachy_test_model"

policy = None
if pretrained_policy_name_or_path is not None:
logging.info("Loading policy")
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)

hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", overrides=["device=mps"])
# dataset = make_dataset(hydra_cfg)
policy = make_policy(hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)

if policy_method == "select_action":
# Do not load previous observations or future actions, to simulate that the observations come from
# an environment.
dataset.delta_timestamps = None

logging.info("Writing CSV files")
for episode_index in tqdm.tqdm(episodes):
# write states and actions in a csv (it can be slow for big datasets)
ep_csv_fname = get_ep_csv_fname(episode_index)
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, policy=policy)

if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)
Expand Down Expand Up @@ -281,8 +403,8 @@ def main():
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
# root = kwargs.pop("root")
dataset = LeRobotDataset(repo_id)
visualize_dataset_html(dataset, **kwargs)


Expand Down

0 comments on commit 272a9d9

Please sign in to comment.