diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index ea8fdc2d6..70654d880 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -160,9 +160,8 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: return loss_dict - def make_optimizer_and_scheduler(self, **kwargs): + def make_optimizer_and_scheduler(self, cfg): """Create the optimizer and learning rate scheduler for ACT""" - lr, lr_backbone, weight_decay = kwargs["lr"], kwargs["lr_backbone"], kwargs["weight_decay"] optimizer_params_dicts = [ { "params": [ @@ -177,10 +176,12 @@ def make_optimizer_and_scheduler(self, **kwargs): for n, p in self.named_parameters() if n.startswith("model.backbone") and p.requires_grad ], - "lr": lr_backbone, + "lr": cfg.training.lr_backbone, }, ] - optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=lr, weight_decay=weight_decay) + optimizer = torch.optim.AdamW( + optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay + ) lr_scheduler = None return optimizer, lr_scheduler diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 0093e451c..6d276fa45 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -156,33 +156,22 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: loss = self.diffusion.compute_loss(batch) return {"loss": loss} - def make_optimizer_and_scheduler(self, **kwargs): + def make_optimizer_and_scheduler(self, cfg): """Create the optimizer and learning rate scheduler for Diffusion policy""" - lr, adam_betas, adam_eps, adam_weight_decay = ( - kwargs["lr"], - kwargs["adam_betas"], - kwargs["adam_eps"], - kwargs["adam_weight_decay"], - ) - lr_scheduler_name, lr_warmup_steps, offline_steps = ( - kwargs["lr_scheduler"], - kwargs["lr_warmup_steps"], - kwargs["offline_steps"], - ) optimizer = torch.optim.Adam( self.diffusion.parameters(), - lr, - adam_betas, - adam_eps, - adam_weight_decay, + cfg.training.lr, + cfg.training.adam_betas, + cfg.training.adam_eps, + cfg.training.adam_weight_decay, ) from diffusers.optimization import get_scheduler lr_scheduler = get_scheduler( - lr_scheduler_name, + cfg.training.lr_scheduler, optimizer=optimizer, - num_warmup_steps=lr_warmup_steps, - num_training_steps=offline_steps, + num_warmup_steps=cfg.training.lr_warmup_steps, + num_training_steps=cfg.training.offline_steps, ) return optimizer, lr_scheduler diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 9e988c207..169f67a0a 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -534,10 +534,9 @@ def update(self): # we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995) update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum) - def make_optimizer_and_scheduler(self, **kwargs): + def make_optimizer_and_scheduler(self, cfg): """Create the optimizer and learning rate scheduler for TD-MPC""" - lr = kwargs["lr"] - optimizer = torch.optim.Adam(self.parameters(), lr) + optimizer = torch.optim.Adam(self.parameters(), cfg.training.lr) lr_scheduler = None return optimizer, lr_scheduler diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 87cf59f19..18cf4491e 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -152,6 +152,12 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: return loss_dict + def make_optimizer_and_scheduler(self, cfg): + """Create the optimizer and learning rate scheduler for VQ-BeT""" + optimizer = VQBeTOptimizer(self, cfg) + scheduler = VQBeTScheduler(optimizer, cfg) + return optimizer, scheduler + class SpatialSoftmax(nn.Module): """ diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index e2cf55d6f..0c048cfb6 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -281,7 +281,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No assert isinstance(policy, nn.Module) # Create optimizer and scheduler # Temporary hack to move optimizer out of policy - optimizer, lr_scheduler = policy.make_optimizer_and_scheduler(**cfg.training) + optimizer, lr_scheduler = policy.make_optimizer_and_scheduler(cfg) grad_scaler = GradScaler(enabled=cfg.use_amp) step = 0 # number of policy updates (forward + backward + optim) diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index 7287ed730..033638774 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -39,7 +39,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides): dataset = make_dataset(cfg) policy = make_policy(cfg, dataset_stats=dataset.stats) policy.train() - optimizer, _ = policy.make_optimizer_and_scheduler(**cfg.training) + optimizer, _ = policy.make_optimizer_and_scheduler(cfg) dataloader = torch.utils.data.DataLoader( dataset, diff --git a/tests/test_policies.py b/tests/test_policies.py index 692616610..76a056d24 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -213,7 +213,7 @@ def test_act_backbone_lr(): dataset = make_dataset(cfg) policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) - optimizer, _ = policy.make_optimizer_and_scheduler(**cfg.training) + optimizer, _ = policy.make_optimizer_and_scheduler(cfg) assert len(optimizer.param_groups) == 2 assert optimizer.param_groups[0]["lr"] == cfg.training.lr assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone