From 7dda64e5a988c8927a0a2bd49668f91d2e9ae5fc Mon Sep 17 00:00:00 2001 From: linjintao Date: Tue, 24 Mar 2020 20:49:44 +0800 Subject: [PATCH] Add optimizer registry based on mmdet --- mmaction/core/__init__.py | 1 + mmaction/core/optimizer/__init__.py | 4 + mmaction/core/optimizer/builder.py | 88 +++++++++++++++ mmaction/core/optimizer/registry.py | 22 ++++ mmaction/core/train.py | 85 +-------------- tests/test_optimizer.py | 159 ++++++++++++++++++++++++++++ 6 files changed, 276 insertions(+), 83 deletions(-) create mode 100644 mmaction/core/optimizer/__init__.py create mode 100644 mmaction/core/optimizer/builder.py create mode 100644 mmaction/core/optimizer/registry.py create mode 100644 tests/test_optimizer.py diff --git a/mmaction/core/__init__.py b/mmaction/core/__init__.py index b21910f95a..ebd789c9ff 100644 --- a/mmaction/core/__init__.py +++ b/mmaction/core/__init__.py @@ -1,6 +1,7 @@ from .dist_utils import * # noqa: F401, F403 from .evaluation import * # noqa: F401, F403 from .fp16 import * # noqa: F401, F403 +from .optimizer import * # noqa: F401, F403 from .test import multi_gpu_test, single_gpu_test from .train import set_random_seed, train_model # noqa: F401 diff --git a/mmaction/core/optimizer/__init__.py b/mmaction/core/optimizer/__init__.py new file mode 100644 index 0000000000..82d4503a94 --- /dev/null +++ b/mmaction/core/optimizer/__init__.py @@ -0,0 +1,4 @@ +from .builder import build_optimizer +from .registry import OPTIMIZERS + +__all__ = ['OPTIMIZERS', 'build_optimizer'] diff --git a/mmaction/core/optimizer/builder.py b/mmaction/core/optimizer/builder.py new file mode 100644 index 0000000000..ccbc9794d3 --- /dev/null +++ b/mmaction/core/optimizer/builder.py @@ -0,0 +1,88 @@ +import re + +from mmcv.utils import build_from_cfg + +from .registry import OPTIMIZERS + + +def build_optimizer(model, optimizer_cfg): + """Build optimizer from configs. + + Args: + model (:obj:`nn.Module`): The model with parameters to be optimized. + optimizer_cfg (dict): The config dict of the optimizer. + Positional fields are: + - type: class name of the optimizer. + - lr: base learning rate. + Optional fields are: + - any arguments of the corresponding optimizer type, e.g., + weight_decay, momentum, etc. + - paramwise_options: a dict with 3 accepted fileds + (bias_lr_mult, bias_decay_mult, norm_decay_mult). + `bias_lr_mult` and `bias_decay_mult` will be multiplied to + the lr and weight decay respectively for all bias parameters + (except for the normalization layers), and + `norm_decay_mult` will be multiplied to the weight decay + for all weight and bias parameters of normalization layers. + + Returns: + torch.optim.Optimizer: The initialized optimizer. + + Example: + >>> model = torch.nn.modules.Conv1d(1, 1, 1) + >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9, + >>> weight_decay=0.0001) + >>> optimizer = build_optimizer(model, optimizer_cfg) + """ + if hasattr(model, 'module'): + model = model.module + + optimizer_cfg = optimizer_cfg.copy() + paramwise_options = optimizer_cfg.pop('paramwise_options', None) + # if no paramwise option is specified, just use the global setting + if paramwise_options is None: + params = model.parameters() + else: + if not isinstance(paramwise_options, dict): + raise TypeError(f'paramwise_options should be a dict, ' + f'but got {type(paramwise_options)}') + # get base lr and weight decay + base_lr = optimizer_cfg['lr'] + base_wd = optimizer_cfg.get('weight_decay', None) + # weight_decay must be explicitly specified if mult is specified + if ('bias_decay_mult' in paramwise_options + or 'norm_decay_mult' in paramwise_options): + if base_wd is None: + raise ValueError(f'base_wd should not be {None}') + # get param-wise options + bias_lr_mult = paramwise_options.get('bias_lr_mult', 1.) + bias_decay_mult = paramwise_options.get('bias_decay_mult', 1.) + norm_decay_mult = paramwise_options.get('norm_decay_mult', 1.) + # set param-wise lr and weight decay + params = [] + for name, param in model.named_parameters(): + param_group = {'params': [param]} + if not param.requires_grad: + # FP16 training needs to copy gradient/weight between master + # weight copy and model weight, it is convenient to keep all + # parameters here to align with model.parameters() + params.append(param_group) + continue + + # for norm layers, overwrite the weight decay of weight and bias + # TODO: obtain the norm layer prefixes dynamically + if re.search(r'(bn|gn)(\d+)?.(weight|bias)', name): + if base_wd is not None: + param_group['weight_decay'] = base_wd * norm_decay_mult + # for other layers, overwrite both lr and weight decay of bias + elif name.endswith('.bias'): + param_group['lr'] = base_lr * bias_lr_mult + if base_wd is not None: + param_group['weight_decay'] = base_wd * bias_decay_mult + # otherwise use the global settings + + params.append(param_group) + + optimizer_cfg['params'] = params + + return build_from_cfg(optimizer_cfg, OPTIMIZERS) diff --git a/mmaction/core/optimizer/registry.py b/mmaction/core/optimizer/registry.py new file mode 100644 index 0000000000..0dbc1a64e0 --- /dev/null +++ b/mmaction/core/optimizer/registry.py @@ -0,0 +1,22 @@ +import inspect + +import torch +from mmcv.utils import Registry + +OPTIMIZERS = Registry('optimizer') + + +def register_torch_optimizers(): + torch_optimizers = [] + for module_name in dir(torch.optim): + if module_name.startswith('__'): + continue + _optim = getattr(torch.optim, module_name) + if inspect.isclass(_optim) and issubclass(_optim, + torch.optim.Optimizer): + OPTIMIZERS.register_module(_optim) + torch_optimizers.append(module_name) + return torch_optimizers + + +TORCH_OPTIMIZERS = register_torch_optimizers() diff --git a/mmaction/core/train.py b/mmaction/core/train.py index 2c024c9b09..fb6453a8af 100644 --- a/mmaction/core/train.py +++ b/mmaction/core/train.py @@ -1,15 +1,14 @@ import os import random -import re from collections import OrderedDict import numpy as np import torch from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import DistSamplerSeedHook, Runner, obj_from_dict +from mmcv.runner import DistSamplerSeedHook, Runner from mmaction.core import (DistEvalHook, DistOptimizerHook, EvalHook, - Fp16OptimizerHook) + Fp16OptimizerHook, build_optimizer) from mmaction.datasets import build_dataloader, build_dataset from mmaction.utils import get_root_logger @@ -126,86 +125,6 @@ def train_model(model, meta=meta) -def build_optimizer(model, optimizer_cfg): - """Build optimizer from configs. - - Args: - model (:obj:`nn.Module`): The model with parameters to be optimized. - optimizer_cfg (dict): The config dict of the optimizer. - Positional fields are: - - type: class name of the optimizer. - - lr: base learning rate. - Optional fields are: - - any arguments of the corresponding optimizer type, e.g., - weight_decay, momentum, etc. - - paramwise_options: a dict with 3 accepted fileds - (bias_lr_mult, bias_decay_mult, norm_decay_mult). - `bias_lr_mult` and `bias_decay_mult` will be multiplied to - the lr and weight decay respectively for all bias parameters - (except for the normalization layers), and - `norm_decay_mult` will be multiplied to the weight decay - for all weight and bias parameters of normalization layers. - - Returns: - torch.optim.Optimizer: The initialized optimizer. - - Example: - >>> model = torch.nn.modules.Conv1d(1, 1, 1) - >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9, - >>> weight_decay=0.0001) - >>> optimizer = build_optimizer(model, optimizer_cfg) - """ - if hasattr(model, 'module'): - model = model.module - - optimizer_cfg = optimizer_cfg.copy() - paramwise_options = optimizer_cfg.pop('paramwise_options', None) - # if no paramwise option is specified, just use the global setting - if paramwise_options is None: - return obj_from_dict(optimizer_cfg, torch.optim, - dict(params=model.parameters())) - else: - assert isinstance(paramwise_options, dict) - # get base lr and weight decay - base_lr = optimizer_cfg['lr'] - base_wd = optimizer_cfg.get('weight_decay', None) - # weight_decay must be explicitly specified if mult is specified - if ('bias_decay_mult' in paramwise_options - or 'norm_decay_mult' in paramwise_options): - assert base_wd is not None - # get param-wise options - bias_lr_mult = paramwise_options.get('bias_lr_mult', 1.) - bias_decay_mult = paramwise_options.get('bias_decay_mult', 1.) - norm_decay_mult = paramwise_options.get('norm_decay_mult', 1.) - # set param-wise lr and weight decay - params = [] - for name, param in model.named_parameters(): - param_group = {'params': [param]} - if not param.requires_grad: - # FP16 training needs to copy gradient/weight between master - # weight copy and model weight, it is convenient to keep all - # parameters here to align with model.parameters() - params.append(param_group) - continue - - # for norm layers, overwrite the weight decay of weight and bias - # TODO: obtain the norm layer prefixes dynamically - if re.search(r'(bn|gn)(\d+)?.(weight|bias)', name): - if base_wd is not None: - param_group['weight_decay'] = base_wd * norm_decay_mult - # for other layers, overwrite both lr and weight decay of bias - elif name.endswith('.bias'): - param_group['lr'] = base_lr * bias_lr_mult - if base_wd is not None: - param_group['weight_decay'] = base_wd * bias_decay_mult - # otherwise use the global settings - - params.append(param_group) - - optimizer_cls = getattr(torch.optim, optimizer_cfg.pop('type')) - return optimizer_cls(params, **optimizer_cfg) - - def _dist_train(model, dataset, cfg, diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 0000000000..cd932a72e1 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,159 @@ +import pytest +import torch +import torch.nn as nn + +from mmaction.core import build_optimizer +from mmaction.core.optimizer.registry import TORCH_OPTIMIZERS + + +class ExampleModel(nn.Module): + + def __init__(self): + super(ExampleModel, self).__init__() + self.conv1 = nn.Conv2d(3, 8, kernel_size=3) + self.bn = nn.BatchNorm2d(8) + self.gn = nn.GroupNorm(3, 8) + + def forward(self, imgs): + return imgs + + +def test_build_optimizer(): + with pytest.raises(TypeError): + optimizer_cfg = dict(paramwise_options=['error']) + model = ExampleModel() + build_optimizer(model, optimizer_cfg) + + with pytest.raises(ValueError): + optimizer_cfg = dict( + paramwise_options=dict(bias_decay_mult=1, norm_decay_mult=1), + lr=0.0001, + weight_decay=None) + model = ExampleModel() + build_optimizer(model, optimizer_cfg) + + base_lr = 0.0001 + base_wd = 0.0002 + momentum = 0.9 + + # basic config with ExampleModel + optimizer_cfg = dict( + type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum) + model = ExampleModel() + optimizer = build_optimizer(model, optimizer_cfg) + param_dict = dict(model.named_parameters()) + param_groups = optimizer.param_groups[0] + assert isinstance(optimizer, torch.optim.SGD) + assert optimizer.defaults['lr'] == 0.0001 + assert optimizer.defaults['momentum'] == 0.9 + assert optimizer.defaults['weight_decay'] == 0.0002 + assert len(param_groups['params']) == 6 + assert torch.equal(param_groups['params'][0], param_dict['conv1.weight']) + assert torch.equal(param_groups['params'][1], param_dict['conv1.bias']) + assert torch.equal(param_groups['params'][2], param_dict['bn.weight']) + assert torch.equal(param_groups['params'][3], param_dict['bn.bias']) + assert torch.equal(param_groups['params'][4], param_dict['gn.weight']) + assert torch.equal(param_groups['params'][5], param_dict['gn.bias']) + + # basic config with Parallel model + model = torch.nn.DataParallel(ExampleModel()) + optimizer = build_optimizer(model, optimizer_cfg) + param_dict = dict(model.named_parameters()) + param_groups = optimizer.param_groups[0] + assert isinstance(optimizer, torch.optim.SGD) + assert optimizer.defaults['lr'] == 0.0001 + assert optimizer.defaults['momentum'] == 0.9 + assert optimizer.defaults['weight_decay'] == 0.0002 + assert len(param_groups['params']) == 6 + assert torch.equal(param_groups['params'][0], + param_dict['module.conv1.weight']) + assert torch.equal(param_groups['params'][1], + param_dict['module.conv1.bias']) + assert torch.equal(param_groups['params'][2], + param_dict['module.bn.weight']) + assert torch.equal(param_groups['params'][3], param_dict['module.bn.bias']) + assert torch.equal(param_groups['params'][4], + param_dict['module.gn.weight']) + assert torch.equal(param_groups['params'][5], param_dict['module.gn.bias']) + + # Empty paramwise_options with ExampleModel + optimizer_cfg['paramwise_options'] = dict() + model = ExampleModel() + optimizer = build_optimizer(model, optimizer_cfg) + param_groups = optimizer.param_groups + assert isinstance(optimizer, torch.optim.SGD) + assert optimizer.defaults['lr'] == 0.0001 + assert optimizer.defaults['momentum'] == 0.9 + assert optimizer.defaults['weight_decay'] == 0.0002 + for i, (name, param) in enumerate(model.named_parameters()): + param_group = param_groups[i] + assert param_group['params'] == [param] + assert param_group['momentum'] == 0.9 + assert param_group['lr'] == 0.0001 + assert param_group['weight_decay'] == 0.0002 + + # Empty paramwise_options with ExampleModel and no grad + for param in model.parameters(): + param.requires_grad = False + optimizer = build_optimizer(model, optimizer_cfg) + param_groups = optimizer.param_groups + assert isinstance(optimizer, torch.optim.SGD) + assert optimizer.defaults['lr'] == 0.0001 + assert optimizer.defaults['momentum'] == 0.9 + assert optimizer.defaults['weight_decay'] == 0.0002 + for i, (name, param) in enumerate(model.named_parameters()): + param_group = param_groups[i] + assert param_group['params'] == [param] + assert param_group['momentum'] == 0.9 + assert param_group['lr'] == 0.0001 + assert param_group['weight_decay'] == 0.0002 + + # paramwise_options with ExampleModel + paramwise_options = dict( + bias_lr_mult=0.9, bias_decay_mult=0.8, norm_decay_mult=0.7) + optimizer_cfg['paramwise_options'] = paramwise_options + model = ExampleModel() + optimizer = build_optimizer(model, optimizer_cfg) + param_groups = optimizer.param_groups + assert isinstance(optimizer, torch.optim.SGD) + assert optimizer.defaults['lr'] == 0.0001 + assert optimizer.defaults['momentum'] == 0.9 + assert optimizer.defaults['weight_decay'] == 0.0002 + for i, (name, param) in enumerate(model.named_parameters()): + param_group = param_groups[i] + assert param_group['params'] == [param] + assert param_group['momentum'] == 0.9 + assert param_groups[0]['lr'] == 0.0001 + assert param_groups[0]['weight_decay'] == 0.0002 + assert param_groups[1]['lr'] == 0.0001 * 0.9 + assert param_groups[1]['weight_decay'] == 0.0002 * 0.8 + assert param_groups[2]['lr'] == 0.0001 + assert param_groups[2]['weight_decay'] == 0.0002 * 0.7 + assert param_groups[3]['lr'] == 0.0001 + assert param_groups[3]['weight_decay'] == 0.0002 * 0.7 + assert param_groups[4]['lr'] == 0.0001 + assert param_groups[4]['weight_decay'] == 0.0002 * 0.7 + assert param_groups[5]['lr'] == 0.0001 + assert param_groups[5]['weight_decay'] == 0.0002 * 0.7 + + # paramwise_options with ExampleModel and no grad + for param in model.parameters(): + param.requires_grad = False + optimizer = build_optimizer(model, optimizer_cfg) + param_groups = optimizer.param_groups + assert isinstance(optimizer, torch.optim.SGD) + assert optimizer.defaults['lr'] == 0.0001 + assert optimizer.defaults['momentum'] == 0.9 + assert optimizer.defaults['weight_decay'] == 0.0002 + for i, (name, param) in enumerate(model.named_parameters()): + param_group = param_groups[i] + assert param_group['params'] == [param] + assert param_group['momentum'] == 0.9 + assert param_group['lr'] == 0.0001 + assert param_group['weight_decay'] == 0.0002 + + torch_optimizers = [ + 'ASGD', 'Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'LBFGS', + 'Optimizer', 'RMSprop', 'Rprop', 'SGD', 'SparseAdam' + ] + assert set(torch_optimizers).issubset(set(TORCH_OPTIMIZERS))