Skip to content

Commit

Permalink
Add RegNets
Browse files Browse the repository at this point in the history
  • Loading branch information
ffiirree committed Nov 18, 2021
1 parent e3eca38 commit ec9248c
Show file tree
Hide file tree
Showing 4 changed files with 302 additions and 3 deletions.
1 change: 1 addition & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .resmlp import *
from .aaai22 import *
from .rexnet import *
from .regnet import *
from .vision_transformer import *
from .convmixer import *

Expand Down
16 changes: 15 additions & 1 deletion models/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
_NORM_POSIITON: str = 'before'
_NORMALIZER: nn.Module = nn.BatchNorm2d
_NONLINEAR: nn.Module = partial(nn.ReLU, inplace=True)
_SE_INNER_NONLINEAR: nn.Module = partial(nn.ReLU, inplace=True)
_SE_GATING_FN: nn.Module = nn.Sigmoid
_SE_DIVISOR: int = 8
_SE_USE_NORM: bool = False
Expand Down Expand Up @@ -46,26 +47,39 @@ def nonlinear(layer: nn.Module):

@contextmanager
def se(
inner_nonlinear: nn.Module = _SE_INNER_NONLINEAR,
gating_fn: nn.Module = _SE_GATING_FN,
divisor: int = _SE_DIVISOR,
use_norm: bool = _SE_USE_NORM
):
global _SE_INNER_NONLINEAR
global _SE_GATING_FN
global _SE_DIVISOR
global _SE_USE_NORM

_pre_inner_fn = _SE_INNER_NONLINEAR
_pre_fn = _SE_GATING_FN
_pre_divisor = _SE_DIVISOR
_pre_use_norm = _SE_USE_NORM
_SE_INNER_NONLINEAR = inner_nonlinear
_SE_GATING_FN = gating_fn
_SE_DIVISOR = divisor
_SE_USE_NORM = use_norm
yield
_SE_INNER_NONLINEAR = _pre_inner_fn
_SE_GATING_FN = _pre_fn
_SE_DIVISOR = _pre_divisor
_SE_USE_NORM = _pre_use_norm


def normalizer_fn(channels):
return _NORMALIZER(channels)


def activation_fn():
return _NONLINEAR()


