forked from open-mmlab/mmaction2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'ljt/optim' into 'master'
Add optimizer registry based on mmdet See merge request EIG-Research/mmaction-lite!64
- Loading branch information
Showing
6 changed files
with
276 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .builder import build_optimizer | ||
from .registry import OPTIMIZERS | ||
|
||
__all__ = ['OPTIMIZERS', 'build_optimizer'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import re | ||
|
||
from mmcv.utils import build_from_cfg | ||
|
||
from .registry import OPTIMIZERS | ||
|
||
|
||
def build_optimizer(model, optimizer_cfg): | ||
"""Build optimizer from configs. | ||
Args: | ||
model (:obj:`nn.Module`): The model with parameters to be optimized. | ||
optimizer_cfg (dict): The config dict of the optimizer. | ||
Positional fields are: | ||
- type: class name of the optimizer. | ||
- lr: base learning rate. | ||
Optional fields are: | ||
- any arguments of the corresponding optimizer type, e.g., | ||
weight_decay, momentum, etc. | ||
- paramwise_options: a dict with 3 accepted fileds | ||
(bias_lr_mult, bias_decay_mult, norm_decay_mult). | ||
`bias_lr_mult` and `bias_decay_mult` will be multiplied to | ||
the lr and weight decay respectively for all bias parameters | ||
(except for the normalization layers), and | ||
`norm_decay_mult` will be multiplied to the weight decay | ||
for all weight and bias parameters of normalization layers. | ||
Returns: | ||
torch.optim.Optimizer: The initialized optimizer. | ||
Example: | ||
>>> model = torch.nn.modules.Conv1d(1, 1, 1) | ||
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9, | ||
>>> weight_decay=0.0001) | ||
>>> optimizer = build_optimizer(model, optimizer_cfg) | ||
""" | ||
if hasattr(model, 'module'): | ||
model = model.module | ||
|
||
optimizer_cfg = optimizer_cfg.copy() | ||
paramwise_options = optimizer_cfg.pop('paramwise_options', None) | ||
# if no paramwise option is specified, just use the global setting | ||
if paramwise_options is None: | ||
params = model.parameters() | ||
else: | ||
if not isinstance(paramwise_options, dict): | ||
raise TypeError(f'paramwise_options should be a dict, ' | ||
f'but got {type(paramwise_options)}') | ||
# get base lr and weight decay | ||
base_lr = optimizer_cfg['lr'] | ||
base_wd = optimizer_cfg.get('weight_decay', None) | ||
# weight_decay must be explicitly specified if mult is specified | ||
if ('bias_decay_mult' in paramwise_options | ||
or 'norm_decay_mult' in paramwise_options): | ||
if base_wd is None: | ||
raise ValueError(f'base_wd should not be {None}') | ||
# get param-wise options | ||
bias_lr_mult = paramwise_options.get('bias_lr_mult', 1.) | ||
bias_decay_mult = paramwise_options.get('bias_decay_mult', 1.) | ||
norm_decay_mult = paramwise_options.get('norm_decay_mult', 1.) | ||
# set param-wise lr and weight decay | ||
params = [] | ||
for name, param in model.named_parameters(): | ||
param_group = {'params': [param]} | ||
if not param.requires_grad: | ||
# FP16 training needs to copy gradient/weight between master | ||
# weight copy and model weight, it is convenient to keep all | ||
# parameters here to align with model.parameters() | ||
params.append(param_group) | ||
continue | ||
|
||
# for norm layers, overwrite the weight decay of weight and bias | ||
# TODO: obtain the norm layer prefixes dynamically | ||
if re.search(r'(bn|gn)(\d+)?.(weight|bias)', name): | ||
if base_wd is not None: | ||
param_group['weight_decay'] = base_wd * norm_decay_mult | ||
# for other layers, overwrite both lr and weight decay of bias | ||
elif name.endswith('.bias'): | ||
param_group['lr'] = base_lr * bias_lr_mult | ||
if base_wd is not None: | ||
param_group['weight_decay'] = base_wd * bias_decay_mult | ||
# otherwise use the global settings | ||
|
||
params.append(param_group) | ||
|
||
optimizer_cfg['params'] = params | ||
|
||
return build_from_cfg(optimizer_cfg, OPTIMIZERS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import inspect | ||
|
||
import torch | ||
from mmcv.utils import Registry | ||
|
||
OPTIMIZERS = Registry('optimizer') | ||
|
||
|
||
def register_torch_optimizers(): | ||
torch_optimizers = [] | ||
for module_name in dir(torch.optim): | ||
if module_name.startswith('__'): | ||
continue | ||
_optim = getattr(torch.optim, module_name) | ||
if inspect.isclass(_optim) and issubclass(_optim, | ||
torch.optim.Optimizer): | ||
OPTIMIZERS.register_module(_optim) | ||
torch_optimizers.append(module_name) | ||
return torch_optimizers | ||
|
||
|
||
TORCH_OPTIMIZERS = register_torch_optimizers() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
|
||
from mmaction.core import build_optimizer | ||
from mmaction.core.optimizer.registry import TORCH_OPTIMIZERS | ||
|
||
|
||
class ExampleModel(nn.Module): | ||
|
||
def __init__(self): | ||
super(ExampleModel, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 8, kernel_size=3) | ||
self.bn = nn.BatchNorm2d(8) | ||
self.gn = nn.GroupNorm(3, 8) | ||
|
||
def forward(self, imgs): | ||
return imgs | ||
|
||
|
||
def test_build_optimizer(): | ||
with pytest.raises(TypeError): | ||
optimizer_cfg = dict(paramwise_options=['error']) | ||
model = ExampleModel() | ||
build_optimizer(model, optimizer_cfg) | ||
|
||
with pytest.raises(ValueError): | ||
optimizer_cfg = dict( | ||
paramwise_options=dict(bias_decay_mult=1, norm_decay_mult=1), | ||
lr=0.0001, | ||
weight_decay=None) | ||
model = ExampleModel() | ||
build_optimizer(model, optimizer_cfg) | ||
|
||
base_lr = 0.0001 | ||
base_wd = 0.0002 | ||
momentum = 0.9 | ||
|
||
# basic config with ExampleModel | ||
optimizer_cfg = dict( | ||
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum) | ||
model = ExampleModel() | ||
optimizer = build_optimizer(model, optimizer_cfg) | ||
param_dict = dict(model.named_parameters()) | ||
param_groups = optimizer.param_groups[0] | ||
assert isinstance(optimizer, torch.optim.SGD) | ||
assert optimizer.defaults['lr'] == 0.0001 | ||
assert optimizer.defaults['momentum'] == 0.9 | ||
assert optimizer.defaults['weight_decay'] == 0.0002 | ||
assert len(param_groups['params']) == 6 | ||
assert torch.equal(param_groups['params'][0], param_dict['conv1.weight']) | ||
assert torch.equal(param_groups['params'][1], param_dict['conv1.bias']) | ||
assert torch.equal(param_groups['params'][2], param_dict['bn.weight']) | ||
assert torch.equal(param_groups['params'][3], param_dict['bn.bias']) | ||
assert torch.equal(param_groups['params'][4], param_dict['gn.weight']) | ||
assert torch.equal(param_groups['params'][5], param_dict['gn.bias']) | ||
|
||
# basic config with Parallel model | ||
model = torch.nn.DataParallel(ExampleModel()) | ||
optimizer = build_optimizer(model, optimizer_cfg) | ||
param_dict = dict(model.named_parameters()) | ||
param_groups = optimizer.param_groups[0] | ||
assert isinstance(optimizer, torch.optim.SGD) | ||
assert optimizer.defaults['lr'] == 0.0001 | ||
assert optimizer.defaults['momentum'] == 0.9 | ||
assert optimizer.defaults['weight_decay'] == 0.0002 | ||
assert len(param_groups['params']) == 6 | ||
assert torch.equal(param_groups['params'][0], | ||
param_dict['module.conv1.weight']) | ||
assert torch.equal(param_groups['params'][1], | ||
param_dict['module.conv1.bias']) | ||
assert torch.equal(param_groups['params'][2], | ||
param_dict['module.bn.weight']) | ||
assert torch.equal(param_groups['params'][3], param_dict['module.bn.bias']) | ||
assert torch.equal(param_groups['params'][4], | ||
param_dict['module.gn.weight']) | ||
assert torch.equal(param_groups['params'][5], param_dict['module.gn.bias']) | ||
|
||
# Empty paramwise_options with ExampleModel | ||
optimizer_cfg['paramwise_options'] = dict() | ||
model = ExampleModel() | ||
optimizer = build_optimizer(model, optimizer_cfg) | ||
param_groups = optimizer.param_groups | ||
assert isinstance(optimizer, torch.optim.SGD) | ||
assert optimizer.defaults['lr'] == 0.0001 | ||
assert optimizer.defaults['momentum'] == 0.9 | ||
assert optimizer.defaults['weight_decay'] == 0.0002 | ||
for i, (name, param) in enumerate(model.named_parameters()): | ||
param_group = param_groups[i] | ||
assert param_group['params'] == [param] | ||
assert param_group['momentum'] == 0.9 | ||
assert param_group['lr'] == 0.0001 | ||
assert param_group['weight_decay'] == 0.0002 | ||
|
||
# Empty paramwise_options with ExampleModel and no grad | ||
for param in model.parameters(): | ||
param.requires_grad = False | ||
optimizer = build_optimizer(model, optimizer_cfg) | ||
param_groups = optimizer.param_groups | ||
assert isinstance(optimizer, torch.optim.SGD) | ||
assert optimizer.defaults['lr'] == 0.0001 | ||
assert optimizer.defaults['momentum'] == 0.9 | ||
assert optimizer.defaults['weight_decay'] == 0.0002 | ||
for i, (name, param) in enumerate(model.named_parameters()): | ||
param_group = param_groups[i] | ||
assert param_group['params'] == [param] | ||
assert param_group['momentum'] == 0.9 | ||
assert param_group['lr'] == 0.0001 | ||
assert param_group['weight_decay'] == 0.0002 | ||
|
||
# paramwise_options with ExampleModel | ||
paramwise_options = dict( | ||
bias_lr_mult=0.9, bias_decay_mult=0.8, norm_decay_mult=0.7) | ||
optimizer_cfg['paramwise_options'] = paramwise_options | ||
model = ExampleModel() | ||
optimizer = build_optimizer(model, optimizer_cfg) | ||
param_groups = optimizer.param_groups | ||
assert isinstance(optimizer, torch.optim.SGD) | ||
assert optimizer.defaults['lr'] == 0.0001 | ||
assert optimizer.defaults['momentum'] == 0.9 | ||
assert optimizer.defaults['weight_decay'] == 0.0002 | ||
for i, (name, param) in enumerate(model.named_parameters()): | ||
param_group = param_groups[i] | ||
assert param_group['params'] == [param] | ||
assert param_group['momentum'] == 0.9 | ||
assert param_groups[0]['lr'] == 0.0001 | ||
assert param_groups[0]['weight_decay'] == 0.0002 | ||
assert param_groups[1]['lr'] == 0.0001 * 0.9 | ||
assert param_groups[1]['weight_decay'] == 0.0002 * 0.8 | ||
assert param_groups[2]['lr'] == 0.0001 | ||
assert param_groups[2]['weight_decay'] == 0.0002 * 0.7 | ||
assert param_groups[3]['lr'] == 0.0001 | ||
assert param_groups[3]['weight_decay'] == 0.0002 * 0.7 | ||
assert param_groups[4]['lr'] == 0.0001 | ||
assert param_groups[4]['weight_decay'] == 0.0002 * 0.7 | ||
assert param_groups[5]['lr'] == 0.0001 | ||
assert param_groups[5]['weight_decay'] == 0.0002 * 0.7 | ||
|
||
# paramwise_options with ExampleModel and no grad | ||
for param in model.parameters(): | ||
param.requires_grad = False | ||
optimizer = build_optimizer(model, optimizer_cfg) | ||
param_groups = optimizer.param_groups | ||
assert isinstance(optimizer, torch.optim.SGD) | ||
assert optimizer.defaults['lr'] == 0.0001 | ||
assert optimizer.defaults['momentum'] == 0.9 | ||
assert optimizer.defaults['weight_decay'] == 0.0002 | ||
for i, (name, param) in enumerate(model.named_parameters()): | ||
param_group = param_groups[i] | ||
assert param_group['params'] == [param] | ||
assert param_group['momentum'] == 0.9 | ||
assert param_group['lr'] == 0.0001 | ||
assert param_group['weight_decay'] == 0.0002 | ||
|
||
torch_optimizers = [ | ||
'ASGD', 'Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'LBFGS', | ||
'Optimizer', 'RMSprop', 'Rprop', 'SGD', 'SparseAdam' | ||
] | ||
assert set(torch_optimizers).issubset(set(TORCH_OPTIMIZERS)) |