From 77422df7780bb106c4ea8afcdf4019794ce8f482 Mon Sep 17 00:00:00 2001 From: bolero2 Date: Thu, 22 Dec 2022 10:29:22 +0900 Subject: [PATCH] feat: added Adam/AdamW Optimizer --- tools/train.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tools/train.py b/tools/train.py index 5bc2211f..6d9afb02 100644 --- a/tools/train.py +++ b/tools/train.py @@ -136,18 +136,28 @@ def main(): model = nn.DataParallel(model, device_ids=gpus).cuda() # optimizer - if config.TRAIN.OPTIMIZER == 'sgd': - params_dict = dict(model.named_parameters()) - params = [{'params': list(params_dict.values()), 'lr': config.TRAIN.LR}] + params_dict = dict(model.named_parameters()) + params = [{'params': list(params_dict.values()), 'lr': config.TRAIN.LR}] + if config.TRAIN.OPTIMIZER == 'sgd': optimizer = torch.optim.SGD(params, lr=config.TRAIN.LR, momentum=config.TRAIN.MOMENTUM, weight_decay=config.TRAIN.WD, nesterov=config.TRAIN.NESTEROV, ) + elif config.TRAIN.OPTIMIZER == 'adam': + optimizer = torch.optim.Adam(params, + lr=config.TRAIN.LR, + weight_decay=config.TRAIN.WD, + betas=(config.TRAIN.MOMENTUM, 0.999)) + elif config.TRAIN.OPTIMIZER == 'adamw': + optimizer = torch.optim.AdamW(params, + lr=config.TRAIN.LR, + weight_decay=config.TRAIN.WD, + betas=(config.TRAIN.MOMENTUM, 0.999)) else: - raise ValueError('Only Support SGD optimizer') + raise ValueError('Only Support SGD, Adam, AdamW optimizer') epoch_iters = int(train_dataset.__len__() / config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))