Skip to content

Commit

Permalink
pass entire config to make_optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
michel-aractingi committed Sep 2, 2024
1 parent 3034272 commit 06fc9b8
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 29 deletions.
9 changes: 5 additions & 4 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,8 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:

return loss_dict

def make_optimizer_and_scheduler(self, **kwargs):
def make_optimizer_and_scheduler(self, cfg):
"""Create the optimizer and learning rate scheduler for ACT"""
lr, lr_backbone, weight_decay = kwargs["lr"], kwargs["lr_backbone"], kwargs["weight_decay"]
optimizer_params_dicts = [
{
"params": [
Expand All @@ -177,10 +176,12 @@ def make_optimizer_and_scheduler(self, **kwargs):
for n, p in self.named_parameters()
if n.startswith("model.backbone") and p.requires_grad
],
"lr": lr_backbone,
"lr": cfg.training.lr_backbone,
},
]
optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=lr, weight_decay=weight_decay)
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
)
lr_scheduler = None
return optimizer, lr_scheduler

Expand Down
27 changes: 8 additions & 19 deletions lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,33 +156,22 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}

def make_optimizer_and_scheduler(self, **kwargs):
def make_optimizer_and_scheduler(self, cfg):
"""Create the optimizer and learning rate scheduler for Diffusion policy"""
lr, adam_betas, adam_eps, adam_weight_decay = (
kwargs["lr"],
kwargs["adam_betas"],
kwargs["adam_eps"],
kwargs["adam_weight_decay"],
)
lr_scheduler_name, lr_warmup_steps, offline_steps = (
kwargs["lr_scheduler"],
kwargs["lr_warmup_steps"],
kwargs["offline_steps"],
)
optimizer = torch.optim.Adam(
self.diffusion.parameters(),
lr,
adam_betas,
adam_eps,
adam_weight_decay,
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
cfg.training.adam_weight_decay,
)
from diffusers.optimization import get_scheduler

lr_scheduler = get_scheduler(
lr_scheduler_name,
cfg.training.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps,
num_training_steps=offline_steps,
num_warmup_steps=cfg.training.lr_warmup_steps,
num_training_steps=cfg.training.offline_steps,
)
return optimizer, lr_scheduler

Expand Down
5 changes: 2 additions & 3 deletions lerobot/common/policies/tdmpc/modeling_tdmpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,10 +534,9 @@ def update(self):
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)

def make_optimizer_and_scheduler(self, **kwargs):
def make_optimizer_and_scheduler(self, cfg):
"""Create the optimizer and learning rate scheduler for TD-MPC"""
lr = kwargs["lr"]
optimizer = torch.optim.Adam(self.parameters(), lr)
optimizer = torch.optim.Adam(self.parameters(), cfg.training.lr)
lr_scheduler = None
return optimizer, lr_scheduler

Expand Down
6 changes: 6 additions & 0 deletions lerobot/common/policies/vqbet/modeling_vqbet.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:

return loss_dict

def make_optimizer_and_scheduler(self, cfg):
"""Create the optimizer and learning rate scheduler for VQ-BeT"""
optimizer = VQBeTOptimizer(self, cfg)
scheduler = VQBeTScheduler(optimizer, cfg)
return optimizer, scheduler


class SpatialSoftmax(nn.Module):
"""
Expand Down
2 changes: 1 addition & 1 deletion lerobot/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
assert isinstance(policy, nn.Module)
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = policy.make_optimizer_and_scheduler(**cfg.training)
optimizer, lr_scheduler = policy.make_optimizer_and_scheduler(cfg)
grad_scaler = GradScaler(enabled=cfg.use_amp)

step = 0 # number of policy updates (forward + backward + optim)
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/save_policy_to_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
policy.train()
optimizer, _ = policy.make_optimizer_and_scheduler(**cfg.training)
optimizer, _ = policy.make_optimizer_and_scheduler(cfg)

dataloader = torch.utils.data.DataLoader(
dataset,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def test_act_backbone_lr():

dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
optimizer, _ = policy.make_optimizer_and_scheduler(**cfg.training)
optimizer, _ = policy.make_optimizer_and_scheduler(cfg)
assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone
Expand Down

0 comments on commit 06fc9b8

Please sign in to comment.