Skip to content

Commit

Permalink
Nest ACT model in ACT Policy (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-soare authored Apr 30, 2024
1 parent 9d60dce commit 986583d
Showing 1 changed file with 104 additions and 95 deletions.
199 changes: 104 additions & 95 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,108 @@ class ActionChunkingTransformerPolicy(nn.Module):
"""
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
"""

name = "act"

def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
"""
super().__init__()
if cfg is None:
cfg = ActionChunkingTransformerConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
)
self.model = _ActionChunkingTransformer(cfg)

def reset(self):
"""This should be called whenever the environment is reset."""
if self.cfg.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)

@torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()

batch = self.normalize_inputs(batch)
self._stack_images(batch)

if len(self._action_queue) == 0:
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
actions = self.model(batch)[0][: self.cfg.n_action_steps]

# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]

self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()

def forward(self, batch, **_) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
self._stack_images(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)

l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()

loss_dict = {"l1_loss": l1_loss}
if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
else:
loss_dict["loss"] = l1_loss

return loss_dict

def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
This function expects `batch` to have (at least):
{
"observation.state": (B, state_dim) batch of robot states.
"observation.images.{name}": (B, C, H, W) tensor of images.
}
"""
# Stack images in the order dictated by input_shapes.
batch["observation.images"] = torch.stack(
[batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
dim=-4,
)

def save(self, fp):
torch.save(self.state_dict(), fp)

def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)


class _ActionChunkingTransformer(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ActionChunkingTransformerPolicy.
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
Expand Down Expand Up @@ -59,24 +161,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
└───────────────────────┘
"""

name = "act"

def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
"""
def __init__(self, cfg: ActionChunkingTransformerConfig):
super().__init__()
if cfg is None:
cfg = ActionChunkingTransformerConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
)

# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
if self.cfg.use_vae:
Expand Down Expand Up @@ -141,76 +228,7 @@ def _reset_parameters(self):
if p.dim() > 1:
nn.init.xavier_uniform_(p)

def reset(self):
"""This should be called whenever the environment is reset."""
if self.cfg.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)

@torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()

batch = self.normalize_inputs(batch)

if len(self._action_queue) == 0:
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# has shape (n_action_steps, batch_size, *), hence the transpose.
actions = self._forward(batch)[0][: self.cfg.n_action_steps]

# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]

self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()

def forward(self, batch, **_) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)

l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()

loss_dict = {"l1_loss": l1_loss}
if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
else:
loss_dict["loss"] = l1_loss

return loss_dict

def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
This function expects `batch` to have (at least):
{
"observation.state": (B, state_dim) batch of robot states.
"observation.images.{name}": (B, C, H, W) tensor of images.
}
"""
# Stack images in the order dictated by input_shapes.
batch["observation.images"] = torch.stack(
[batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
dim=-4,
)

def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure:
Expand All @@ -231,8 +249,6 @@ def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tens
"action" in batch
), "actions must be provided when using the variational objective in training mode."

self._stack_images(batch)

batch_size = batch["observation.state"].shape[0]

# Prepare the latent for input to the transformer encoder.
Expand Down Expand Up @@ -324,13 +340,6 @@ def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tens

return actions, (mu, log_sigma_x2)

def save(self, fp):
torch.save(self.state_dict(), fp)

def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)


class _TransformerEncoder(nn.Module):
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
Expand Down

0 comments on commit 986583d

Please sign in to comment.