diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 22384ca00..911e7121d 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -26,6 +26,108 @@ class ActionChunkingTransformerPolicy(nn.Module): """ Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) + """ + + name = "act" + + def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None): + """ + Args: + cfg: Policy configuration class instance or None, in which case the default instantiation of the + configuration class is used. + """ + super().__init__() + if cfg is None: + cfg = ActionChunkingTransformerConfig() + self.cfg = cfg + self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats) + self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats) + self.unnormalize_outputs = Unnormalize( + cfg.output_shapes, cfg.output_normalization_modes, dataset_stats + ) + self.model = _ActionChunkingTransformer(cfg) + + def reset(self): + """This should be called whenever the environment is reset.""" + if self.cfg.n_action_steps is not None: + self._action_queue = deque([], maxlen=self.cfg.n_action_steps) + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + self.eval() + + batch = self.normalize_inputs(batch) + self._stack_images(batch) + + if len(self._action_queue) == 0: + # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. + actions = self.model(batch)[0][: self.cfg.n_action_steps] + + # TODO(rcadene): make _forward return output dictionary? + actions = self.unnormalize_outputs({"action": actions})["action"] + + self._action_queue.extend(actions.transpose(0, 1)) + return self._action_queue.popleft() + + def forward(self, batch, **_) -> dict[str, Tensor]: + """Run the batch through the model and compute the loss for training or validation.""" + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + self._stack_images(batch) + actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) + + l1_loss = ( + F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) + ).mean() + + loss_dict = {"l1_loss": l1_loss} + if self.cfg.use_vae: + # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for + # each dimension independently, we sum over the latent dimension to get the total + # KL-divergence per batch element, then take the mean over the batch. + # (See App. B of https://arxiv.org/abs/1312.6114 for more details). + mean_kld = ( + (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() + ) + loss_dict["kld_loss"] = mean_kld + loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight + else: + loss_dict["loss"] = l1_loss + + return loss_dict + + def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Stacks all the images in a batch and puts them in a new key: "observation.images". + + This function expects `batch` to have (at least): + { + "observation.state": (B, state_dim) batch of robot states. + "observation.images.{name}": (B, C, H, W) tensor of images. + } + """ + # Stack images in the order dictated by input_shapes. + batch["observation.images"] = torch.stack( + [batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")], + dim=-4, + ) + + def save(self, fp): + torch.save(self.state_dict(), fp) + + def load(self, fp): + d = torch.load(fp) + self.load_state_dict(d) + + +class _ActionChunkingTransformer(nn.Module): + """Action Chunking Transformer: The underlying neural network for ActionChunkingTransformerPolicy. Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows. - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the @@ -59,24 +161,9 @@ class ActionChunkingTransformerPolicy(nn.Module): └───────────────────────┘ """ - name = "act" - - def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None): - """ - Args: - cfg: Policy configuration class instance or None, in which case the default instantiation of the - configuration class is used. - """ + def __init__(self, cfg: ActionChunkingTransformerConfig): super().__init__() - if cfg is None: - cfg = ActionChunkingTransformerConfig() self.cfg = cfg - self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats) - self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats) - self.unnormalize_outputs = Unnormalize( - cfg.output_shapes, cfg.output_normalization_modes, dataset_stats - ) - # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). if self.cfg.use_vae: @@ -141,76 +228,7 @@ def _reset_parameters(self): if p.dim() > 1: nn.init.xavier_uniform_(p) - def reset(self): - """This should be called whenever the environment is reset.""" - if self.cfg.n_action_steps is not None: - self._action_queue = deque([], maxlen=self.cfg.n_action_steps) - - @torch.no_grad - def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: - """Select a single action given environment observations. - - This method wraps `select_actions` in order to return one action at a time for execution in the - environment. It works by managing the actions in a queue and only calling `select_actions` when the - queue is empty. - """ - self.eval() - - batch = self.normalize_inputs(batch) - - if len(self._action_queue) == 0: - # `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively - # has shape (n_action_steps, batch_size, *), hence the transpose. - actions = self._forward(batch)[0][: self.cfg.n_action_steps] - - # TODO(rcadene): make _forward return output dictionary? - actions = self.unnormalize_outputs({"action": actions})["action"] - - self._action_queue.extend(actions.transpose(0, 1)) - return self._action_queue.popleft() - - def forward(self, batch, **_) -> dict[str, Tensor]: - """Run the batch through the model and compute the loss for training or validation.""" - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) - actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch) - - l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) - ).mean() - - loss_dict = {"l1_loss": l1_loss} - if self.cfg.use_vae: - # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for - # each dimension independently, we sum over the latent dimension to get the total - # KL-divergence per batch element, then take the mean over the batch. - # (See App. B of https://arxiv.org/abs/1312.6114 for more details). - mean_kld = ( - (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() - ) - loss_dict["kld_loss"] = mean_kld - loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight - else: - loss_dict["loss"] = l1_loss - - return loss_dict - - def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - """Stacks all the images in a batch and puts them in a new key: "observation.images". - - This function expects `batch` to have (at least): - { - "observation.state": (B, state_dim) batch of robot states. - "observation.images.{name}": (B, C, H, W) tensor of images. - } - """ - # Stack images in the order dictated by input_shapes. - batch["observation.images"] = torch.stack( - [batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")], - dim=-4, - ) - - def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: """A forward pass through the Action Chunking Transformer (with optional VAE encoder). `batch` should have the following structure: @@ -231,8 +249,6 @@ def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tens "action" in batch ), "actions must be provided when using the variational objective in training mode." - self._stack_images(batch) - batch_size = batch["observation.state"].shape[0] # Prepare the latent for input to the transformer encoder. @@ -324,13 +340,6 @@ def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tens return actions, (mu, log_sigma_x2) - def save(self, fp): - torch.save(self.state_dict(), fp) - - def load(self, fp): - d = torch.load(fp) - self.load_state_dict(d) - class _TransformerEncoder(nn.Module): """Convenience module for running multiple encoder layers, maybe followed by normalization."""