Skip to content

Commit

Permalink
Make sure targets are normalized too (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-soare authored Apr 26, 2024
1 parent b980c5d commit 45f351c
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 94 deletions.
4 changes: 2 additions & 2 deletions lerobot/common/policies/act/configuration_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ class ActionChunkingTransformerConfig:
)

# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field(
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "mean_std",
}
)
unnormalize_output_modes: dict[str, str] = field(
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
}
Expand Down
8 changes: 6 additions & 2 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,11 @@ def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_s
if cfg is None:
cfg = ActionChunkingTransformerConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
)

# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
Expand Down Expand Up @@ -216,6 +219,7 @@ def update(self, batch, **_) -> dict:
self.train()

batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)

loss_dict = self.forward(batch)
# TODO(rcadene): self.unnormalize_outputs(out_dict)
Expand Down
8 changes: 2 additions & 6 deletions lerobot/common/policies/diffusion/configuration_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,13 @@ class DiffusionConfig:
)

# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field(
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
}
)
unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})

# Architecture / modeling.
# Vision backbone.
Expand Down
8 changes: 6 additions & 2 deletions lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@ def __init__(
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
)

# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
Expand Down Expand Up @@ -162,6 +165,7 @@ def update(self, batch: dict[str, Tensor], **_) -> dict:
self.diffusion.train()

batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)

loss = self.forward(batch)["loss"]
loss.backward()
Expand Down
175 changes: 98 additions & 77 deletions lerobot/common/policies/normalize.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,21 @@
import torch
from torch import nn
from torch import Tensor, nn


def create_stats_buffers(shapes, modes, stats=None):
def create_stats_buffers(
shapes: dict[str, list[int]],
modes: dict[str, str],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> dict[str, dict[str, nn.ParameterDict]]:
"""
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max statistics.
Parameters:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
- "mean_std": substract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
statistics.
Args: (see Normalize and Unnormalize)
Returns:
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing `nn.Parameters` set to
`requires_grad=False`, suitable to not be updated during backpropagation.
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
"""
stats_buffers = {}

Expand Down Expand Up @@ -75,24 +69,32 @@ def create_stats_buffers(shapes, modes, stats=None):


class Normalize(nn.Module):
"""
Normalizes the input data (e.g. "observation.image") for more stable and faster convergence during training.
Parameters:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
- "mean_std": substract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
"""

def __init__(self, shapes, modes, stats=None):
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""

def __init__(
self,
shapes: dict[str, list[int]],
modes: dict[str, str],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.shapes = shapes
self.modes = modes
Expand All @@ -104,29 +106,33 @@ def __init__(self, shapes, modes, stats=None):

# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
for key, mode in self.modes.items():
buffer = getattr(self, "buffer_" + key.replace(".", "_"))

if mode == "mean_std":
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(
mean
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
std
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(mean).any(), (
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
assert not torch.isinf(std).any(), (
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(
min
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
max
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(min).any(), (
"`min` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
assert not torch.isinf(max).any(), (
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1]
Expand All @@ -138,23 +144,34 @@ def forward(self, batch):

class Unnormalize(nn.Module):
"""
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their original range used by the environment.
Parameters:
shapes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their shapes (e.g. [10]).
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their unnormalization modes among:
- "mean_std": multiply by standard deviation and add mean
- "min_max": go from [-1, 1] range to original range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "action") and values are dictionaries of statistic types and their values
(e.g. `{"max": torch.tensor(1)}, "min": torch.tensor(0)}`). If provided, as expected for training the model for the first time,
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
they are already in the policy state_dict.
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
original range used by the environment.
"""

def __init__(self, shapes, modes, stats=None):
def __init__(
self,
shapes: dict[str, list[int]],
modes: dict[str, str],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.shapes = shapes
self.modes = modes
Expand All @@ -166,29 +183,33 @@ def __init__(self, shapes, modes, stats=None):

# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
for key, mode in self.modes.items():
buffer = getattr(self, "buffer_" + key.replace(".", "_"))

if mode == "mean_std":
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(
mean
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
std
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(mean).any(), (
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
assert not torch.isinf(std).any(), (
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
batch[key] = batch[key] * std + mean
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(
min
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(
max
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
assert not torch.isinf(min).any(), (
"`min` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
assert not torch.isinf(max).any(), (
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
"`policy.load_state_dict`."
)
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
Expand Down
4 changes: 2 additions & 2 deletions lerobot/configs/policy/act.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ policy:
action: ["${env.action_dim}"]

# Normalization / Unnormalization
normalize_input_modes:
input_normalization_modes:
observation.images.top: mean_std
observation.state: mean_std
unnormalize_output_modes:
output_normalization_modes:
action: mean_std

# Architecture.
Expand Down
4 changes: 2 additions & 2 deletions lerobot/configs/policy/diffusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ policy:
action: ["${env.action_dim}"]

# Normalization / Unnormalization
normalize_input_modes:
input_normalization_modes:
observation.image: mean_std
observation.state: min_max
unnormalize_output_modes:
output_normalization_modes:
action: min_max

# Architecture / modeling.
Expand Down
1 change: 0 additions & 1 deletion lerobot/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ def _maybe_eval_and_maybe_save(step):
policy,
video_dir=Path(out_dir) / "eval",
max_episodes_rendered=4,
transform=offline_dataset.transform,
seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
Expand Down

0 comments on commit 45f351c

Please sign in to comment.