def norm_activation(
channels,
normalizer_fn: nn.Module = None,
Expand All @@ -85,7 +99,7 @@ def norm_activation(
return [activation_fn()]

if activation_fn == None:
return [normalizer_fn()]
return [normalizer_fn(channels)]

if norm_position == 'after':
return [activation_fn(), normalizer_fn(channels)]
Expand Down
4 changes: 2 additions & 2 deletions models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
self.classifier = nn.Sequential(
nn.Linear(576, 1024),
hardswish(),
nn.Dropout(dropout_rate),
nn.Dropout(dropout_rate, inplace=True),
nn.Linear(1024, num_classes)
)

Expand Down Expand Up @@ -137,7 +137,7 @@ def __init__(
self.classifier = nn.Sequential(
nn.Linear(960, 1280),
hardswish(),
nn.Dropout(dropout_rate),
nn.Dropout(dropout_rate, inplace=True),
nn.Linear(1280, num_classes)
)

Expand Down
284 changes: 284 additions & 0 deletions models/regnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
'''
Papers:
[RegNet] Designing Network Design Spaces
RegNet trends:
1. The depth of best models is stable across regimes, with an optimal depth of ~20 blocks(60 layers);
2. The best models use a bottleneck ratio of 1.0, which effectively removes the bottleneck;
3. The width multiplier wm of good models is ~2.5.
Notice:
1. The inverted bottleneck degrades the EDF slightly and depthwise conv performs even worse relative to b = 1 and g >= 1.
2. SE is useful;
3. Swish outperforms ReLU at low flops, but ReLU is better at high flops.
Interestingly, if g is restricted to be 1 (depthwiseconv), Swish performs much better than ReLU.
'''
import os
import math
import torch
import torch.nn as nn
from .core import blocks, export, make_divisible
from typing import Any


class BottleneckTransform(nn.Sequential):
@blocks.se(divisor=1)
def __init__(
self,
inp,
oup,
stride,
group_width,
bottleneck_multiplier,
se_ratio
):
super().__init__()

wb = int(round(oup * bottleneck_multiplier))

self.add_module('1x1-1', blocks.Conv2d1x1Block(inp, wb))
self.add_module('3x3', blocks.Conv2dBlock(wb, wb, stride=stride, groups=(wb // group_width)))

if se_ratio:
self.add_module('se', blocks.SEBlock(wb, (inp * se_ratio) / wb)) # se <-> inp

self.add_module('1x1-2', blocks.Conv2d1x1BN(wb, oup))


class ResBottleneckBlock(nn.Module):
"""Residual bottleneck block: x + F(x), F = bottleneck transform."""

def __init__(
self,
inp: int,
oup: int,
stride: int,
group_width: int = 1,
bottleneck_multiplier: float = 1.0,
se_ratio: float = None,
) -> None:
super().__init__()

# Use skip connection with projection if shape changes
self.proj = None
should_proj = (inp != oup) or (stride != 1)
if should_proj:
self.proj = blocks.Conv2d1x1BN(inp, oup, stride)

self.f = BottleneckTransform(
inp,
oup,
stride,
group_width,
bottleneck_multiplier,
se_ratio,
)

self.activation = blocks.activation_fn()

def forward(self, x):
if self.proj is not None:
x = self.proj(x) + self.f(x)
else:
x = x + self.f(x)
return self.activation(x)


class RegStage(nn.Sequential):
def __init__(
self,
in_width,
out_width,
stride,
depth,
group_widths,
bottleneck_multiplier,
se_ratio: float,
stage_index: int
):
super().__init__()

for i in range(depth):
self.add_module(
f'block{stage_index}-{i}',
ResBottleneckBlock(
in_width if i == 0 else out_width,
out_width,
stride if i == 0 else 1,
group_widths,
bottleneck_multiplier,
se_ratio
)
)


@export
class RegNet(nn.Module):
def __init__(
self,
in_channels: int = 3,
num_classes: int = 1000,
stem_width: int = 32,
d: int = None,
w0: int = None,
wa: float = None,
wm: float = None,
b: float = None,
g: int = None,
se_ratio: float = None,
dropout_rate: float = 0.0,
**kwargs: Any
):
"""
d: the number of blocks
w0: initial width
wa: slope
uj = w0 + wa * j for 0 <= j < d -> for each block
wm:
b: bottleneck ratio
g: group width
"""
super().__init__()

self.features = nn.Sequential()
self.features.add_module('stem', blocks.Conv2dBlock(in_channels, stem_width, stride=2))

uj = w0 + wa * torch.arange(d)
sj = torch.round(torch.log(uj / w0) / math.log(wm))

widths = (torch.round((w0 * torch.pow(wm, sj)) / 8) * 8).int().tolist()
widths = [int(make_divisible(w * b, min(g, w * b)) / b) for w in widths] # Adjusts the compatibility of widths and groups
num_stages = len(set(widths))
stage_depths = [(torch.tensor(widths) == w).sum().item() for w in torch.unique(torch.tensor(widths))]
stage_widths = torch.unique(torch.tensor(widths)).numpy().tolist()
group_widths = [g] * num_stages
group_widths = [min(g, int(w * b)) for g, w in zip(group_widths, stage_widths)]
bottleneck_multipliers = [b] * num_stages

stage_widths = [stem_width] + stage_widths

for i in range(num_stages):
self.features.add_module(
f'stage{i}',
RegStage(
stage_widths[i],
stage_widths[i+1],
2,
stage_depths[i],
group_widths[i],
bottleneck_multipliers[i],
se_ratio,
i + 1
)
)

self.avg = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate, inplace=True),
nn.Linear(stage_widths[-1], num_classes)
)

def forward(self, x):
x = self.features(x)
x = self.avg(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x


def _regnet(
d: int,
w0: int,
wa: float,
wm: float,
b: float = 1.0,
g: int = None,
se_ratio: float = None,
pretrained: bool = False,
pth: str = None,
progress: bool = True,
**kwargs: Any
):
model = RegNet(d=d, w0=w0, wa=wa, wm=wm, b=b, g=g, se_ratio=se_ratio, **kwargs)

if pretrained:
if pth is not None:
state_dict = torch.load(os.path.expanduser(pth))
else:
assert 'url' in kwargs and kwargs['url'] != '', 'Invalid URL.'
state_dict = torch.hub.load_state_dict_from_url(
kwargs['url'],
progress=progress
)
model.load_state_dict(state_dict)
return model


@export
def regnet_x_400mf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs):
return _regnet(22, 24, 24.48, 2.54, 1.0, 16, None, pretrained, pth, progress, **kwargs)


@export
def regnet_x_800mf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs):
return _regnet(16, 56, 35.73, 2.28, 1.0, 16, None, pretrained, pth, progress, **kwargs)


@export
def regnet_x_1_6gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs):
return _regnet(18, 80, 34.01, 2.25, 1.0, 24, None, pretrained, pth, progress, **kwargs)


@export
def regnet_x_3_2gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs):
return _regnet(25, 88, 26.32, 2.25, 1.0, 48, None, pretrained, pth, progress, **kwargs)


@export
def regnet_x_8gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs):
return _regnet(23, 80, 49.56, 2.88, 1.0, 120, None, pretrained, pth, progress, **kwargs)


@export
def regnet_x_16gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs):
return _regnet(22, 216, 55.59, 2.1, 1.0, 128, None, pretrained, pth, progress, **kwargs)


@export
def regnet_x_32gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs):
return _regnet(23, 320, 69.86, 2.0, 1.0, 168, None, pretrained, pth, progress, **kwargs)


@export
def regnet_y_400mf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs):
return _regnet(16, 48, 27.89, 2.09, 1.0, 8, 0.25, pretrained, pth, progress, **kwargs)


@export
def regnet_y_800mf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs):
return _regnet(14, 56, 38.84, 2.4, 1.0, 16, 0.25, pretrained, pth, progress, **kwargs)


@export
def regnet_y_1_6gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs):
return _regnet(27, 48, 20.71, 2.65, 1.0, 24, 0.25, pretrained, pth, progress, **kwargs)


@export
def regnet_y_3_2gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs):
return _regnet(21, 80, 42.63, 2.66, 1.0, 24, 0.25, pretrained, pth, progress, **kwargs)


@export
def regnet_y_8gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs):
return _regnet(17, 192, 76.82, 2.19, 1.0, 56, 0.25, pretrained, pth, progress, **kwargs)


@export
def regnet_y_16gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs):
return _regnet(18, 200, 106.23, 2.48, 1.0, 112, 0.25, pretrained, pth, progress, **kwargs)


@export
def regnet_y_32gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs):
return _regnet(20, 232, 115.89, 2.53, 1.0, 232, 0.25, pretrained, pth, progress, **kwargs)

0 comments on commit ec9248c

Please sign in to comment.