diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index bd3692ace..cbad38d62 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -98,7 +98,7 @@ class DiffusionConfig: # Inputs / output structure. n_obs_steps: int = 2 - horizon: int = 16 + horizon: int = 10 n_action_steps: int = 8 input_shapes: dict[str, list[int]] = field( @@ -134,7 +134,7 @@ class DiffusionConfig: down_dims: tuple[int, ...] = (512, 1024, 2048) kernel_size: int = 5 n_groups: int = 8 - diffusion_step_embed_dim: int = 128 + diffusion_step_embed_dim: int = 256 use_film_scale_modulation: bool = True # Noise scheduler. noise_scheduler_type: str = "DDPM" @@ -145,6 +145,14 @@ class DiffusionConfig: prediction_type: str = "epsilon" clip_sample: bool = True clip_sample_range: float = 1.0 + # Transformer + use_transformer: bool = True + n_layer: int = 8 + n_head: int = 4 + p_drop_emb: float = 0.0 + p_drop_attn: float = 0.3 + causal_attn: bool = True + n_cond_layers: int = 0 # Inference num_inference_steps: int | None = None @@ -200,7 +208,7 @@ def __post_init__(self): # Check that the horizon size and U-Net downsampling is compatible. # U-Net downsamples by 2 with each stage. downsampling_factor = 2 ** len(self.down_dims) - if self.horizon % downsampling_factor != 0: + if not self.use_transformer and self.horizon % downsampling_factor != 0: raise ValueError( "The horizon should be an integer multiple of the downsampling factor (which is determined " f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}" diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 308a8be3c..b25a1e494 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -22,7 +22,7 @@ import math from collections import deque -from typing import Callable +from typing import Callable, Tuple import einops import numpy as np @@ -188,7 +188,12 @@ def __init__(self, config: DiffusionConfig): self._use_env_state = True global_cond_dim += config.input_shapes["observation.environment_state"][0] - self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps) + if config.use_transformer: + self.net = TransformerForDiffusion(config, cond_dim=global_cond_dim) + else: + self.net = DiffusionConditionalUnet1d( + config, global_cond_dim=global_cond_dim * config.n_obs_steps + ) self.noise_scheduler = _make_noise_scheduler( config.noise_scheduler_type, @@ -206,6 +211,20 @@ def __init__(self, config: DiffusionConfig): else: self.num_inference_steps = config.num_inference_steps + def get_optimizer( + self, + transformer_weight_decay: float = 1e-3, + rgb_encoder_weight_decay: float = 1e-6, + learning_rate: float = 1e-4, + betas: Tuple[float, float] = [0.9, 0.95], + ) -> torch.optim.Optimizer: + optim_groups = self.net.get_optim_groups(weight_decay=transformer_weight_decay) + optim_groups.append( + {"params": self.rgb_encoder.parameters(), "weight_decay": rgb_encoder_weight_decay} + ) + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) + return optimizer + # ========= inference ============ def conditional_sample( self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None @@ -225,7 +244,7 @@ def conditional_sample( for t in self.noise_scheduler.timesteps: # Predict model output. - model_output = self.unet( + model_output = self.net( sample, torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device), global_cond=global_cond, @@ -324,7 +343,7 @@ def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps) # Run the denoising network (that might denoise the trajectory, or attempt to predict the noise). - pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond) + pred = self.net(noisy_trajectory, timesteps, global_cond=global_cond) # Compute the loss. # The target is either the original trajectory, or the noise. @@ -749,3 +768,264 @@ def forward(self, x: Tensor, cond: Tensor) -> Tensor: out = self.conv2(out) out = out + self.residual_conv(x) return out + + +class TransformerForDiffusion(nn.Module): + def __init__(self, config: DiffusionConfig, cond_dim: int): + super().__init__() + self.config = config + + # compute number of tokens for main trunk and condition encoder + if config.n_obs_steps is None: + config.n_obs_steps = config.horizon + + t = config.horizon + t_cond = 1 + t_cond += config.n_obs_steps + + input_dim = config.output_shapes["action"][0] + # input embedding stem + self.input_emb = nn.Linear(input_dim, config.diffusion_step_embed_dim) + self.pos_emb = nn.Parameter(torch.zeros(1, t, config.diffusion_step_embed_dim)) + self.drop = nn.Dropout(config.p_drop_emb) + + # cond encoder + self.time_emb = DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim) + self.cond_obs_emb = None + + self.cond_obs_emb = nn.Linear(cond_dim, config.diffusion_step_embed_dim) + + self.cond_pos_emb = None + self.encoder = None + self.decoder = None + + self.cond_pos_emb = nn.Parameter(torch.zeros(1, t_cond, config.diffusion_step_embed_dim)) + if config.n_cond_layers > 0: + encoder_layer = nn.TransformerEncoderLayer( + d_model=config.diffusion_step_embed_dim, + nhead=config.n_head, + dim_feedforward=4 * config.diffusion_step_embed_dim, + dropout=config.p_drop_attn, + activation="gelu", + batch_first=True, + norm_first=True, + ) + self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=config.n_cond_layers) + else: + self.encoder = nn.Sequential( + nn.Linear(config.diffusion_step_embed_dim, 4 * config.diffusion_step_embed_dim), + nn.Mish(), + nn.Linear(4 * config.diffusion_step_embed_dim, config.diffusion_step_embed_dim), + ) + # decoder + decoder_layer = nn.TransformerDecoderLayer( + d_model=config.diffusion_step_embed_dim, + nhead=config.n_head, + dim_feedforward=4 * config.diffusion_step_embed_dim, + dropout=config.p_drop_attn, + activation="gelu", + batch_first=True, + norm_first=True, # important for stability + ) + self.decoder = nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=config.n_layer) + + # attention mask + if config.causal_attn: + # causal mask to ensure that attention is only applied to the left in the input sequence + # torch.nn.Transformer uses additive mask as opposed to multiplicative mask in minGPT + # therefore, the upper triangle should be -inf and others (including diag) should be 0. + sz = t + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) + self.register_buffer("mask", mask) + + # assume conditioning over time and observation both + p, q = torch.meshgrid(torch.arange(t), torch.arange(t_cond), indexing="ij") + mask = p >= (q - 1) # add one dimension since time is the first token in cond + mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) + self.register_buffer("memory_mask", mask) + else: + self.mask = None + self.memory_mask = None + + # decoder head + self.ln_f = nn.LayerNorm(config.diffusion_step_embed_dim) + self.head = nn.Linear(config.diffusion_step_embed_dim, input_dim) + + # constants + self.t = t + self.t_cond = t_cond + self.horizon = config.horizon + self.n_obs_steps = config.n_obs_steps + + # init + self.apply(self._init_weights) + # logger.info( + # "number of parameters: %e", sum(p.numel() for p in self.parameters()) + # ) + + def _init_weights(self, module): + ignore_types = ( + nn.Dropout, + DiffusionSinusoidalPosEmb, + nn.TransformerEncoderLayer, + nn.TransformerDecoderLayer, + nn.TransformerEncoder, + nn.TransformerDecoder, + nn.ModuleList, + nn.Mish, + nn.Sequential, + ) + if isinstance(module, (nn.Linear, nn.Embedding)): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.MultiheadAttention): + weight_names = ["in_proj_weight", "q_proj_weight", "k_proj_weight", "v_proj_weight"] + for name in weight_names: + weight = getattr(module, name) + if weight is not None: + torch.nn.init.normal_(weight, mean=0.0, std=0.02) + + bias_names = ["in_proj_bias", "bias_k", "bias_v"] + for name in bias_names: + bias = getattr(module, name) + if bias is not None: + torch.nn.init.zeros_(bias) + elif isinstance(module, nn.LayerNorm): + torch.nn.init.zeros_(module.bias) + torch.nn.init.ones_(module.weight) + elif isinstance(module, TransformerForDiffusion): + torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02) + if module.cond_obs_emb is not None: + torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02) + elif isinstance(module, ignore_types): + # no param + pass + else: + raise RuntimeError("Unaccounted module {}".format(module)) + + def get_optim_groups(self, weight_decay: float = 1e-3): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.named_modules(): + for pn, _ in m.named_parameters(): + fpn = "{}.{}".format(mn, pn) if mn else pn # full param name + + if pn.endswith("bias"): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.startswith("bias"): + # MultiheadAttention bias starts with "bias" + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # special case the position embedding parameter in the root GPT module as not decayed + no_decay.add("pos_emb") + # no_decay.add("_dummy_variable") + if self.cond_pos_emb is not None: + no_decay.add("cond_pos_emb") + + # validate that we considered every parameter + # param_dict = {pn: p for pn, p in self.named_parameters()} + param_dict = dict(self.named_parameters()) + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( + str(inter_params) + ) + assert ( + len(param_dict.keys() - union_params) == 0 + ), "parameters {} were not separated into either decay/no_decay set!".format( + str(param_dict.keys() - union_params), + ) + + # create the pytorch optimizer object + optim_groups = [ + { + "params": [param_dict[pn] for pn in sorted(decay)], + "weight_decay": weight_decay, + }, + { + "params": [param_dict[pn] for pn in sorted(no_decay)], + "weight_decay": 0.0, + }, + ] + return optim_groups + + def configure_optimizers( + self, + learning_rate: float = 1e-4, + weight_decay: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.95), + ): + optim_groups = self.get_optim_groups(weight_decay=weight_decay) + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) + return optimizer + + def forward(self, sample: torch.Tensor, timestep: torch.Tensor, global_cond: torch.Tensor, **kwargs): + """ + x: (B,T,input_dim) + timestep: (B,) + global_cond: (B, global_cond_dim) + output: (B,T,input_dim) + """ + # 1. time + timesteps = timestep + batch_size = sample.shape[0] + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(batch_size) + time_emb = self.time_emb(timesteps).unsqueeze(1) + # (B,1,n_emb) + + cond = einops.rearrange(global_cond, "b (s n) ... -> b s (n ...)", b=batch_size, s=self.n_obs_steps) + # (B,To,n_cond) + + # process input + input_emb = self.input_emb(sample) + + # encoder + cond_embeddings = time_emb + # (B,1,n_emb) + + cond_obs_emb = self.cond_obs_emb(cond) + # (B,To,n_emb) + cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1) + # (B,To + 1,n_emb) + + tc = cond_embeddings.shape[1] + position_embeddings = self.cond_pos_emb[:, :tc, :] # each position maps to a (learnable) vector + x = self.drop(cond_embeddings + position_embeddings) + x = self.encoder(x) + memory = x + # (B,T_cond,n_emb) + + # decoder + token_embeddings = input_emb + t = token_embeddings.shape[1] + position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector + x = self.drop(token_embeddings + position_embeddings) + # (B,T,n_emb) + x = self.decoder(tgt=x, memory=memory, tgt_mask=self.mask, memory_mask=self.memory_mask) + # (B,T,n_emb) + + # head + x = self.ln_f(x) + x = self.head(x) + # (B,T,n_inp) + return x diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 880819bb9..2c41deffd 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -32,12 +32,14 @@ training: grad_clip_norm: 10 lr: 1.0e-4 lr_scheduler: cosine - lr_warmup_steps: 500 + lr_warmup_steps: 1000 adam_betas: [0.95, 0.999] adam_eps: 1.0e-8 adam_weight_decay: 1.0e-6 online_steps_between_rollouts: 1 + transformer_weight_decay: 1e-3 + delta_timestamps: observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" @@ -56,7 +58,7 @@ policy: # Input / output structure. n_obs_steps: 2 - horizon: 16 + horizon: 10 n_action_steps: 8 input_shapes: @@ -85,8 +87,16 @@ policy: down_dims: [512, 1024, 2048] kernel_size: 5 n_groups: 8 - diffusion_step_embed_dim: 128 + diffusion_step_embed_dim: 256 use_film_scale_modulation: True + # Transformer + use_transformer: True + n_layer: 8 + n_head: 4 + p_drop_emb: 0.0 + p_drop_attn: 0.3 + causal_attn: True + n_cond_layers: 0 # Noise scheduler. noise_scheduler_type: DDPM num_train_timesteps: 100 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index f60f904eb..9cee6aa7d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -75,13 +75,21 @@ def make_optimizer_and_scheduler(cfg, policy): ) lr_scheduler = None elif cfg.policy.name == "diffusion": - optimizer = torch.optim.Adam( - policy.diffusion.parameters(), - cfg.training.lr, - cfg.training.adam_betas, - cfg.training.adam_eps, - cfg.training.adam_weight_decay, - ) + if cfg.policy.use_transformer: + optimizer = policy.diffusion.get_optimizer( + transformer_weight_decay=cfg.training.transformer_weight_decay, + rgb_encoder_weight_decay=cfg.training.adam_weight_decay, + learning_rate=cfg.training.lr, + betas=cfg.training.adam_betas, + ) + else: + optimizer = torch.optim.Adam( + policy.diffusion.parameters(), + cfg.training.lr, + cfg.training.adam_betas, + cfg.training.adam_eps, + cfg.training.adam_weight_decay, + ) from diffusers.optimization import get_scheduler lr_scheduler = get_scheduler(