diff --git a/models/__init__.py b/models/__init__.py index 89e83b1..a3b7377 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -18,6 +18,7 @@ from .resmlp import * from .aaai22 import * from .rexnet import * +from .regnet import * from .vision_transformer import * from .convmixer import * diff --git a/models/core/blocks.py b/models/core/blocks.py index 50e80c7..2cb10b2 100644 --- a/models/core/blocks.py +++ b/models/core/blocks.py @@ -9,6 +9,7 @@ _NORM_POSIITON: str = 'before' _NORMALIZER: nn.Module = nn.BatchNorm2d _NONLINEAR: nn.Module = partial(nn.ReLU, inplace=True) +_SE_INNER_NONLINEAR: nn.Module = partial(nn.ReLU, inplace=True) _SE_GATING_FN: nn.Module = nn.Sigmoid _SE_DIVISOR: int = 8 _SE_USE_NORM: bool = False @@ -46,26 +47,39 @@ def nonlinear(layer: nn.Module): @contextmanager def se( + inner_nonlinear: nn.Module = _SE_INNER_NONLINEAR, gating_fn: nn.Module = _SE_GATING_FN, divisor: int = _SE_DIVISOR, use_norm: bool = _SE_USE_NORM ): + global _SE_INNER_NONLINEAR global _SE_GATING_FN global _SE_DIVISOR global _SE_USE_NORM + _pre_inner_fn = _SE_INNER_NONLINEAR _pre_fn = _SE_GATING_FN _pre_divisor = _SE_DIVISOR _pre_use_norm = _SE_USE_NORM + _SE_INNER_NONLINEAR = inner_nonlinear _SE_GATING_FN = gating_fn _SE_DIVISOR = divisor _SE_USE_NORM = use_norm yield + _SE_INNER_NONLINEAR = _pre_inner_fn _SE_GATING_FN = _pre_fn _SE_DIVISOR = _pre_divisor _SE_USE_NORM = _pre_use_norm +def normalizer_fn(channels): + return _NORMALIZER(channels) + + +def activation_fn(): + return _NONLINEAR() + + def norm_activation( channels, normalizer_fn: nn.Module = None, @@ -85,7 +99,7 @@ def norm_activation( return [activation_fn()] if activation_fn == None: - return [normalizer_fn()] + return [normalizer_fn(channels)] if norm_position == 'after': return [activation_fn(), normalizer_fn(channels)] diff --git a/models/mobilenetv3.py b/models/mobilenetv3.py index 70c6a6f..f437ae2 100644 --- a/models/mobilenetv3.py +++ b/models/mobilenetv3.py @@ -67,7 +67,7 @@ def __init__( self.classifier = nn.Sequential( nn.Linear(576, 1024), hardswish(), - nn.Dropout(dropout_rate), + nn.Dropout(dropout_rate, inplace=True), nn.Linear(1024, num_classes) ) @@ -137,7 +137,7 @@ def __init__( self.classifier = nn.Sequential( nn.Linear(960, 1280), hardswish(), - nn.Dropout(dropout_rate), + nn.Dropout(dropout_rate, inplace=True), nn.Linear(1280, num_classes) ) diff --git a/models/regnet.py b/models/regnet.py new file mode 100644 index 0000000..62064c4 --- /dev/null +++ b/models/regnet.py @@ -0,0 +1,284 @@ +''' +Papers: + [RegNet] Designing Network Design Spaces + +RegNet trends: + 1. The depth of best models is stable across regimes, with an optimal depth of ~20 blocks(60 layers); + 2. The best models use a bottleneck ratio of 1.0, which effectively removes the bottleneck; + 3. The width multiplier wm of good models is ~2.5. + +Notice: + 1. The inverted bottleneck degrades the EDF slightly and depthwise conv performs even worse relative to b = 1 and g >= 1. + 2. SE is useful; + 3. Swish outperforms ReLU at low flops, but ReLU is better at high flops. + Interestingly, if g is restricted to be 1 (depthwiseconv), Swish performs much better than ReLU. +''' +import os +import math +import torch +import torch.nn as nn +from .core import blocks, export, make_divisible +from typing import Any + + +class BottleneckTransform(nn.Sequential): + @blocks.se(divisor=1) + def __init__( + self, + inp, + oup, + stride, + group_width, + bottleneck_multiplier, + se_ratio + ): + super().__init__() + + wb = int(round(oup * bottleneck_multiplier)) + + self.add_module('1x1-1', blocks.Conv2d1x1Block(inp, wb)) + self.add_module('3x3', blocks.Conv2dBlock(wb, wb, stride=stride, groups=(wb // group_width))) + + if se_ratio: + self.add_module('se', blocks.SEBlock(wb, (inp * se_ratio) / wb)) # se <-> inp + + self.add_module('1x1-2', blocks.Conv2d1x1BN(wb, oup)) + + +class ResBottleneckBlock(nn.Module): + """Residual bottleneck block: x + F(x), F = bottleneck transform.""" + + def __init__( + self, + inp: int, + oup: int, + stride: int, + group_width: int = 1, + bottleneck_multiplier: float = 1.0, + se_ratio: float = None, + ) -> None: + super().__init__() + + # Use skip connection with projection if shape changes + self.proj = None + should_proj = (inp != oup) or (stride != 1) + if should_proj: + self.proj = blocks.Conv2d1x1BN(inp, oup, stride) + + self.f = BottleneckTransform( + inp, + oup, + stride, + group_width, + bottleneck_multiplier, + se_ratio, + ) + + self.activation = blocks.activation_fn() + + def forward(self, x): + if self.proj is not None: + x = self.proj(x) + self.f(x) + else: + x = x + self.f(x) + return self.activation(x) + + +class RegStage(nn.Sequential): + def __init__( + self, + in_width, + out_width, + stride, + depth, + group_widths, + bottleneck_multiplier, + se_ratio: float, + stage_index: int + ): + super().__init__() + + for i in range(depth): + self.add_module( + f'block{stage_index}-{i}', + ResBottleneckBlock( + in_width if i == 0 else out_width, + out_width, + stride if i == 0 else 1, + group_widths, + bottleneck_multiplier, + se_ratio + ) + ) + + +@export +class RegNet(nn.Module): + def __init__( + self, + in_channels: int = 3, + num_classes: int = 1000, + stem_width: int = 32, + d: int = None, + w0: int = None, + wa: float = None, + wm: float = None, + b: float = None, + g: int = None, + se_ratio: float = None, + dropout_rate: float = 0.0, + **kwargs: Any + ): + """ + d: the number of blocks + w0: initial width + wa: slope + uj = w0 + wa * j for 0 <= j < d -> for each block + wm: + b: bottleneck ratio + g: group width + """ + super().__init__() + + self.features = nn.Sequential() + self.features.add_module('stem', blocks.Conv2dBlock(in_channels, stem_width, stride=2)) + + uj = w0 + wa * torch.arange(d) + sj = torch.round(torch.log(uj / w0) / math.log(wm)) + + widths = (torch.round((w0 * torch.pow(wm, sj)) / 8) * 8).int().tolist() + widths = [int(make_divisible(w * b, min(g, w * b)) / b) for w in widths] # Adjusts the compatibility of widths and groups + num_stages = len(set(widths)) + stage_depths = [(torch.tensor(widths) == w).sum().item() for w in torch.unique(torch.tensor(widths))] + stage_widths = torch.unique(torch.tensor(widths)).numpy().tolist() + group_widths = [g] * num_stages + group_widths = [min(g, int(w * b)) for g, w in zip(group_widths, stage_widths)] + bottleneck_multipliers = [b] * num_stages + + stage_widths = [stem_width] + stage_widths + + for i in range(num_stages): + self.features.add_module( + f'stage{i}', + RegStage( + stage_widths[i], + stage_widths[i+1], + 2, + stage_depths[i], + group_widths[i], + bottleneck_multipliers[i], + se_ratio, + i + 1 + ) + ) + + self.avg = nn.AdaptiveAvgPool2d((1, 1)) + self.classifier = nn.Sequential( + nn.Dropout(dropout_rate, inplace=True), + nn.Linear(stage_widths[-1], num_classes) + ) + + def forward(self, x): + x = self.features(x) + x = self.avg(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x + + +def _regnet( + d: int, + w0: int, + wa: float, + wm: float, + b: float = 1.0, + g: int = None, + se_ratio: float = None, + pretrained: bool = False, + pth: str = None, + progress: bool = True, + **kwargs: Any +): + model = RegNet(d=d, w0=w0, wa=wa, wm=wm, b=b, g=g, se_ratio=se_ratio, **kwargs) + + if pretrained: + if pth is not None: + state_dict = torch.load(os.path.expanduser(pth)) + else: + assert 'url' in kwargs and kwargs['url'] != '', 'Invalid URL.' + state_dict = torch.hub.load_state_dict_from_url( + kwargs['url'], + progress=progress + ) + model.load_state_dict(state_dict) + return model + + +@export +def regnet_x_400mf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs): + return _regnet(22, 24, 24.48, 2.54, 1.0, 16, None, pretrained, pth, progress, **kwargs) + + +@export +def regnet_x_800mf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs): + return _regnet(16, 56, 35.73, 2.28, 1.0, 16, None, pretrained, pth, progress, **kwargs) + + +@export +def regnet_x_1_6gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs): + return _regnet(18, 80, 34.01, 2.25, 1.0, 24, None, pretrained, pth, progress, **kwargs) + + +@export +def regnet_x_3_2gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs): + return _regnet(25, 88, 26.32, 2.25, 1.0, 48, None, pretrained, pth, progress, **kwargs) + + +@export +def regnet_x_8gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs): + return _regnet(23, 80, 49.56, 2.88, 1.0, 120, None, pretrained, pth, progress, **kwargs) + + +@export +def regnet_x_16gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs): + return _regnet(22, 216, 55.59, 2.1, 1.0, 128, None, pretrained, pth, progress, **kwargs) + + +@export +def regnet_x_32gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs): + return _regnet(23, 320, 69.86, 2.0, 1.0, 168, None, pretrained, pth, progress, **kwargs) + + +@export +def regnet_y_400mf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs): + return _regnet(16, 48, 27.89, 2.09, 1.0, 8, 0.25, pretrained, pth, progress, **kwargs) + + +@export +def regnet_y_800mf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs): + return _regnet(14, 56, 38.84, 2.4, 1.0, 16, 0.25, pretrained, pth, progress, **kwargs) + + +@export +def regnet_y_1_6gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs): + return _regnet(27, 48, 20.71, 2.65, 1.0, 24, 0.25, pretrained, pth, progress, **kwargs) + + +@export +def regnet_y_3_2gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs): + return _regnet(21, 80, 42.63, 2.66, 1.0, 24, 0.25, pretrained, pth, progress, **kwargs) + + +@export +def regnet_y_8gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs): + return _regnet(17, 192, 76.82, 2.19, 1.0, 56, 0.25, pretrained, pth, progress, **kwargs) + + +@export +def regnet_y_16gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs): + return _regnet(18, 200, 106.23, 2.48, 1.0, 112, 0.25, pretrained, pth, progress, **kwargs) + + +@export +def regnet_y_32gf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs): + return _regnet(20, 232, 115.89, 2.53, 1.0, 232, 0.25, pretrained, pth, progress, **kwargs)