diff --git a/tests/test_models.py b/tests/test_models.py index d3ee71c462..247415b04d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -58,6 +58,8 @@ EXCLUDE_FILTERS = ['*enormous*'] NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*'] +EXCLUDE_JIT_FILTERS = [] + TARGET_FWD_SIZE = MAX_FWD_SIZE = 384 TARGET_BWD_SIZE = 128 MAX_BWD_SIZE = 320 @@ -277,7 +279,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size): if 'GITHUB_ACTIONS' not in os.environ: - @pytest.mark.timeout(120) + @pytest.mark.timeout(240) @pytest.mark.parametrize('model_name', list_models(pretrained=True)) @pytest.mark.parametrize('batch_size', [1]) def test_model_load_pretrained(model_name, batch_size): @@ -286,19 +288,13 @@ def test_model_load_pretrained(model_name, batch_size): create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=5) create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=0) - @pytest.mark.timeout(120) + @pytest.mark.timeout(240) @pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_features_pretrained(model_name, batch_size): """Create that pretrained weights load when features_only==True.""" create_model(model_name, pretrained=True, features_only=True) -EXCLUDE_JIT_FILTERS = [ - '*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable - 'dla*', 'hrnet*', 'ghostnet*' # hopefully fix at some point - 'vit_large_*', 'vit_huge_*', 'vit_gi*', -] - @pytest.mark.torchscript @pytest.mark.timeout(120) diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index 78adbf9af7..2eb4ec2eb6 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -52,6 +52,7 @@ def create_classifier( pool_type: str = 'avg', use_conv: bool = False, input_fmt: str = 'NCHW', + drop_rate: Optional[float] = None, ): global_pool, num_pooled_features = _create_pool( num_features, @@ -65,6 +66,9 @@ def create_classifier( num_classes, use_conv=use_conv, ) + if drop_rate is not None: + dropout = nn.Dropout(drop_rate) + return global_pool, dropout, fc return global_pool, fc diff --git a/timm/layers/conv_bn_act.py b/timm/layers/conv_bn_act.py index 9e7c64b858..84aaf4bf1a 100644 --- a/timm/layers/conv_bn_act.py +++ b/timm/layers/conv_bn_act.py @@ -11,9 +11,26 @@ class ConvNormAct(nn.Module): def __init__( - self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, - bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None): + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding='', + dilation=1, + groups=1, + bias=False, + apply_act=True, + norm_layer=nn.BatchNorm2d, + norm_kwargs=None, + act_layer=nn.ReLU, + act_kwargs=None, + drop_layer=None, + ): super(ConvNormAct, self).__init__() + norm_kwargs = norm_kwargs or {} + act_kwargs = act_kwargs or {} + self.conv = create_conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) @@ -21,8 +38,14 @@ def __init__( # NOTE for backwards compatibility with models that use separate norm and act layer definitions norm_act_layer = get_norm_act_layer(norm_layer, act_layer) # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` - norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} - self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) + if drop_layer: + norm_kwargs['drop_layer'] = drop_layer + self.bn = norm_act_layer( + out_channels, + apply_act=apply_act, + act_kwargs=act_kwargs, + **norm_kwargs, + ) @property def in_channels(self): @@ -57,10 +80,27 @@ def create_aa(aa_layer, channels, stride=2, enable=True): class ConvNormActAa(nn.Module): def __init__( - self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, - bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None): + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding='', + dilation=1, + groups=1, + bias=False, + apply_act=True, + norm_layer=nn.BatchNorm2d, + norm_kwargs=None, + act_layer=nn.ReLU, + act_kwargs=None, + aa_layer=None, + drop_layer=None, + ): super(ConvNormActAa, self).__init__() use_aa = aa_layer is not None and stride == 2 + norm_kwargs = norm_kwargs or {} + act_kwargs = act_kwargs or {} self.conv = create_conv2d( in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, @@ -69,8 +109,9 @@ def __init__( # NOTE for backwards compatibility with models that use separate norm and act layer definitions norm_act_layer = get_norm_act_layer(norm_layer, act_layer) # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` - norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} - self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) + if drop_layer: + norm_kwargs['drop_layer'] = drop_layer + self.bn = norm_act_layer(out_channels, apply_act=apply_act, act_kwargs=act_kwargs, **norm_kwargs) self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) @property diff --git a/timm/layers/norm_act.py b/timm/layers/norm_act.py index d1bf6d0b6c..49505c58f5 100644 --- a/timm/layers/norm_act.py +++ b/timm/layers/norm_act.py @@ -24,6 +24,18 @@ from .trace_utils import _assert +def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True): + act_layer = get_act_layer(act_layer) # string -> nn.Module + act_kwargs = act_kwargs or {} + if act_layer is not None and apply_act: + if inplace: + act_kwargs['inplace'] = inplace + act = act_layer(**act_kwargs) + else: + act = nn.Identity() + return act + + class BatchNormAct2d(nn.BatchNorm2d): """BatchNorm + Activation @@ -40,31 +52,33 @@ def __init__( track_running_stats=True, apply_act=True, act_layer=nn.ReLU, - act_params=None, # FIXME not the final approach + act_kwargs=None, inplace=True, drop_layer=None, device=None, - dtype=None + dtype=None, ): try: factory_kwargs = {'device': device, 'dtype': dtype} super(BatchNormAct2d, self).__init__( - num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats, - **factory_kwargs + num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + **factory_kwargs, ) except TypeError: # NOTE for backwards compat with old PyTorch w/o factory device/dtype support super(BatchNormAct2d, self).__init__( - num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) + num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) self.drop = drop_layer() if drop_layer is not None else nn.Identity() - act_layer = get_act_layer(act_layer) # string -> nn.Module - if act_layer is not None and apply_act: - act_args = dict(inplace=True) if inplace else {} - if act_params is not None: - act_args['negative_slope'] = act_params - self.act = act_layer(**act_args) - else: - self.act = nn.Identity() + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) def forward(self, x): # cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing @@ -188,6 +202,7 @@ def __init__( eps: float = 1e-5, apply_act=True, act_layer=nn.ReLU, + act_kwargs=None, inplace=True, drop_layer=None, ): @@ -199,12 +214,7 @@ def __init__( self.register_buffer("running_var", torch.ones(num_features)) self.drop = drop_layer() if drop_layer is not None else nn.Identity() - act_layer = get_act_layer(act_layer) # string -> nn.Module - if act_layer is not None and apply_act: - act_args = dict(inplace=True) if inplace else {} - self.act = act_layer(**act_args) - else: - self.act = nn.Identity() + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) def _load_from_state_dict( self, @@ -344,6 +354,7 @@ def __init__( group_size=None, apply_act=True, act_layer=nn.ReLU, + act_kwargs=None, inplace=True, drop_layer=None, ): @@ -354,12 +365,8 @@ def __init__( affine=affine, ) self.drop = drop_layer() if drop_layer is not None else nn.Identity() - act_layer = get_act_layer(act_layer) # string -> nn.Module - if act_layer is not None and apply_act: - act_args = dict(inplace=True) if inplace else {} - self.act = act_layer(**act_args) - else: - self.act = nn.Identity() + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) + self._fast_norm = is_fast_norm() def forward(self, x): @@ -380,17 +387,14 @@ def __init__( affine=True, apply_act=True, act_layer=nn.ReLU, + act_kwargs=None, inplace=True, drop_layer=None, ): super(GroupNorm1Act, self).__init__(1, num_channels, eps=eps, affine=affine) self.drop = drop_layer() if drop_layer is not None else nn.Identity() - act_layer = get_act_layer(act_layer) # string -> nn.Module - if act_layer is not None and apply_act: - act_args = dict(inplace=True) if inplace else {} - self.act = act_layer(**act_args) - else: - self.act = nn.Identity() + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) + self._fast_norm = is_fast_norm() def forward(self, x): @@ -411,17 +415,15 @@ def __init__( affine=True, apply_act=True, act_layer=nn.ReLU, + act_kwargs=None, inplace=True, drop_layer=None, ): super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine) self.drop = drop_layer() if drop_layer is not None else nn.Identity() act_layer = get_act_layer(act_layer) # string -> nn.Module - if act_layer is not None and apply_act: - act_args = dict(inplace=True) if inplace else {} - self.act = act_layer(**act_args) - else: - self.act = nn.Identity() + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) + self._fast_norm = is_fast_norm() def forward(self, x): @@ -442,17 +444,13 @@ def __init__( affine=True, apply_act=True, act_layer=nn.ReLU, + act_kwargs=None, inplace=True, drop_layer=None, ): super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine) self.drop = drop_layer() if drop_layer is not None else nn.Identity() - act_layer = get_act_layer(act_layer) # string -> nn.Module - if act_layer is not None and apply_act: - act_args = dict(inplace=True) if inplace else {} - self.act = act_layer(**act_args) - else: - self.act = nn.Identity() + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() def forward(self, x): diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index 05768674f6..a4946531b2 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -29,7 +29,7 @@ class PatchEmbed(nn.Module): def __init__( self, - img_size: int = 224, + img_size: Optional[int] = 224, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, @@ -39,12 +39,16 @@ def __init__( bias: bool = True, ): super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.num_patches = self.grid_size[0] * self.grid_size[1] + self.patch_size = to_2tuple(patch_size) + if img_size is not None: + self.img_size = to_2tuple(img_size) + self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + else: + self.img_size = None + self.grid_size = None + self.num_patches = None + if output_fmt is not None: self.flatten = False self.output_fmt = Format(output_fmt) @@ -58,8 +62,10 @@ def __init__( def forward(self, x): B, C, H, W = x.shape - _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") - _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + if self.img_size is not None: + _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # NCHW -> NLC diff --git a/timm/layers/space_to_depth.py b/timm/layers/space_to_depth.py index b4ba66890c..5867456cdd 100644 --- a/timm/layers/space_to_depth.py +++ b/timm/layers/space_to_depth.py @@ -3,6 +3,8 @@ class SpaceToDepth(nn.Module): + bs: torch.jit.Final[int] + def __init__(self, block_size=4): super().__init__() assert block_size == 4 @@ -12,7 +14,7 @@ def forward(self, x): N, C, H, W = x.size() x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) - x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) + x = x.view(N, C * self.bs * self.bs, H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) return x diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 7b7540dbdf..e65a933f42 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -21,7 +21,6 @@ from .focalnet import * from .gcvit import * from .ghostnet import * -from .gluon_xception import * from .hardcorenas import * from .hrnet import * from .inception_resnet_v2 import * diff --git a/timm/models/coat.py b/timm/models/coat.py index f58d57a70f..b8afbb294e 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -15,43 +15,13 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, LayerNorm from ._builder import build_model_with_cfg -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs __all__ = ['CoaT'] -def _cfg_coat(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed1.proj', 'classifier': 'head', - **kwargs - } - - -default_cfgs = { - 'coat_tiny': _cfg_coat( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_tiny-473c2a20.pth' - ), - 'coat_mini': _cfg_coat( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_mini-2c6baf49.pth' - ), - 'coat_lite_tiny': _cfg_coat( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth' - ), - 'coat_lite_mini': _cfg_coat( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth' - ), - 'coat_lite_small': _cfg_coat( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_small-fea1d5a1.pth' - ), -} - - class ConvRelPosEnc(nn.Module): """ Convolutional relative position encoding. """ def __init__(self, head_chs, num_heads, window): @@ -147,7 +117,7 @@ def forward(self, x, size: Tuple[int, int]): # Generate Q, K, V. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # [B, h, N, Ch] + q, k, v = qkv.unbind(0) # [B, h, N, Ch] # Factorized attention. k_softmax = k.softmax(dim=2) @@ -334,7 +304,12 @@ def interpolate(self, x, scale_factor: float, size: Tuple[int, int]): img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W) img_tokens = F.interpolate( - img_tokens, scale_factor=scale_factor, recompute_scale_factor=False, mode='bilinear', align_corners=False) + img_tokens, + scale_factor=scale_factor, + recompute_scale_factor=False, + mode='bilinear', + align_corners=False, + ) img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2) out = torch.cat((cls_token, img_tokens), dim=1) @@ -384,17 +359,17 @@ def __init__( patch_size=16, in_chans=3, num_classes=1000, - embed_dims=(0, 0, 0, 0), - serial_depths=(0, 0, 0, 0), + embed_dims=(64, 128, 320, 512), + serial_depths=(3, 4, 6, 3), parallel_depth=0, - num_heads=0, - mlp_ratios=(0, 0, 0, 0), + num_heads=8, + mlp_ratios=(4, 4, 4, 4), qkv_bias=True, drop_rate=0., proj_drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - norm_layer=partial(nn.LayerNorm, eps=1e-6), + norm_layer=LayerNorm, return_interm_layers=False, out_features=None, crpe_window=None, @@ -711,6 +686,7 @@ def remove_cls(x): def checkpoint_filter_fn(state_dict, model): out_dict = {} + state_dict = state_dict.get('model', state_dict) for k, v in state_dict.items(): # original model had unused norm layers, removing them requires filtering pretrained checkpoints if k.startswith('norm1') or \ @@ -726,52 +702,100 @@ def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( - CoaT, variant, pretrained, + CoaT, + variant, + pretrained, pretrained_filter_fn=checkpoint_filter_fn, - **kwargs) + **kwargs, + ) return model +def _cfg_coat(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed1.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'coat_tiny.in1k': _cfg_coat(hf_hub_id='timm/'), + 'coat_mini.in1k': _cfg_coat(hf_hub_id='timm/'), + 'coat_small.in1k': _cfg_coat(hf_hub_id='timm/'), + 'coat_lite_tiny.in1k': _cfg_coat(hf_hub_id='timm/'), + 'coat_lite_mini.in1k': _cfg_coat(hf_hub_id='timm/'), + 'coat_lite_small.in1k': _cfg_coat(hf_hub_id='timm/'), + 'coat_lite_medium.in1k': _cfg_coat(hf_hub_id='timm/'), + 'coat_lite_medium_384.in1k': _cfg_coat( + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash', + ), +}) + + @register_model def coat_tiny(pretrained=False, **kwargs): model_cfg = dict( - patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, - num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) - model = _create_coat('coat_tiny', pretrained=pretrained, **model_cfg) + patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6) + model = _create_coat('coat_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs)) return model @register_model def coat_mini(pretrained=False, **kwargs): model_cfg = dict( - patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6, - num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) - model = _create_coat('coat_mini', pretrained=pretrained, **model_cfg) + patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6) + model = _create_coat('coat_mini', pretrained=pretrained, **dict(model_cfg, **kwargs)) + return model + + +@register_model +def coat_small(pretrained=False, **kwargs): + model_cfg = dict( + patch_size=4, embed_dims=[152, 320, 320, 320], serial_depths=[2, 2, 2, 2], parallel_depth=6, **kwargs) + model = _create_coat('coat_small', pretrained=pretrained, **dict(model_cfg, **kwargs)) return model @register_model def coat_lite_tiny(pretrained=False, **kwargs): model_cfg = dict( - patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, - num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - model = _create_coat('coat_lite_tiny', pretrained=pretrained, **model_cfg) + patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], mlp_ratios=[8, 8, 4, 4]) + model = _create_coat('coat_lite_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs)) return model @register_model def coat_lite_mini(pretrained=False, **kwargs): model_cfg = dict( - patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0, - num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - model = _create_coat('coat_lite_mini', pretrained=pretrained, **model_cfg) + patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], mlp_ratios=[8, 8, 4, 4]) + model = _create_coat('coat_lite_mini', pretrained=pretrained, **dict(model_cfg, **kwargs)) return model @register_model def coat_lite_small(pretrained=False, **kwargs): model_cfg = dict( - patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0, - num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - model = _create_coat('coat_lite_small', pretrained=pretrained, **model_cfg) + patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], mlp_ratios=[8, 8, 4, 4]) + model = _create_coat('coat_lite_small', pretrained=pretrained, **dict(model_cfg, **kwargs)) + return model + + +@register_model +def coat_lite_medium(pretrained=False, **kwargs): + model_cfg = dict( + patch_size=4, embed_dims=[128, 256, 320, 512], serial_depths=[3, 6, 10, 8]) + model = _create_coat('coat_lite_medium', pretrained=pretrained, **dict(model_cfg, **kwargs)) + return model + + +@register_model +def coat_lite_medium_384(pretrained=False, **kwargs): + model_cfg = dict( + img_size=384, patch_size=4, embed_dims=[128, 256, 320, 512], serial_depths=[3, 6, 10, 8]) + model = _create_coat('coat_lite_medium_384', pretrained=pretrained, **dict(model_cfg, **kwargs)) return model \ No newline at end of file diff --git a/timm/models/convit.py b/timm/models/convit.py index 1921741862..16e8b5c4e9 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -28,37 +28,16 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp +from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp, LayerNorm from ._builder import build_model_with_cfg from ._features_fx import register_notrace_module -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs from .vision_transformer_hybrid import HybridEmbed __all__ = ['ConViT'] -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', - **kwargs - } - - -default_cfgs = { - # ConViT - 'convit_tiny': _cfg( - url="https://dl.fbaipublicfiles.com/convit/convit_tiny.pth"), - 'convit_small': _cfg( - url="https://dl.fbaipublicfiles.com/convit/convit_small.pth"), - 'convit_base': _cfg( - url="https://dl.fbaipublicfiles.com/convit/convit_base.pth") -} - - @register_notrace_module # reason: FX can't symbolically trace control flow in forward method class GPSA(nn.Module): def __init__( @@ -218,7 +197,7 @@ def __init__( attn_drop=0., drop_path=0., act_layer=nn.GELU, - norm_layer=nn.LayerNorm, + norm_layer=LayerNorm, use_gpsa=True, locality_strength=1., ): @@ -280,7 +259,7 @@ def __init__( attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, - norm_layer=nn.LayerNorm, + norm_layer=LayerNorm, local_up_to_layer=3, locality_strength=1., use_pos_embed=True, @@ -300,7 +279,11 @@ def __init__( hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) else: self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) num_patches = self.patch_embed.num_patches self.num_patches = num_patches @@ -405,28 +388,43 @@ def _create_convit(variant, pretrained=False, **kwargs): return build_model_with_cfg(ConViT, variant, pretrained, **kwargs) +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + # ConViT + 'convit_tiny.fb_in1k': _cfg(hf_hub_id='timm/'), + 'convit_small.fb_in1k': _cfg(hf_hub_id='timm/'), + 'convit_base.fb_in1k': _cfg(hf_hub_id='timm/') +}) + + @register_model def convit_tiny(pretrained=False, **kwargs): model_args = dict( - local_up_to_layer=10, locality_strength=1.0, embed_dim=48, - num_heads=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model = _create_convit(variant='convit_tiny', pretrained=pretrained, **model_args) + local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=4) + model = _create_convit(variant='convit_tiny', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def convit_small(pretrained=False, **kwargs): model_args = dict( - local_up_to_layer=10, locality_strength=1.0, embed_dim=48, - num_heads=9, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model = _create_convit(variant='convit_small', pretrained=pretrained, **model_args) + local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=9) + model = _create_convit(variant='convit_small', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def convit_base(pretrained=False, **kwargs): model_args = dict( - local_up_to_layer=10, locality_strength=1.0, embed_dim=48, - num_heads=16, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model = _create_convit(variant='convit_base', pretrained=pretrained, **model_args) + local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=16) + model = _create_convit(variant='convit_base', pretrained=pretrained, **dict(model_args, **kwargs)) return model diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py index ff9f214325..edb3c3aebd 100644 --- a/timm/models/convmixer.py +++ b/timm/models/convmixer.py @@ -6,31 +6,13 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SelectAdaptivePool2d -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq __all__ = ['ConvMixer'] -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .96, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', - 'first_conv': 'stem.0', - **kwargs - } - - -default_cfgs = { - 'convmixer_1536_20': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1536_20_ks9_p7.pth.tar'), - 'convmixer_768_32': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_768_32_ks7_p7_relu.pth.tar'), - 'convmixer_1024_20_ks9_p14': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1024_20_ks9_p14.pth.tar') -} - - class Residual(nn.Module): def __init__(self, fn): super().__init__() @@ -122,6 +104,25 @@ def _create_convmixer(variant, pretrained=False, **kwargs): return build_model_with_cfg(ConvMixer, variant, pretrained, **kwargs) +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .96, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', + 'first_conv': 'stem.0', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'convmixer_1536_20.in1k': _cfg(hf_hub_id='timm/'), + 'convmixer_768_32.in1k': _cfg(hf_hub_id='timm/'), + 'convmixer_1024_20_ks9_p14.in1k': _cfg(hf_hub_id='timm/') +}) + + + @register_model def convmixer_1536_20(pretrained=False, **kwargs): model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs) diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 5499529198..39709e4426 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -36,56 +36,12 @@ from timm.layers import DropPath, to_2tuple, trunc_normal_, _assert from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs from .vision_transformer import Block __all__ = ['CrossViT'] # model_registry will add each entrypoint fn to this -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'crop_pct': 0.875, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True, - 'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'), - 'classifier': ('head.0', 'head.1'), - **kwargs - } - - -default_cfgs = { - 'crossvit_15_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_224.pth'), - 'crossvit_15_dagger_240': _cfg( - url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_224.pth', - first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), - ), - 'crossvit_15_dagger_408': _cfg( - url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth', - input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0, - ), - 'crossvit_18_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth'), - 'crossvit_18_dagger_240': _cfg( - url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_224.pth', - first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), - ), - 'crossvit_18_dagger_408': _cfg( - url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth', - input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0, - ), - 'crossvit_9_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth'), - 'crossvit_9_dagger_240': _cfg( - url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_dagger_224.pth', - first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), - ), - 'crossvit_base_240': _cfg( - url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_base_224.pth'), - 'crossvit_small_240': _cfg( - url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_small_224.pth'), - 'crossvit_tiny_240': _cfg( - url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_tiny_224.pth'), -} - - class PatchEmbed(nn.Module): """ Image to Patch Embedding """ @@ -531,6 +487,47 @@ def pretrained_filter_fn(state_dict): ) +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'crop_pct': 0.875, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True, + 'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'), + 'classifier': ('head.0', 'head.1'), + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'crossvit_15_240.in1k': _cfg(hf_hub_id='timm/'), + 'crossvit_15_dagger_240.in1k': _cfg( + hf_hub_id='timm/', + first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), + ), + 'crossvit_15_dagger_408.in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0, + ), + 'crossvit_18_240.in1k': _cfg(hf_hub_id='timm/'), + 'crossvit_18_dagger_240.in1k': _cfg( + hf_hub_id='timm/', + first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), + ), + 'crossvit_18_dagger_408.in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0, + ), + 'crossvit_9_240.in1k': _cfg(hf_hub_id='timm/'), + 'crossvit_9_dagger_240.in1k': _cfg( + hf_hub_id='timm/', + first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), + ), + 'crossvit_base_240.in1k': _cfg(hf_hub_id='timm/'), + 'crossvit_small_240.in1k': _cfg(hf_hub_id='timm/'), + 'crossvit_tiny_240.in1k': _cfg(hf_hub_id='timm/'), +}) + + @register_model def crossvit_tiny_240(pretrained=False, **kwargs): model_args = dict( diff --git a/timm/models/dla.py b/timm/models/dla.py index 5231225e74..565798d974 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -1,5 +1,5 @@ """ Deep Layer Aggregation and DLA w/ Res2Net -DLA original adapted from Official Pytorch impl at: +DLA original adapted from Official Pytorch impl at: https://github.com/ucbdrive/dla DLA Paper: `Deep Layer Aggregation` - https://arxiv.org/abs/1707.06484 Res2Net additions from: https://github.com/gasvn/Res2Net/ @@ -15,55 +15,28 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import create_classifier from ._builder import build_model_with_cfg -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs __all__ = ['DLA'] -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bilinear', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'base_layer.0', 'classifier': 'fc', - **kwargs - } - - -default_cfgs = { - 'dla34': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla34-2b83ff04.pth'), - 'dla46_c': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla46_c-9b68d685.pth'), - 'dla46x_c': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla46x_c-6bc5b5c8.pth'), - 'dla60x_c': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla60x_c-a38e054a.pth'), - 'dla60': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla60-9e91bd4d.pth'), - 'dla60x': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla60x-6818f6bb.pth'), - 'dla102': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla102-21f57b54.pth'), - 'dla102x': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla102x-7ec0aa2a.pth'), - 'dla102x2': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla102x2-ac4239c4.pth'), - 'dla169': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla169-7c767967.pth'), - 'dla60_res2net': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net_dla60_4s-d88db7f9.pth'), - 'dla60_res2next': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next_dla60_4s-d327927b.pth'), -} - - class DlaBasic(nn.Module): """DLA Basic""" def __init__(self, inplanes, planes, stride=1, dilation=1, **_): super(DlaBasic, self).__init__() self.conv1 = nn.Conv2d( - inplanes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) + inplanes, planes, kernel_size=3, + stride=stride, padding=dilation, bias=False, dilation=dilation) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d( - planes, planes, kernel_size=3, stride=1, padding=dilation, bias=False, dilation=dilation) + planes, planes, kernel_size=3, + stride=1, padding=dilation, bias=False, dilation=dilation) self.bn2 = nn.BatchNorm2d(planes) self.stride = stride - def forward(self, x, shortcut=None, children: Optional[List[torch.Tensor]] = None): + def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None): if shortcut is None: shortcut = x @@ -93,8 +66,8 @@ def __init__(self, inplanes, outplanes, stride=1, dilation=1, cardinality=1, bas self.conv1 = nn.Conv2d(inplanes, mid_planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(mid_planes) self.conv2 = nn.Conv2d( - mid_planes, mid_planes, kernel_size=3, stride=stride, padding=dilation, - bias=False, dilation=dilation, groups=cardinality) + mid_planes, mid_planes, kernel_size=3, + stride=stride, padding=dilation, bias=False, dilation=dilation, groups=cardinality) self.bn2 = nn.BatchNorm2d(mid_planes) self.conv3 = nn.Conv2d(mid_planes, outplanes, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(outplanes) @@ -143,8 +116,8 @@ def __init__(self, inplanes, outplanes, stride=1, dilation=1, scale=4, cardinali bns = [] for _ in range(num_scale_convs): convs.append(nn.Conv2d( - mid_planes, mid_planes, kernel_size=3, stride=stride, - padding=dilation, dilation=dilation, groups=cardinality, bias=False)) + mid_planes, mid_planes, kernel_size=3, + stride=stride, padding=dilation, dilation=dilation, groups=cardinality, bias=False)) bns.append(nn.BatchNorm2d(mid_planes)) self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList(bns) @@ -211,8 +184,20 @@ def forward(self, x_children: List[torch.Tensor]): class DlaTree(nn.Module): def __init__( - self, levels, block, in_channels, out_channels, stride=1, dilation=1, cardinality=1, - base_width=64, level_root=False, root_dim=0, root_kernel_size=1, root_shortcut=False): + self, + levels, + block, + in_channels, + out_channels, + stride=1, + dilation=1, + cardinality=1, + base_width=64, + level_root=False, + root_dim=0, + root_kernel_size=1, + root_shortcut=False, + ): super(DlaTree, self).__init__() if root_dim == 0: root_dim = 2 * out_channels @@ -235,9 +220,22 @@ def __init__( else: cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut)) self.tree1 = DlaTree( - levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs) + levels - 1, + block, + in_channels, + out_channels, + stride, + root_dim=0, + **cargs, + ) self.tree2 = DlaTree( - levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs) + levels - 1, + block, + out_channels, + out_channels, + root_dim=root_dim + out_channels, + **cargs, + ) self.root = None self.level_root = level_root self.root_dim = root_dim @@ -262,20 +260,31 @@ def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional class DLA(nn.Module): def __init__( - self, levels, channels, output_stride=32, num_classes=1000, in_chans=3, global_pool='avg', - cardinality=1, base_width=64, block=DlaBottle2neck, shortcut_root=False, drop_rate=0.0): + self, + levels, + channels, + output_stride=32, + num_classes=1000, + in_chans=3, + global_pool='avg', + cardinality=1, + base_width=64, + block=DlaBottle2neck, + shortcut_root=False, + drop_rate=0.0, + ): super(DLA, self).__init__() self.channels = channels self.num_classes = num_classes self.cardinality = cardinality self.base_width = base_width - self.drop_rate = drop_rate assert output_stride == 32 # FIXME support dilation self.base_layer = nn.Sequential( nn.Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False), nn.BatchNorm2d(channels[0]), - nn.ReLU(inplace=True)) + nn.ReLU(inplace=True), + ) self.level0 = self._make_conv_level(channels[0], channels[0], levels[0]) self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2) cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root) @@ -293,8 +302,13 @@ def __init__( ] self.num_features = channels[-1] - self.global_pool, self.fc = create_classifier( - self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.global_pool, self.head_drop, self.fc = create_classifier( + self.num_features, + self.num_classes, + pool_type=global_pool, + use_conv=True, + drop_rate=drop_rate, + ) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() for m in self.modules(): @@ -310,7 +324,8 @@ def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): for i in range(convs): modules.extend([ nn.Conv2d( - inplanes, planes, kernel_size=3, stride=stride if i == 0 else 1, + inplanes, planes, kernel_size=3, + stride=stride if i == 0 else 1, padding=dilation, bias=False, dilation=dilation), nn.BatchNorm2d(planes), nn.ReLU(inplace=True)]) @@ -356,8 +371,7 @@ def forward_features(self, x): def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.head_drop(x) if pre_logits: return self.flatten(x) x = self.fc(x) @@ -371,103 +385,131 @@ def forward(self, x): def _create_dla(variant, pretrained=False, **kwargs): return build_model_with_cfg( - DLA, variant, pretrained, + DLA, + variant, + pretrained, pretrained_strict=False, feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)), - **kwargs) + **kwargs, + ) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'base_layer.0', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'dla34.in1k': _cfg(hf_hub_id='timm/'), + 'dla46_c.in1k': _cfg(hf_hub_id='timm/'), + 'dla46x_c.in1k': _cfg(hf_hub_id='timm/'), + 'dla60x_c.in1k': _cfg(hf_hub_id='timm/'), + 'dla60.in1k': _cfg(hf_hub_id='timm/'), + 'dla60x.in1k': _cfg(hf_hub_id='timm/'), + 'dla102.in1k': _cfg(hf_hub_id='timm/'), + 'dla102x.in1k': _cfg(hf_hub_id='timm/'), + 'dla102x2.in1k': _cfg(hf_hub_id='timm/'), + 'dla169': _cfg(hf_hub_id='timm/'), + 'dla60_res2net.in1k': _cfg(hf_hub_id='timm/'), + 'dla60_res2next.in1k': _cfg(hf_hub_id='timm/'), +}) @register_model def dla60_res2net(pretrained=False, **kwargs): - model_kwargs = dict( + model_args = dict( levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024), - block=DlaBottle2neck, cardinality=1, base_width=28, **kwargs) - return _create_dla('dla60_res2net', pretrained, **model_kwargs) + block=DlaBottle2neck, cardinality=1, base_width=28) + return _create_dla('dla60_res2net', pretrained, **dict(model_args, **kwargs)) @register_model def dla60_res2next(pretrained=False,**kwargs): - model_kwargs = dict( + model_args = dict( levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024), - block=DlaBottle2neck, cardinality=8, base_width=4, **kwargs) - return _create_dla('dla60_res2next', pretrained, **model_kwargs) + block=DlaBottle2neck, cardinality=8, base_width=4) + return _create_dla('dla60_res2next', pretrained, **dict(model_args, **kwargs)) @register_model def dla34(pretrained=False, **kwargs): # DLA-34 - model_kwargs = dict( - levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 128, 256, 512], - block=DlaBasic, **kwargs) - return _create_dla('dla34', pretrained, **model_kwargs) + model_args = dict( + levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 128, 256, 512], block=DlaBasic) + return _create_dla('dla34', pretrained, **dict(model_args, **kwargs)) @register_model def dla46_c(pretrained=False, **kwargs): # DLA-46-C - model_kwargs = dict( - levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], - block=DlaBottleneck, **kwargs) - return _create_dla('dla46_c', pretrained, **model_kwargs) + model_args = dict( + levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], block=DlaBottleneck) + return _create_dla('dla46_c', pretrained, **dict(model_args, **kwargs)) @register_model def dla46x_c(pretrained=False, **kwargs): # DLA-X-46-C - model_kwargs = dict( + model_args = dict( levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], - block=DlaBottleneck, cardinality=32, base_width=4, **kwargs) - return _create_dla('dla46x_c', pretrained, **model_kwargs) + block=DlaBottleneck, cardinality=32, base_width=4) + return _create_dla('dla46x_c', pretrained, **dict(model_args, **kwargs)) @register_model def dla60x_c(pretrained=False, **kwargs): # DLA-X-60-C - model_kwargs = dict( + model_args = dict( levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 64, 64, 128, 256], - block=DlaBottleneck, cardinality=32, base_width=4, **kwargs) - return _create_dla('dla60x_c', pretrained, **model_kwargs) + block=DlaBottleneck, cardinality=32, base_width=4) + return _create_dla('dla60x_c', pretrained, **dict(model_args, **kwargs)) @register_model def dla60(pretrained=False, **kwargs): # DLA-60 - model_kwargs = dict( + model_args = dict( levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, **kwargs) - return _create_dla('dla60', pretrained, **model_kwargs) + block=DlaBottleneck) + return _create_dla('dla60', pretrained, **dict(model_args, **kwargs)) @register_model def dla60x(pretrained=False, **kwargs): # DLA-X-60 - model_kwargs = dict( + model_args = dict( levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, cardinality=32, base_width=4, **kwargs) - return _create_dla('dla60x', pretrained, **model_kwargs) + block=DlaBottleneck, cardinality=32, base_width=4) + return _create_dla('dla60x', pretrained, **dict(model_args, **kwargs)) @register_model def dla102(pretrained=False, **kwargs): # DLA-102 - model_kwargs = dict( + model_args = dict( levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, shortcut_root=True, **kwargs) - return _create_dla('dla102', pretrained, **model_kwargs) + block=DlaBottleneck, shortcut_root=True) + return _create_dla('dla102', pretrained, **dict(model_args, **kwargs)) @register_model def dla102x(pretrained=False, **kwargs): # DLA-X-102 - model_kwargs = dict( + model_args = dict( levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True, **kwargs) - return _create_dla('dla102x', pretrained, **model_kwargs) + block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True) + return _create_dla('dla102x', pretrained, **dict(model_args, **kwargs)) @register_model def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64 - model_kwargs = dict( + model_args = dict( levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True, **kwargs) - return _create_dla('dla102x2', pretrained, **model_kwargs) + block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True) + return _create_dla('dla102x2', pretrained, **dict(model_args, **kwargs)) @register_model def dla169(pretrained=False, **kwargs): # DLA-169 - model_kwargs = dict( + model_args = dict( levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024], - block=DlaBottleneck, shortcut_root=True, **kwargs) - return _create_dla('dla169', pretrained, **model_kwargs) + block=DlaBottleneck, shortcut_root=True) + return _create_dla('dla169', pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index e770a5be1d..c597b7ed12 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -17,7 +17,8 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d +from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d, \ + use_fused_attn from ._builder import build_model_with_cfg from ._features_fx import register_notrace_module from ._manipulate import named_apply, checkpoint_seq @@ -37,14 +38,16 @@ def __init__(self, hidden_dim=32, dim=768, temperature=10000): self.dim = dim def forward(self, shape: Tuple[int, int, int]): - inv_mask = ~torch.zeros(shape).to(device=self.token_projection.weight.device, dtype=torch.bool) - y_embed = inv_mask.cumsum(1, dtype=torch.float32) - x_embed = inv_mask.cumsum(2, dtype=torch.float32) + device = self.token_projection.weight.device + dtype = self.token_projection.weight.dtype + inv_mask = ~torch.zeros(shape).to(device=device, dtype=torch.bool) + y_embed = inv_mask.cumsum(1, dtype=dtype) + x_embed = inv_mask.cumsum(2, dtype=dtype) eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=inv_mask.device) + dim_t = torch.arange(self.hidden_dim, dtype=dtype, device=device) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim) pos_x = x_embed[:, :, :, None] / dim_t @@ -129,9 +132,9 @@ def forward(self, x): attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) + x = (attn @ v) - x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) - + x = x.permute(0, 3, 1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -494,25 +497,25 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ 'edgenext_xx_small.in1k': _cfg( - url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth", + hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=1.0), 'edgenext_x_small.in1k': _cfg( - url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth", + hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=1.0), 'edgenext_small.usi_in1k': _cfg( # USI weights - url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth", + hf_hub_id='timm/', crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, ), 'edgenext_base.usi_in1k': _cfg( # USI weights - url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth", + hf_hub_id='timm/', crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, ), 'edgenext_base.in21k_ft_in1k': _cfg( # USI weights - url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.21/edgenext_base_IN21K.pth", + hf_hub_id='timm/', crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, ), 'edgenext_small_rw.sw_in1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth', + hf_hub_id='timm/', test_input_size=(3, 320, 320), test_crop_pct=1.0, ), }) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 221a133d11..3a9fc13a6a 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -961,10 +961,10 @@ def _cfg(url='', **kwargs): url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b4_ra2_320-7eb33cd5.pth', hf_hub_id='timm/', input_size=(3, 320, 320), pool_size=(10, 10), test_input_size=(3, 384, 384), crop_pct=1.0), - 'efficientnet_b5.in12k_ft_in1k': _cfg( + 'efficientnet_b5.sw_in12k_ft_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, crop_mode='squash'), - 'efficientnet_b5.in12k': _cfg( + 'efficientnet_b5.sw_in12k': _cfg( hf_hub_id='timm/', input_size=(3, 416, 416), pool_size=(13, 13), crop_pct=0.95, num_classes=11821), 'efficientnet_b6.untrained': _cfg( @@ -1149,6 +1149,19 @@ def _cfg(url='', **kwargs): mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + 'tf_efficientnet_b5.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', + hf_hub_id='timm/', + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b7.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', + hf_hub_id='timm/', + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_b8.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', + hf_hub_id='timm/', + input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + 'tf_efficientnet_b0.aa_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', hf_hub_id='timm/', @@ -1169,22 +1182,44 @@ def _cfg(url='', **kwargs): url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', hf_hub_id='timm/', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), - 'tf_efficientnet_b5.ra_in1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', + 'tf_efficientnet_b5.aa_in1k': _cfg( + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_aa-99018a74.pth', hf_hub_id='timm/', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), 'tf_efficientnet_b6.aa_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', hf_hub_id='timm/', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), - 'tf_efficientnet_b7.ra_in1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', + 'tf_efficientnet_b7.aa_in1k': _cfg( + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_aa-076e3472.pth', hf_hub_id='timm/', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), - 'tf_efficientnet_b8.ra_in1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', - hf_hub_id='timm/', - input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + + 'tf_efficientnet_b0.in1k': _cfg( + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth', + #hf_hub_id='timm/', + input_size=(3, 224, 224)), + 'tf_efficientnet_b1.in1k': _cfg( + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4.pth', + #hf_hub_id='timm/', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2.in1k': _cfg( + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04.pth', + #hf_hub_id='timm/', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3.in1k': _cfg( + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth', + #hf_hub_id='timm/', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4.in1k': _cfg( + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth', + #hf_hub_id='timm/', + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5.in1k': _cfg( + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth', + #hf_hub_id='timm/', + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_es.in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 492049b9e7..4838ee6261 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -16,63 +16,62 @@ from ._builder import build_model_with_cfg from ._efficientnet_blocks import SqueezeExcite, ConvBnAct from ._manipulate import checkpoint_seq -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs __all__ = ['GhostNet'] -def _cfg(url='', **kwargs): - return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bilinear', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'conv_stem', 'classifier': 'classifier', - **kwargs - } - - -default_cfgs = { - 'ghostnet_050': _cfg(url=''), - 'ghostnet_100': _cfg( - url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'), - 'ghostnet_130': _cfg(url=''), -} - - _SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4)) class GhostModule(nn.Module): - def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True): + def __init__( + self, + in_chs, + out_chs, + kernel_size=1, + ratio=2, + dw_size=3, + stride=1, + relu=True, + ): super(GhostModule, self).__init__() - self.oup = oup - init_channels = math.ceil(oup / ratio) - new_channels = init_channels * (ratio - 1) + self.out_chs = out_chs + init_chs = math.ceil(out_chs / ratio) + new_chs = init_chs * (ratio - 1) self.primary_conv = nn.Sequential( - nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), - nn.BatchNorm2d(init_channels), - nn.ReLU(inplace=True) if relu else nn.Sequential(), + nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False), + nn.BatchNorm2d(init_chs), + nn.ReLU(inplace=True) if relu else nn.Identity(), ) self.cheap_operation = nn.Sequential( - nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), - nn.BatchNorm2d(new_channels), - nn.ReLU(inplace=True) if relu else nn.Sequential(), + nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size//2, groups=init_chs, bias=False), + nn.BatchNorm2d(new_chs), + nn.ReLU(inplace=True) if relu else nn.Identity(), ) def forward(self, x): x1 = self.primary_conv(x) x2 = self.cheap_operation(x1) out = torch.cat([x1, x2], dim=1) - return out[:, :self.oup, :, :] + return out[:, :self.out_chs, :, :] class GhostBottleneck(nn.Module): """ Ghost bottleneck w/ optional SE""" - def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3, - stride=1, act_layer=nn.ReLU, se_ratio=0.): + def __init__( + self, + in_chs, + mid_chs, + out_chs, + dw_kernel_size=3, + stride=1, + act_layer=nn.ReLU, + se_ratio=0., + ): super(GhostBottleneck, self).__init__() has_se = se_ratio is not None and se_ratio > 0. self.stride = stride @@ -133,7 +132,15 @@ def forward(self, x): class GhostNet(nn.Module): def __init__( - self, cfgs, num_classes=1000, width=1.0, in_chans=3, output_stride=32, global_pool='avg', drop_rate=0.2): + self, + cfgs, + num_classes=1000, + width=1.0, + in_chans=3, + output_stride=32, + global_pool='avg', + drop_rate=0.2, + ): super(GhostNet, self).__init__() # setting of inverted residual blocks assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported' @@ -275,9 +282,30 @@ def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs): **kwargs, ) return build_model_with_cfg( - GhostNet, variant, pretrained, + GhostNet, + variant, + pretrained, feature_cfg=dict(flatten_sequential=True), - **model_kwargs) + **model_kwargs, + ) + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'ghostnet_050.untrained': _cfg(), + 'ghostnet_100.in1k': _cfg( + url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'), + 'ghostnet_130.untrained': _cfg(), +}) @register_model diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py deleted file mode 100644 index b487d0fd18..0000000000 --- a/timm/models/gluon_xception.py +++ /dev/null @@ -1,267 +0,0 @@ -"""Pytorch impl of Gluon Xception -This is a port of the Gluon Xception code and weights, itself ported from a PyTorch DeepLab impl. - -Gluon model: (https://gluon-cv.mxnet.io/_modules/gluoncv/model_zoo/xception.html) -Original PyTorch DeepLab impl: https://github.com/jfzhang95/pytorch-deeplab-xception - -Hacked together by / Copyright 2020 Ross Wightman -""" -from collections import OrderedDict - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import create_classifier, get_padding -from ._builder import build_model_with_cfg -from ._registry import register_model - -__all__ = ['Xception65'] - -default_cfgs = { - 'gluon_xception65': { - 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth', - 'input_size': (3, 299, 299), - 'crop_pct': 0.903, - 'pool_size': (10, 10), - 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, - 'std': IMAGENET_DEFAULT_STD, - 'num_classes': 1000, - 'first_conv': 'conv1', - 'classifier': 'fc' - # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 - }, -} - -""" PADDING NOTES -The original PyTorch and Gluon impl of these models dutifully reproduced the -aligned padding added to Tensorflow models for Deeplab. This padding was compensating -for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd want it to. -""" - - -class SeparableConv2d(nn.Module): - def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None): - super(SeparableConv2d, self).__init__() - self.kernel_size = kernel_size - self.dilation = dilation - - # depthwise convolution - padding = get_padding(kernel_size, stride, dilation) - self.conv_dw = nn.Conv2d( - inplanes, inplanes, kernel_size, stride=stride, - padding=padding, dilation=dilation, groups=inplanes, bias=bias) - self.bn = norm_layer(num_features=inplanes) - # pointwise convolution - self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias) - - def forward(self, x): - x = self.conv_dw(x) - x = self.bn(x) - x = self.conv_pw(x) - return x - - -class Block(nn.Module): - def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, norm_layer=None): - super(Block, self).__init__() - if isinstance(planes, (list, tuple)): - assert len(planes) == 3 - else: - planes = (planes,) * 3 - outplanes = planes[-1] - - if outplanes != inplanes or stride != 1: - self.skip = nn.Sequential() - self.skip.add_module('conv1', nn.Conv2d( - inplanes, outplanes, 1, stride=stride, bias=False)), - self.skip.add_module('bn1', norm_layer(num_features=outplanes)) - else: - self.skip = None - - rep = OrderedDict() - for i in range(3): - rep['act%d' % (i + 1)] = nn.ReLU(inplace=True) - rep['conv%d' % (i + 1)] = SeparableConv2d( - inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, norm_layer=norm_layer) - rep['bn%d' % (i + 1)] = norm_layer(planes[i]) - inplanes = planes[i] - - if not start_with_relu: - del rep['act1'] - else: - rep['act1'] = nn.ReLU(inplace=False) - self.rep = nn.Sequential(rep) - - def forward(self, x): - skip = x - if self.skip is not None: - skip = self.skip(skip) - x = self.rep(x) + skip - return x - - -class Xception65(nn.Module): - """Modified Aligned Xception. - - NOTE: only the 65 layer version is included here, the 71 layer variant - was not correct and had no pretrained weights - """ - - def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d, - drop_rate=0., global_pool='avg'): - super(Xception65, self).__init__() - self.num_classes = num_classes - self.drop_rate = drop_rate - if output_stride == 32: - entry_block3_stride = 2 - exit_block20_stride = 2 - middle_dilation = 1 - exit_dilation = (1, 1) - elif output_stride == 16: - entry_block3_stride = 2 - exit_block20_stride = 1 - middle_dilation = 1 - exit_dilation = (1, 2) - elif output_stride == 8: - entry_block3_stride = 1 - exit_block20_stride = 1 - middle_dilation = 2 - exit_dilation = (2, 4) - else: - raise NotImplementedError - - # Entry flow - self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False) - self.bn1 = norm_layer(num_features=32) - self.act1 = nn.ReLU(inplace=True) - - self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False) - self.bn2 = norm_layer(num_features=64) - self.act2 = nn.ReLU(inplace=True) - - self.block1 = Block(64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer) - self.block1_act = nn.ReLU(inplace=True) - self.block2 = Block(128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer) - self.block3 = Block(256, 728, stride=entry_block3_stride, norm_layer=norm_layer) - - # Middle flow - self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block( - 728, 728, stride=1, dilation=middle_dilation, norm_layer=norm_layer)) for i in range(4, 20)])) - - # Exit flow - self.block20 = Block( - 728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_dilation[0], norm_layer=norm_layer) - self.block20_act = nn.ReLU(inplace=True) - - self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer) - self.bn3 = norm_layer(num_features=1536) - self.act3 = nn.ReLU(inplace=True) - - self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer) - self.bn4 = norm_layer(num_features=1536) - self.act4 = nn.ReLU(inplace=True) - - self.num_features = 2048 - self.conv5 = SeparableConv2d( - 1536, self.num_features, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer) - self.bn5 = norm_layer(num_features=self.num_features) - self.act5 = nn.ReLU(inplace=True) - self.feature_info = [ - dict(num_chs=64, reduction=2, module='act2'), - dict(num_chs=128, reduction=4, module='block1_act'), - dict(num_chs=256, reduction=8, module='block3.rep.act1'), - dict(num_chs=728, reduction=16, module='block20.rep.act1'), - dict(num_chs=2048, reduction=32, module='act5'), - ] - - self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) - - @torch.jit.ignore - def group_matcher(self, coarse=False): - matcher = dict( - stem=r'^conv[12]|bn[12]', - blocks=[ - (r'^mid\.block(\d+)', None), - (r'^block(\d+)', None), - (r'^conv[345]|bn[345]', (99,)), - ], - ) - return matcher - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - assert not enable, "gradient checkpointing not supported" - - @torch.jit.ignore - def get_classifier(self): - return self.fc - - def reset_classifier(self, num_classes, global_pool='avg'): - self.num_classes = num_classes - self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) - - def forward_features(self, x): - # Entry flow - x = self.conv1(x) - x = self.bn1(x) - x = self.act1(x) - - x = self.conv2(x) - x = self.bn2(x) - x = self.act2(x) - - x = self.block1(x) - x = self.block1_act(x) - # c1 = x - x = self.block2(x) - # c2 = x - x = self.block3(x) - - # Middle flow - x = self.mid(x) - # c3 = x - - # Exit flow - x = self.block20(x) - x = self.block20_act(x) - x = self.conv3(x) - x = self.bn3(x) - x = self.act3(x) - - x = self.conv4(x) - x = self.bn4(x) - x = self.act4(x) - - x = self.conv5(x) - x = self.bn5(x) - x = self.act5(x) - return x - - def forward_head(self, x): - x = self.global_pool(x) - if self.drop_rate: - F.dropout(x, self.drop_rate, training=self.training) - x = self.fc(x) - return x - - def forward(self, x): - x = self.forward_features(x) - x = self.forward_head(x) - return x - - -def _create_gluon_xception(variant, pretrained=False, **kwargs): - return build_model_with_cfg( - Xception65, variant, pretrained, - feature_cfg=dict(feature_cls='hook'), - **kwargs) - - -@register_model -def gluon_xception65(pretrained=False, **kwargs): - """ Modified Aligned Xception-65 - """ - return _create_gluon_xception('gluon_xception65', pretrained, **kwargs) diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 338d409edd..db75bc0f3b 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -19,8 +19,8 @@ from timm.layers import create_classifier from ._builder import build_model_with_cfg, pretrained_cfg_for_features from ._features import FeatureInfo -from ._registry import register_model -from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE +from ._registry import register_model, generate_default_cfgs +from .resnet import BasicBlock, Bottleneck # leveraging ResNet block_types w/ additional features like SE __all__ = ['HighResolutionNet', 'HighResolutionNetFeatures'] # model_registry will add each entrypoint fn to this @@ -28,371 +28,352 @@ _logger = logging.getLogger(__name__) -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bilinear', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'conv1', 'classifier': 'classifier', - **kwargs - } - - -default_cfgs = { - 'hrnet_w18_small': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v1-f460c6bc.pth'), - 'hrnet_w18_small_v2': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v2-4c50a8cb.pth'), - 'hrnet_w18': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w18-8cb57bb9.pth'), - 'hrnet_w30': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w30-8d7f8dab.pth'), - 'hrnet_w32': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w32-90d8c5fb.pth'), - 'hrnet_w40': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w40-7cd397a4.pth'), - 'hrnet_w44': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w44-c9ac8c18.pth'), - 'hrnet_w48': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w48-abd2e6ab.pth'), - 'hrnet_w64': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w64-b47cc881.pth'), -} - cfg_cls = dict( hrnet_w18_small=dict( - STEM_WIDTH=64, - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(1,), - NUM_CHANNELS=(32,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(2, 2), - NUM_CHANNELS=(16, 32), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=1, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(2, 2, 2), - NUM_CHANNELS=(16, 32, 64), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=1, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(2, 2, 2, 2), - NUM_CHANNELS=(16, 32, 64, 128), - FUSE_METHOD='SUM', + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(1,), + num_channels=(32,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(2, 2), + num_channels=(16, 32), + fuse_method='SUM' + ), + stage3=dict( + num_modules=1, + num_branches=3, + block_type='BASIC', + num_blocks=(2, 2, 2), + num_channels=(16, 32, 64), + fuse_method='SUM' + ), + stage4=dict( + num_modules=1, + num_branches=4, + block_type='BASIC', + num_blocks=(2, 2, 2, 2), + num_channels=(16, 32, 64, 128), + fuse_method='SUM', ), ), hrnet_w18_small_v2=dict( - STEM_WIDTH=64, - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(2,), - NUM_CHANNELS=(64,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(2, 2), - NUM_CHANNELS=(18, 36), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=3, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(2, 2, 2), - NUM_CHANNELS=(18, 36, 72), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=2, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(2, 2, 2, 2), - NUM_CHANNELS=(18, 36, 72, 144), - FUSE_METHOD='SUM', + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(2,), + num_channels=(64,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(2, 2), + num_channels=(18, 36), + fuse_method='SUM' + ), + stage3=dict( + num_modules=3, + num_branches=3, + block_type='BASIC', + num_blocks=(2, 2, 2), + num_channels=(18, 36, 72), + fuse_method='SUM' + ), + stage4=dict( + num_modules=2, + num_branches=4, + block_type='BASIC', + num_blocks=(2, 2, 2, 2), + num_channels=(18, 36, 72, 144), + fuse_method='SUM', ), ), hrnet_w18=dict( - STEM_WIDTH=64, - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(4,), - NUM_CHANNELS=(64,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4), - NUM_CHANNELS=(18, 36), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=4, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4), - NUM_CHANNELS=(18, 36, 72), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=3, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4, 4), - NUM_CHANNELS=(18, 36, 72, 144), - FUSE_METHOD='SUM', + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(4,), + num_channels=(64,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(4, 4), + num_channels=(18, 36), + fuse_method='SUM' + ), + stage3=dict( + num_modules=4, + num_branches=3, + block_type='BASIC', + num_blocks=(4, 4, 4), + num_channels=(18, 36, 72), + fuse_method='SUM' + ), + stage4=dict( + num_modules=3, + num_branches=4, + block_type='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(18, 36, 72, 144), + fuse_method='SUM', ), ), hrnet_w30=dict( - STEM_WIDTH=64, - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(4,), - NUM_CHANNELS=(64,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4), - NUM_CHANNELS=(30, 60), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=4, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4), - NUM_CHANNELS=(30, 60, 120), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=3, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4, 4), - NUM_CHANNELS=(30, 60, 120, 240), - FUSE_METHOD='SUM', + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(4,), + num_channels=(64,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(4, 4), + num_channels=(30, 60), + fuse_method='SUM' + ), + stage3=dict( + num_modules=4, + num_branches=3, + block_type='BASIC', + num_blocks=(4, 4, 4), + num_channels=(30, 60, 120), + fuse_method='SUM' + ), + stage4=dict( + num_modules=3, + num_branches=4, + block_type='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(30, 60, 120, 240), + fuse_method='SUM', ), ), hrnet_w32=dict( - STEM_WIDTH=64, - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(4,), - NUM_CHANNELS=(64,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4), - NUM_CHANNELS=(32, 64), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=4, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4), - NUM_CHANNELS=(32, 64, 128), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=3, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4, 4), - NUM_CHANNELS=(32, 64, 128, 256), - FUSE_METHOD='SUM', + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(4,), + num_channels=(64,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(4, 4), + num_channels=(32, 64), + fuse_method='SUM' + ), + stage3=dict( + num_modules=4, + num_branches=3, + block_type='BASIC', + num_blocks=(4, 4, 4), + num_channels=(32, 64, 128), + fuse_method='SUM' + ), + stage4=dict( + num_modules=3, + num_branches=4, + block_type='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(32, 64, 128, 256), + fuse_method='SUM', ), ), hrnet_w40=dict( - STEM_WIDTH=64, - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(4,), - NUM_CHANNELS=(64,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4), - NUM_CHANNELS=(40, 80), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=4, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4), - NUM_CHANNELS=(40, 80, 160), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=3, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4, 4), - NUM_CHANNELS=(40, 80, 160, 320), - FUSE_METHOD='SUM', + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(4,), + num_channels=(64,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(4, 4), + num_channels=(40, 80), + fuse_method='SUM' + ), + stage3=dict( + num_modules=4, + num_branches=3, + block_type='BASIC', + num_blocks=(4, 4, 4), + num_channels=(40, 80, 160), + fuse_method='SUM' + ), + stage4=dict( + num_modules=3, + num_branches=4, + block_type='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(40, 80, 160, 320), + fuse_method='SUM', ), ), hrnet_w44=dict( - STEM_WIDTH=64, - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(4,), - NUM_CHANNELS=(64,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4), - NUM_CHANNELS=(44, 88), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=4, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4), - NUM_CHANNELS=(44, 88, 176), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=3, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4, 4), - NUM_CHANNELS=(44, 88, 176, 352), - FUSE_METHOD='SUM', + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(4,), + num_channels=(64,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(4, 4), + num_channels=(44, 88), + fuse_method='SUM' + ), + stage3=dict( + num_modules=4, + num_branches=3, + block_type='BASIC', + num_blocks=(4, 4, 4), + num_channels=(44, 88, 176), + fuse_method='SUM' + ), + stage4=dict( + num_modules=3, + num_branches=4, + block_type='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(44, 88, 176, 352), + fuse_method='SUM', ), ), hrnet_w48=dict( - STEM_WIDTH=64, - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(4,), - NUM_CHANNELS=(64,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4), - NUM_CHANNELS=(48, 96), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=4, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4), - NUM_CHANNELS=(48, 96, 192), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=3, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4, 4), - NUM_CHANNELS=(48, 96, 192, 384), - FUSE_METHOD='SUM', + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(4,), + num_channels=(64,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(4, 4), + num_channels=(48, 96), + fuse_method='SUM' + ), + stage3=dict( + num_modules=4, + num_branches=3, + block_type='BASIC', + num_blocks=(4, 4, 4), + num_channels=(48, 96, 192), + fuse_method='SUM' + ), + stage4=dict( + num_modules=3, + num_branches=4, + block_type='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(48, 96, 192, 384), + fuse_method='SUM', ), ), hrnet_w64=dict( - STEM_WIDTH=64, - STAGE1=dict( - NUM_MODULES=1, - NUM_BRANCHES=1, - BLOCK='BOTTLENECK', - NUM_BLOCKS=(4,), - NUM_CHANNELS=(64,), - FUSE_METHOD='SUM', - ), - STAGE2=dict( - NUM_MODULES=1, - NUM_BRANCHES=2, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4), - NUM_CHANNELS=(64, 128), - FUSE_METHOD='SUM' - ), - STAGE3=dict( - NUM_MODULES=4, - NUM_BRANCHES=3, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4), - NUM_CHANNELS=(64, 128, 256), - FUSE_METHOD='SUM' - ), - STAGE4=dict( - NUM_MODULES=3, - NUM_BRANCHES=4, - BLOCK='BASIC', - NUM_BLOCKS=(4, 4, 4, 4), - NUM_CHANNELS=(64, 128, 256, 512), - FUSE_METHOD='SUM', + stem_width=64, + stage1=dict( + num_modules=1, + num_branches=1, + block_type='BOTTLENECK', + num_blocks=(4,), + num_channels=(64,), + fuse_method='SUM', + ), + stage2=dict( + num_modules=1, + num_branches=2, + block_type='BASIC', + num_blocks=(4, 4), + num_channels=(64, 128), + fuse_method='SUM' + ), + stage3=dict( + num_modules=4, + num_branches=3, + block_type='BASIC', + num_blocks=(4, 4, 4), + num_channels=(64, 128, 256), + fuse_method='SUM' + ), + stage4=dict( + num_modules=3, + num_branches=4, + block_type='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(64, 128, 256, 512), + fuse_method='SUM', ), ) ) class HighResolutionModule(nn.Module): - def __init__(self, num_branches, blocks, num_blocks, num_in_chs, - num_channels, fuse_method, multi_scale_output=True): + def __init__( + self, + num_branches, + block_types, + num_blocks, + num_in_chs, + num_channels, + fuse_method, + multi_scale_output=True, + ): super(HighResolutionModule, self).__init__() self._check_branches( - num_branches, blocks, num_blocks, num_in_chs, num_channels) + num_branches, + block_types, + num_blocks, + num_in_chs, + num_channels, + ) self.num_in_chs = num_in_chs self.fuse_method = fuse_method @@ -401,43 +382,47 @@ def __init__(self, num_branches, blocks, num_blocks, num_in_chs, self.multi_scale_output = multi_scale_output self.branches = self._make_branches( - num_branches, blocks, num_blocks, num_channels) + num_branches, + block_types, + num_blocks, + num_channels, + ) self.fuse_layers = self._make_fuse_layers() self.fuse_act = nn.ReLU(False) - def _check_branches(self, num_branches, blocks, num_blocks, num_in_chs, num_channels): + def _check_branches(self, num_branches, block_types, num_blocks, num_in_chs, num_channels): error_msg = '' if num_branches != len(num_blocks): - error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(num_branches, len(num_blocks)) + error_msg = 'num_branches({}) <> num_blocks({})'.format(num_branches, len(num_blocks)) elif num_branches != len(num_channels): - error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(num_branches, len(num_channels)) + error_msg = 'num_branches({}) <> num_channels({})'.format(num_branches, len(num_channels)) elif num_branches != len(num_in_chs): - error_msg = 'NUM_BRANCHES({}) <> num_in_chs({})'.format(num_branches, len(num_in_chs)) + error_msg = 'num_branches({}) <> num_in_chs({})'.format(num_branches, len(num_in_chs)) if error_msg: _logger.error(error_msg) raise ValueError(error_msg) - def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): + def _make_one_branch(self, branch_index, block_type, num_blocks, num_channels, stride=1): downsample = None - if stride != 1 or self.num_in_chs[branch_index] != num_channels[branch_index] * block.expansion: + if stride != 1 or self.num_in_chs[branch_index] != num_channels[branch_index] * block_type.expansion: downsample = nn.Sequential( nn.Conv2d( - self.num_in_chs[branch_index], num_channels[branch_index] * block.expansion, + self.num_in_chs[branch_index], num_channels[branch_index] * block_type.expansion, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=_BN_MOMENTUM), + nn.BatchNorm2d(num_channels[branch_index] * block_type.expansion, momentum=_BN_MOMENTUM), ) - layers = [block(self.num_in_chs[branch_index], num_channels[branch_index], stride, downsample)] - self.num_in_chs[branch_index] = num_channels[branch_index] * block.expansion + layers = [block_type(self.num_in_chs[branch_index], num_channels[branch_index], stride, downsample)] + self.num_in_chs[branch_index] = num_channels[branch_index] * block_type.expansion for i in range(1, num_blocks[branch_index]): - layers.append(block(self.num_in_chs[branch_index], num_channels[branch_index])) + layers.append(block_type(self.num_in_chs[branch_index], num_channels[branch_index])) return nn.Sequential(*layers) - def _make_branches(self, num_branches, block, num_blocks, num_channels): + def _make_branches(self, num_branches, block_type, num_blocks, num_channels): branches = [] for i in range(num_branches): - branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) + branches.append(self._make_one_branch(i, block_type, num_blocks, num_channels)) return nn.ModuleList(branches) @@ -462,16 +447,18 @@ def _make_fuse_layers(self): conv3x3s = [] for k in range(i - j): if k == i - j - 1: - num_outchannels_conv3x3 = num_in_chs[i] + num_out_chs_conv3x3 = num_in_chs[i] conv3x3s.append(nn.Sequential( - nn.Conv2d(num_in_chs[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), - nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM))) + nn.Conv2d(num_in_chs[j], num_out_chs_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_out_chs_conv3x3, momentum=_BN_MOMENTUM) + )) else: - num_outchannels_conv3x3 = num_in_chs[j] + num_out_chs_conv3x3 = num_in_chs[j] conv3x3s.append(nn.Sequential( - nn.Conv2d(num_in_chs[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), - nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM), - nn.ReLU(False))) + nn.Conv2d(num_in_chs[j], num_out_chs_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_out_chs_conv3x3, momentum=_BN_MOMENTUM), + nn.ReLU(False) + )) fuse_layer.append(nn.Sequential(*conv3x3s)) fuse_layers.append(nn.ModuleList(fuse_layer)) @@ -480,7 +467,7 @@ def _make_fuse_layers(self): def get_num_in_chs(self): return self.num_in_chs - def forward(self, x: List[torch.Tensor]): + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: if self.num_branches == 1: return [self.branches[0](x[0])] @@ -489,18 +476,44 @@ def forward(self, x: List[torch.Tensor]): x_fuse = [] for i, fuse_outer in enumerate(self.fuse_layers): - y = x[0] if i == 0 else fuse_outer[0](x[0]) - for j in range(1, self.num_branches): - if i == j: - y = y + x[j] + y = None + for j, f in enumerate(fuse_outer): + if y is None: + y = f(x[j]) else: - y = y + fuse_outer[j](x[j]) + y = y + f(x[j]) x_fuse.append(self.fuse_act(y)) - return x_fuse -blocks_dict = { +class SequentialList(nn.Sequential): + + def __init__(self, *args): + super(SequentialList, self).__init__(*args) + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (List[torch.Tensor]) -> (List[torch.Tensor]) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> (List[torch.Tensor]) + pass + + def forward(self, x) -> List[torch.Tensor]: + for module in self: + x = module(x) + return x + + +@torch.jit.interface +class ModuleInterface(torch.nn.Module): + def forward(self, input: torch.Tensor) -> torch.Tensor: # `input` has a same name in Sequential forward + pass + + +block_types_dict = { 'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck } @@ -508,12 +521,23 @@ def forward(self, x: List[torch.Tensor]): class HighResolutionNet(nn.Module): - def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0, head='classification'): + def __init__( + self, + cfg, + in_chans=3, + num_classes=1000, + output_stride=32, + global_pool='avg', + drop_rate=0.0, + head='classification', + **kwargs, + ): super(HighResolutionNet, self).__init__() self.num_classes = num_classes - self.drop_rate = drop_rate + assert output_stride == 32 # FIXME support dilation - stem_width = cfg['STEM_WIDTH'] + cfg.update(**kwargs) + stem_width = cfg['stem_width'] self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(stem_width, momentum=_BN_MOMENTUM) self.act1 = nn.ReLU(inplace=True) @@ -521,68 +545,80 @@ def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_ra self.bn2 = nn.BatchNorm2d(64, momentum=_BN_MOMENTUM) self.act2 = nn.ReLU(inplace=True) - self.stage1_cfg = cfg['STAGE1'] - num_channels = self.stage1_cfg['NUM_CHANNELS'][0] - block = blocks_dict[self.stage1_cfg['BLOCK']] - num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] - self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) - stage1_out_channel = block.expansion * num_channels - - self.stage2_cfg = cfg['STAGE2'] - num_channels = self.stage2_cfg['NUM_CHANNELS'] - block = blocks_dict[self.stage2_cfg['BLOCK']] - num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.stage1_cfg = cfg['stage1'] + num_channels = self.stage1_cfg['num_channels'][0] + block_type = block_types_dict[self.stage1_cfg['block_type']] + num_blocks = self.stage1_cfg['num_blocks'][0] + self.layer1 = self._make_layer(block_type, 64, num_channels, num_blocks) + stage1_out_channel = block_type.expansion * num_channels + + self.stage2_cfg = cfg['stage2'] + num_channels = self.stage2_cfg['num_channels'] + block_type = block_types_dict[self.stage2_cfg['block_type']] + num_channels = [num_channels[i] * block_type.expansion for i in range(len(num_channels))] self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels) self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels) - self.stage3_cfg = cfg['STAGE3'] - num_channels = self.stage3_cfg['NUM_CHANNELS'] - block = blocks_dict[self.stage3_cfg['BLOCK']] - num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.stage3_cfg = cfg['stage3'] + num_channels = self.stage3_cfg['num_channels'] + block_type = block_types_dict[self.stage3_cfg['block_type']] + num_channels = [num_channels[i] * block_type.expansion for i in range(len(num_channels))] self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels) - self.stage4_cfg = cfg['STAGE4'] - num_channels = self.stage4_cfg['NUM_CHANNELS'] - block = blocks_dict[self.stage4_cfg['BLOCK']] - num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.stage4_cfg = cfg['stage4'] + num_channels = self.stage4_cfg['num_channels'] + block_type = block_types_dict[self.stage4_cfg['block_type']] + num_channels = [num_channels[i] * block_type.expansion for i in range(len(num_channels))] self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True) self.head = head self.head_channels = None # set if _make_head called + head_conv_bias = cfg.pop('head_conv_bias', True) if head == 'classification': # Classification Head self.num_features = 2048 - self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head(pre_stage_channels) - self.global_pool, self.classifier = create_classifier( - self.num_features, self.num_classes, pool_type=global_pool) - elif head == 'incre': - self.num_features = 2048 - self.incre_modules, _, _ = self._make_head(pre_stage_channels, True) + self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head( + pre_stage_channels, + conv_bias=head_conv_bias, + ) + self.global_pool, self.head_drop, self.classifier = create_classifier( + self.num_features, + self.num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + ) else: - self.incre_modules = None - self.num_features = 256 + if head == 'incre': + self.num_features = 2048 + self.incre_modules, _, _ = self._make_head(pre_stage_channels, incre_only=True) + else: + self.num_features = 256 + self.incre_modules = None + self.global_pool = nn.Identity() + self.head_drop = nn.Identity() + self.classifier = nn.Identity() curr_stride = 2 # module names aren't actually valid here, hook or FeatureNet based extraction would not work self.feature_info = [dict(num_chs=64, reduction=curr_stride, module='stem')] for i, c in enumerate(self.head_channels if self.head_channels else num_channels): curr_stride *= 2 - c = c * 4 if self.head_channels else c # head block expansion factor of 4 + c = c * 4 if self.head_channels else c # head block_type expansion factor of 4 self.feature_info += [dict(num_chs=c, reduction=curr_stride, module=f'stage{i + 1}')] self.init_weights() - def _make_head(self, pre_stage_channels, incre_only=False): - head_block = Bottleneck + def _make_head(self, pre_stage_channels, incre_only=False, conv_bias=True): + head_block_type = Bottleneck self.head_channels = [32, 64, 128, 256] # Increasing the #channels on each resolution # from C, 2C, 4C, 8C to 128, 256, 512, 1024 incre_modules = [] for i, channels in enumerate(pre_stage_channels): - incre_modules.append(self._make_layer(head_block, channels, self.head_channels[i], 1, stride=1)) + incre_modules.append(self._make_layer(head_block_type, channels, self.head_channels[i], 1, stride=1)) incre_modules = nn.ModuleList(incre_modules) if incre_only: return incre_modules, None, None @@ -590,11 +626,12 @@ def _make_head(self, pre_stage_channels, incre_only=False): # downsampling modules downsamp_modules = [] for i in range(len(pre_stage_channels) - 1): - in_channels = self.head_channels[i] * head_block.expansion - out_channels = self.head_channels[i + 1] * head_block.expansion + in_channels = self.head_channels[i] * head_block_type.expansion + out_channels = self.head_channels[i + 1] * head_block_type.expansion downsamp_module = nn.Sequential( nn.Conv2d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1), + in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=2, padding=1, bias=conv_bias), nn.BatchNorm2d(out_channels, momentum=_BN_MOMENTUM), nn.ReLU(inplace=True) ) @@ -603,9 +640,8 @@ def _make_head(self, pre_stage_channels, incre_only=False): final_layer = nn.Sequential( nn.Conv2d( - in_channels=self.head_channels[3] * head_block.expansion, - out_channels=self.num_features, kernel_size=1, stride=1, padding=0 - ), + in_channels=self.head_channels[3] * head_block_type.expansion, out_channels=self.num_features, + kernel_size=1, stride=1, padding=0, bias=conv_bias), nn.BatchNorm2d(self.num_features, momentum=_BN_MOMENTUM), nn.ReLU(inplace=True) ) @@ -629,49 +665,49 @@ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer) else: conv3x3s = [] for j in range(i + 1 - num_branches_pre): - inchannels = num_channels_pre_layer[-1] - outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels + _in_chs = num_channels_pre_layer[-1] + _out_chs = num_channels_cur_layer[i] if j == i - num_branches_pre else _in_chs conv3x3s.append(nn.Sequential( - nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), - nn.BatchNorm2d(outchannels, momentum=_BN_MOMENTUM), + nn.Conv2d(_in_chs, _out_chs, 3, 2, 1, bias=False), + nn.BatchNorm2d(_out_chs, momentum=_BN_MOMENTUM), nn.ReLU(inplace=True))) transition_layers.append(nn.Sequential(*conv3x3s)) return nn.ModuleList(transition_layers) - def _make_layer(self, block, inplanes, planes, blocks, stride=1): + def _make_layer(self, block_type, inplanes, planes, block_types, stride=1): downsample = None - if stride != 1 or inplanes != planes * block.expansion: + if stride != 1 or inplanes != planes * block_type.expansion: downsample = nn.Sequential( - nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block.expansion, momentum=_BN_MOMENTUM), + nn.Conv2d(inplanes, planes * block_type.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block_type.expansion, momentum=_BN_MOMENTUM), ) - layers = [block(inplanes, planes, stride, downsample)] - inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block(inplanes, planes)) + layers = [block_type(inplanes, planes, stride, downsample)] + inplanes = planes * block_type.expansion + for i in range(1, block_types): + layers.append(block_type(inplanes, planes)) return nn.Sequential(*layers) def _make_stage(self, layer_config, num_in_chs, multi_scale_output=True): - num_modules = layer_config['NUM_MODULES'] - num_branches = layer_config['NUM_BRANCHES'] - num_blocks = layer_config['NUM_BLOCKS'] - num_channels = layer_config['NUM_CHANNELS'] - block = blocks_dict[layer_config['BLOCK']] - fuse_method = layer_config['FUSE_METHOD'] + num_modules = layer_config['num_modules'] + num_branches = layer_config['num_branches'] + num_blocks = layer_config['num_blocks'] + num_channels = layer_config['num_channels'] + block_type = block_types_dict[layer_config['block_type']] + fuse_method = layer_config['fuse_method'] modules = [] for i in range(num_modules): # multi_scale_output is only used last module reset_multi_scale_output = multi_scale_output or i < num_modules - 1 modules.append(HighResolutionModule( - num_branches, block, num_blocks, num_in_chs, num_channels, fuse_method, reset_multi_scale_output) + num_branches, block_type, num_blocks, num_in_chs, num_channels, fuse_method, reset_multi_scale_output) ) num_in_chs = modules[-1].get_num_in_chs() - return nn.Sequential(*modules), num_in_chs + return SequentialList(*modules), num_in_chs @torch.jit.ignore def init_weights(self): @@ -687,7 +723,7 @@ def init_weights(self): def group_matcher(self, coarse=False): matcher = dict( stem=r'^conv[12]|bn[12]', - blocks=r'^(?:layer|stage|transition)(\d+)' if coarse else [ + block_types=r'^(?:layer|stage|transition)(\d+)' if coarse else [ (r'^layer(\d+)\.(\d+)', None), (r'^stage(\d+)\.(\d+)', None), (r'^transition(\d+)', (99999,)), @@ -734,17 +770,22 @@ def forward_features(self, x): yl = self.stages(x) if self.incre_modules is None or self.downsamp_modules is None: return yl - y = self.incre_modules[0](yl[0]) - for i, down in enumerate(self.downsamp_modules): - y = self.incre_modules[i + 1](yl[i + 1]) + down(y) + + y = None + for i, incre in enumerate(self.incre_modules): + if y is None: + y = incre(yl[i]) + else: + down: ModuleInterface = self.downsamp_modules[i - 1] # needed for torchscript module indexing + y = incre(yl[i]) + down.forward(y) + y = self.final_layer(y) return y def forward_head(self, x, pre_logits: bool = False): # Classification Head x = self.global_pool(x) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.head_drop(x) return x if pre_logits else self.classifier(x) def forward(self, x): @@ -764,12 +805,29 @@ class HighResolutionNetFeatures(HighResolutionNet): conv is used for stride 2 features. """ - def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0, - feature_location='incre', out_indices=(0, 1, 2, 3, 4)): + def __init__( + self, + cfg, + in_chans=3, + num_classes=1000, + output_stride=32, + global_pool='avg', + drop_rate=0.0, + feature_location='incre', + out_indices=(0, 1, 2, 3, 4), + **kwargs, + ): assert feature_location in ('incre', '') super(HighResolutionNetFeatures, self).__init__( - cfg, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool, - drop_rate=drop_rate, head=feature_location) + cfg, + in_chans=in_chans, + num_classes=num_classes, + output_stride=output_stride, + global_pool=global_pool, + drop_rate=drop_rate, + head=feature_location, + **kwargs, + ) self.feature_info = FeatureInfo(self.feature_info, out_indices) self._out_idx = {i for i in out_indices} @@ -795,7 +853,7 @@ def forward(self, x) -> List[torch.tensor]: return out -def _create_hrnet(variant, pretrained, **model_kwargs): +def _create_hrnet(variant, pretrained=False, cfg_variant=None, **model_kwargs): model_cls = HighResolutionNet features_only = False kwargs_filter = None @@ -803,18 +861,59 @@ def _create_hrnet(variant, pretrained, **model_kwargs): model_cls = HighResolutionNetFeatures kwargs_filter = ('num_classes', 'global_pool') features_only = True + cfg_variant = cfg_variant or variant model = build_model_with_cfg( - model_cls, variant, pretrained, - model_cfg=cfg_cls[variant], + model_cls, + variant, + pretrained, + model_cfg=cfg_cls[cfg_variant], pretrained_strict=not features_only, kwargs_filter=kwargs_filter, - **model_kwargs) + **model_kwargs, + ) if features_only: model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg) model.default_cfg = model.pretrained_cfg # backwards compat return model +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'hrnet_w18_small.ms_in1k': _cfg(hf_hub_id='timm/'), + 'hrnet_w18_small_v2.ms_in1k': _cfg(hf_hub_id='timm/'), + 'hrnet_w18.ms_aug_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, + ), + 'hrnet_w18.ms_in1k': _cfg(hf_hub_id='timm/'), + 'hrnet_w30.ms_in1k': _cfg(hf_hub_id='timm/'), + 'hrnet_w32.ms_in1k': _cfg(hf_hub_id='timm/'), + 'hrnet_w40.ms_in1k': _cfg(hf_hub_id='timm/'), + 'hrnet_w44.ms_in1k': _cfg(hf_hub_id='timm/'), + 'hrnet_w48.ms_in1k': _cfg(hf_hub_id='timm/'), + 'hrnet_w64.ms_in1k': _cfg(hf_hub_id='timm/'), + + 'hrnet_w18_ssld.paddle_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, test_crop_pct=1.0, test_input_size=(3, 288, 288) + ), + 'hrnet_w48_ssld.paddle_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, test_crop_pct=1.0, test_input_size=(3, 288, 288) + ), +}) + + @register_model def hrnet_w18_small(pretrained=False, **kwargs): return _create_hrnet('hrnet_w18_small', pretrained, **kwargs) @@ -858,3 +957,16 @@ def hrnet_w48(pretrained=False, **kwargs): @register_model def hrnet_w64(pretrained=False, **kwargs): return _create_hrnet('hrnet_w64', pretrained, **kwargs) + + +@register_model +def hrnet_w18_ssld(pretrained=False, **kwargs): + kwargs.setdefault('head_conv_bias', False) + return _create_hrnet('hrnet_w18_ssld', cfg_variant='hrnet_w18', pretrained=pretrained, **kwargs) + + +@register_model +def hrnet_w48_ssld(pretrained=False, **kwargs): + kwargs.setdefault('head_conv_bias', False) + return _create_hrnet('hrnet_w48_ssld', cfg_variant='hrnet_w48', pretrained=pretrained, **kwargs) + diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index 3006f3d2e9..f69a1a8131 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -2,75 +2,41 @@ Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) """ +from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import create_classifier +from timm.layers import create_classifier, ConvNormAct from ._builder import build_model_with_cfg from ._manipulate import flatten_modules -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs, register_model_deprecations __all__ = ['InceptionResnetV2'] -default_cfgs = { - # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz - 'inception_resnet_v2': { - 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth', - 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), - 'crop_pct': 0.8975, 'interpolation': 'bicubic', - 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', - 'label_offset': 1, # 1001 classes in pretrained weights - }, - # ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz - 'ens_adv_inception_resnet_v2': { - 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth', - 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), - 'crop_pct': 0.8975, 'interpolation': 'bicubic', - 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', - 'label_offset': 1, # 1001 classes in pretrained weights - } -} - - -class BasicConv2d(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): - super(BasicConv2d, self).__init__() - self.conv = nn.Conv2d( - in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) - self.bn = nn.BatchNorm2d(out_planes, eps=.001) - self.relu = nn.ReLU(inplace=False) - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - x = self.relu(x) - return x - class Mixed_5b(nn.Module): - def __init__(self): + def __init__(self, conv_block=None): super(Mixed_5b, self).__init__() + conv_block = conv_block or ConvNormAct - self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1) + self.branch0 = conv_block(192, 96, kernel_size=1, stride=1) self.branch1 = nn.Sequential( - BasicConv2d(192, 48, kernel_size=1, stride=1), - BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2) + conv_block(192, 48, kernel_size=1, stride=1), + conv_block(48, 64, kernel_size=5, stride=1, padding=2) ) self.branch2 = nn.Sequential( - BasicConv2d(192, 64, kernel_size=1, stride=1), - BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), - BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + conv_block(192, 64, kernel_size=1, stride=1), + conv_block(64, 96, kernel_size=3, stride=1, padding=1), + conv_block(96, 96, kernel_size=3, stride=1, padding=1) ) self.branch3 = nn.Sequential( nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - BasicConv2d(192, 64, kernel_size=1, stride=1) + conv_block(192, 64, kernel_size=1, stride=1) ) def forward(self, x): @@ -83,26 +49,26 @@ def forward(self, x): class Block35(nn.Module): - def __init__(self, scale=1.0): + def __init__(self, scale=1.0, conv_block=None): super(Block35, self).__init__() - self.scale = scale + conv_block = conv_block or ConvNormAct - self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1) + self.branch0 = conv_block(320, 32, kernel_size=1, stride=1) self.branch1 = nn.Sequential( - BasicConv2d(320, 32, kernel_size=1, stride=1), - BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) + conv_block(320, 32, kernel_size=1, stride=1), + conv_block(32, 32, kernel_size=3, stride=1, padding=1) ) self.branch2 = nn.Sequential( - BasicConv2d(320, 32, kernel_size=1, stride=1), - BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1), - BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1) + conv_block(320, 32, kernel_size=1, stride=1), + conv_block(32, 48, kernel_size=3, stride=1, padding=1), + conv_block(48, 64, kernel_size=3, stride=1, padding=1) ) self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1) - self.relu = nn.ReLU(inplace=False) + self.act = nn.ReLU() def forward(self, x): x0 = self.branch0(x) @@ -111,20 +77,21 @@ def forward(self, x): out = torch.cat((x0, x1, x2), 1) out = self.conv2d(out) out = out * self.scale + x - out = self.relu(out) + out = self.act(out) return out class Mixed_6a(nn.Module): - def __init__(self): + def __init__(self, conv_block=None): super(Mixed_6a, self).__init__() + conv_block = conv_block or ConvNormAct - self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2) + self.branch0 = conv_block(320, 384, kernel_size=3, stride=2) self.branch1 = nn.Sequential( - BasicConv2d(320, 256, kernel_size=1, stride=1), - BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), - BasicConv2d(256, 384, kernel_size=3, stride=2) + conv_block(320, 256, kernel_size=1, stride=1), + conv_block(256, 256, kernel_size=3, stride=1, padding=1), + conv_block(256, 384, kernel_size=3, stride=2) ) self.branch2 = nn.MaxPool2d(3, stride=2) @@ -138,21 +105,21 @@ def forward(self, x): class Block17(nn.Module): - def __init__(self, scale=1.0): + def __init__(self, scale=1.0, conv_block=None): super(Block17, self).__init__() - self.scale = scale + conv_block = conv_block or ConvNormAct - self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1) + self.branch0 = conv_block(1088, 192, kernel_size=1, stride=1) self.branch1 = nn.Sequential( - BasicConv2d(1088, 128, kernel_size=1, stride=1), - BasicConv2d(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)), - BasicConv2d(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)) + conv_block(1088, 128, kernel_size=1, stride=1), + conv_block(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)), + conv_block(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)) ) self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1) - self.relu = nn.ReLU(inplace=False) + self.act = nn.ReLU() def forward(self, x): x0 = self.branch0(x) @@ -160,28 +127,29 @@ def forward(self, x): out = torch.cat((x0, x1), 1) out = self.conv2d(out) out = out * self.scale + x - out = self.relu(out) + out = self.act(out) return out class Mixed_7a(nn.Module): - def __init__(self): + def __init__(self, conv_block=None): super(Mixed_7a, self).__init__() + conv_block = conv_block or ConvNormAct self.branch0 = nn.Sequential( - BasicConv2d(1088, 256, kernel_size=1, stride=1), - BasicConv2d(256, 384, kernel_size=3, stride=2) + conv_block(1088, 256, kernel_size=1, stride=1), + conv_block(256, 384, kernel_size=3, stride=2) ) self.branch1 = nn.Sequential( - BasicConv2d(1088, 256, kernel_size=1, stride=1), - BasicConv2d(256, 288, kernel_size=3, stride=2) + conv_block(1088, 256, kernel_size=1, stride=1), + conv_block(256, 288, kernel_size=3, stride=2) ) self.branch2 = nn.Sequential( - BasicConv2d(1088, 256, kernel_size=1, stride=1), - BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1), - BasicConv2d(288, 320, kernel_size=3, stride=2) + conv_block(1088, 256, kernel_size=1, stride=1), + conv_block(256, 288, kernel_size=3, stride=1, padding=1), + conv_block(288, 320, kernel_size=3, stride=2) ) self.branch3 = nn.MaxPool2d(3, stride=2) @@ -197,21 +165,21 @@ def forward(self, x): class Block8(nn.Module): - def __init__(self, scale=1.0, no_relu=False): + def __init__(self, scale=1.0, no_relu=False, conv_block=None): super(Block8, self).__init__() - self.scale = scale + conv_block = conv_block or ConvNormAct - self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) + self.branch0 = conv_block(2080, 192, kernel_size=1, stride=1) self.branch1 = nn.Sequential( - BasicConv2d(2080, 192, kernel_size=1, stride=1), - BasicConv2d(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)), - BasicConv2d(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + conv_block(2080, 192, kernel_size=1, stride=1), + conv_block(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)), + conv_block(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) ) self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) - self.relu = None if no_relu else nn.ReLU(inplace=False) + self.relu = None if no_relu else nn.ReLU() def forward(self, x): x0 = self.branch0(x) @@ -225,81 +193,58 @@ def forward(self, x): class InceptionResnetV2(nn.Module): - def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'): + def __init__( + self, + num_classes=1000, + in_chans=3, + drop_rate=0., + output_stride=32, + global_pool='avg', + norm_layer='batchnorm2d', + norm_eps=1e-3, + act_layer='relu', + ): super(InceptionResnetV2, self).__init__() - self.drop_rate = drop_rate self.num_classes = num_classes self.num_features = 1536 assert output_stride == 32 + conv_block = partial( + ConvNormAct, + padding=0, + norm_layer=norm_layer, + act_layer=act_layer, + norm_kwargs=dict(eps=norm_eps), + act_kwargs=dict(inplace=True), + ) - self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2) - self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) - self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.conv2d_1a = conv_block(in_chans, 32, kernel_size=3, stride=2) + self.conv2d_2a = conv_block(32, 32, kernel_size=3, stride=1) + self.conv2d_2b = conv_block(32, 64, kernel_size=3, stride=1, padding=1) self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')] self.maxpool_3a = nn.MaxPool2d(3, stride=2) - self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) - self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) + self.conv2d_3b = conv_block(64, 80, kernel_size=1, stride=1) + self.conv2d_4a = conv_block(80, 192, kernel_size=3, stride=1) self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')] self.maxpool_5a = nn.MaxPool2d(3, stride=2) - self.mixed_5b = Mixed_5b() - self.repeat = nn.Sequential( - Block35(scale=0.17), - Block35(scale=0.17), - Block35(scale=0.17), - Block35(scale=0.17), - Block35(scale=0.17), - Block35(scale=0.17), - Block35(scale=0.17), - Block35(scale=0.17), - Block35(scale=0.17), - Block35(scale=0.17) - ) + self.mixed_5b = Mixed_5b(conv_block=conv_block) + self.repeat = nn.Sequential(*[Block35(scale=0.17, conv_block=conv_block) for _ in range(10)]) self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')] - self.mixed_6a = Mixed_6a() - self.repeat_1 = nn.Sequential( - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10), - Block17(scale=0.10) - ) + self.mixed_6a = Mixed_6a(conv_block=conv_block) + self.repeat_1 = nn.Sequential(*[Block17(scale=0.10, conv_block=conv_block) for _ in range(20)]) self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')] - self.mixed_7a = Mixed_7a() - self.repeat_2 = nn.Sequential( - Block8(scale=0.20), - Block8(scale=0.20), - Block8(scale=0.20), - Block8(scale=0.20), - Block8(scale=0.20), - Block8(scale=0.20), - Block8(scale=0.20), - Block8(scale=0.20), - Block8(scale=0.20) - ) - self.block8 = Block8(no_relu=True) - self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1) + self.mixed_7a = Mixed_7a(conv_block=conv_block) + self.repeat_2 = nn.Sequential(*[Block8(scale=0.20, conv_block=conv_block) for _ in range(9)]) + + self.block8 = Block8(no_relu=True, conv_block=conv_block) + self.conv2d_7b = conv_block(2080, self.num_features, kernel_size=1, stride=1) self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')] - self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + self.global_pool, self.head_drop, self.classif = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate) @torch.jit.ignore def group_matcher(self, coarse=False): @@ -352,8 +297,7 @@ def forward_features(self, x): def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) - if self.drop_rate > 0: - x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.head_drop(x) return x if pre_logits else self.classif(x) def forward(self, x): @@ -366,18 +310,36 @@ def _create_inception_resnet_v2(variant, pretrained=False, **kwargs): return build_model_with_cfg(InceptionResnetV2, variant, pretrained, **kwargs) +default_cfgs = generate_default_cfgs({ + # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz + 'inception_resnet_v2.tf_in1k': { + 'hf_hub_id': 'timm/', + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth', + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.8975, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', + 'label_offset': 1, # 1001 classes in pretrained weights + }, + # As per https://arxiv.org/abs/1705.07204 and + # ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz + 'inception_resnet_v2.tf_ens_adv_in1k': { + 'hf_hub_id': 'timm/', + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth', + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.8975, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', + 'label_offset': 1, # 1001 classes in pretrained weights + } +}) + + @register_model def inception_resnet_v2(pretrained=False, **kwargs): - r"""InceptionResnetV2 model architecture from the - `"InceptionV4, Inception-ResNet..." ` paper. - """ return _create_inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs) -@register_model -def ens_adv_inception_resnet_v2(pretrained=False, **kwargs): - r""" Ensemble Adversarially trained InceptionResnetV2 model architecture - As per https://arxiv.org/abs/1705.07204 and - https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models. - """ - return _create_inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs) +register_model_deprecations(__name__, { + 'ens_adv_inception_resnet_v2': 'inception_resnet_v2.tf_ens_adv_in1k', +}) \ No newline at end of file diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index 28794ce6ea..dc9e130248 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -3,61 +3,27 @@ Originally from torchvision Inception3 model Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE """ +from functools import partial + import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import trunc_normal_, create_classifier, Linear +from timm.layers import trunc_normal_, create_classifier, Linear, ConvNormAct from ._builder import build_model_with_cfg from ._builder import resolve_pretrained_cfg from ._manipulate import flatten_modules -from ._registry import register_model - -__all__ = ['InceptionV3', 'InceptionV3Aux'] # model_registry will add each entrypoint fn to this - +from ._registry import register_model, generate_default_cfgs, register_model_deprecations -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), - 'crop_pct': 0.875, 'interpolation': 'bicubic', - 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - 'first_conv': 'Conv2d_1a_3x3.conv', 'classifier': 'fc', - **kwargs - } - - -default_cfgs = { - # original PyTorch weights, ported from Tensorflow but modified - 'inception_v3': _cfg( - # NOTE checkpoint has aux logit layer weights - url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'), - # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) - 'tf_inception_v3': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth', - num_classes=1000, label_offset=1), - # my port of Tensorflow adversarially trained Inception V3 from - # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz - 'adv_inception_v3': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth', - num_classes=1000, label_offset=1), - # from gluon pretrained models, best performing in terms of accuracy/loss metrics - # https://gluon-cv.mxnet.io/model_zoo/classification.html - 'gluon_inception_v3': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth', - mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults - std=IMAGENET_DEFAULT_STD, # also works well with inception defaults - ) -} +__all__ = ['InceptionV3'] # model_registry will add each entrypoint fn to this class InceptionA(nn.Module): def __init__(self, in_channels, pool_features, conv_block=None): super(InceptionA, self).__init__() - if conv_block is None: - conv_block = BasicConv2d + conv_block = conv_block or ConvNormAct self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) @@ -94,8 +60,7 @@ class InceptionB(nn.Module): def __init__(self, in_channels, conv_block=None): super(InceptionB, self).__init__() - if conv_block is None: - conv_block = BasicConv2d + conv_block = conv_block or ConvNormAct self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) @@ -123,8 +88,7 @@ class InceptionC(nn.Module): def __init__(self, in_channels, channels_7x7, conv_block=None): super(InceptionC, self).__init__() - if conv_block is None: - conv_block = BasicConv2d + conv_block = conv_block or ConvNormAct self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) c7 = channels_7x7 @@ -168,8 +132,7 @@ class InceptionD(nn.Module): def __init__(self, in_channels, conv_block=None): super(InceptionD, self).__init__() - if conv_block is None: - conv_block = BasicConv2d + conv_block = conv_block or ConvNormAct self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) @@ -200,8 +163,7 @@ class InceptionE(nn.Module): def __init__(self, in_channels, conv_block=None): super(InceptionE, self).__init__() - if conv_block is None: - conv_block = BasicConv2d + conv_block = conv_block or ConvNormAct self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) @@ -248,8 +210,7 @@ class InceptionAux(nn.Module): def __init__(self, in_channels, num_classes, conv_block=None): super(InceptionAux, self).__init__() - if conv_block is None: - conv_block = BasicConv2d + conv_block = conv_block or ConvNormAct self.conv0 = conv_block(in_channels, 128, kernel_size=1) self.conv1 = conv_block(128, 768, kernel_size=5) self.conv1.stddev = 0.01 @@ -274,52 +235,56 @@ def forward(self, x): return x -class BasicConv2d(nn.Module): - - def __init__(self, in_channels, out_channels, **kwargs): - super(BasicConv2d, self).__init__() - self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) - self.bn = nn.BatchNorm2d(out_channels, eps=0.001) - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - return F.relu(x, inplace=True) - - class InceptionV3(nn.Module): - """Inception-V3 with no AuxLogits - FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns + """Inception-V3 """ - - def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=False): + aux_logits: torch.jit.Final[bool] + + def __init__( + self, + num_classes=1000, + in_chans=3, + drop_rate=0., + global_pool='avg', + aux_logits=False, + norm_layer='batchnorm2d', + norm_eps=1e-3, + act_layer='relu', + ): super(InceptionV3, self).__init__() self.num_classes = num_classes - self.drop_rate = drop_rate self.aux_logits = aux_logits - - self.Conv2d_1a_3x3 = BasicConv2d(in_chans, 32, kernel_size=3, stride=2) - self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) - self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) + conv_block = partial( + ConvNormAct, + padding=0, + norm_layer=norm_layer, + act_layer=act_layer, + norm_kwargs=dict(eps=norm_eps), + act_kwargs=dict(inplace=True), + ) + + self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1) self.Pool1 = nn.MaxPool2d(kernel_size=3, stride=2) - self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) - self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) + self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3) self.Pool2 = nn.MaxPool2d(kernel_size=3, stride=2) - self.Mixed_5b = InceptionA(192, pool_features=32) - self.Mixed_5c = InceptionA(256, pool_features=64) - self.Mixed_5d = InceptionA(288, pool_features=64) - self.Mixed_6a = InceptionB(288) - self.Mixed_6b = InceptionC(768, channels_7x7=128) - self.Mixed_6c = InceptionC(768, channels_7x7=160) - self.Mixed_6d = InceptionC(768, channels_7x7=160) - self.Mixed_6e = InceptionC(768, channels_7x7=192) + self.Mixed_5b = InceptionA(192, pool_features=32, conv_block=conv_block) + self.Mixed_5c = InceptionA(256, pool_features=64, conv_block=conv_block) + self.Mixed_5d = InceptionA(288, pool_features=64, conv_block=conv_block) + self.Mixed_6a = InceptionB(288, conv_block=conv_block) + self.Mixed_6b = InceptionC(768, channels_7x7=128, conv_block=conv_block) + self.Mixed_6c = InceptionC(768, channels_7x7=160, conv_block=conv_block) + self.Mixed_6d = InceptionC(768, channels_7x7=160, conv_block=conv_block) + self.Mixed_6e = InceptionC(768, channels_7x7=192, conv_block=conv_block) if aux_logits: - self.AuxLogits = InceptionAux(768, num_classes) + self.AuxLogits = InceptionAux(768, num_classes, conv_block=conv_block) else: self.AuxLogits = None - self.Mixed_7a = InceptionD(768) - self.Mixed_7b = InceptionE(1280) - self.Mixed_7c = InceptionE(2048) + self.Mixed_7a = InceptionD(768, conv_block=conv_block) + self.Mixed_7b = InceptionE(1280, conv_block=conv_block) + self.Mixed_7c = InceptionE(2048, conv_block=conv_block) self.feature_info = [ dict(num_chs=64, reduction=2, module='Conv2d_2b_3x3'), dict(num_chs=192, reduction=4, module='Conv2d_4a_3x3'), @@ -329,7 +294,12 @@ def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg' ] self.num_features = 2048 - self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + self.global_pool, self.head_drop, self.fc = create_classifier( + self.num_features, + self.num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + ) for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): @@ -394,85 +364,99 @@ def forward_postaux(self, x): def forward_features(self, x): x = self.forward_preaux(x) + if self.aux_logits: + aux = self.AuxLogits(x) + x = self.forward_postaux(x) + return x, aux x = self.forward_postaux(x) return x def forward_head(self, x): x = self.global_pool(x) - if self.drop_rate > 0: - x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.head_drop(x) x = self.fc(x) return x def forward(self, x): + if self.aux_logits: + x, aux = self.forward_features(x) + x = self.forward_head(x) + return x, aux x = self.forward_features(x) x = self.forward_head(x) return x -class InceptionV3Aux(InceptionV3): - """InceptionV3 with AuxLogits - """ - - def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=True): - super(InceptionV3Aux, self).__init__( - num_classes, in_chans, drop_rate, global_pool, aux_logits) - - def forward_features(self, x): - x = self.forward_preaux(x) - aux = self.AuxLogits(x) if self.training else None - x = self.forward_postaux(x) - return x, aux - - def forward(self, x): - x, aux = self.forward_features(x) - x = self.forward_head(x) - return x, aux - - def _create_inception_v3(variant, pretrained=False, **kwargs): pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) - aux_logits = kwargs.pop('aux_logits', False) + aux_logits = kwargs.get('aux_logits', False) + has_aux_logits = False + if pretrained_cfg: + # only torchvision pretrained weights have aux logits + has_aux_logits = pretrained_cfg.tag == 'tv_in1k' if aux_logits: assert not kwargs.pop('features_only', False) - model_cls = InceptionV3Aux - load_strict = variant == 'inception_v3' + load_strict = has_aux_logits else: - model_cls = InceptionV3 - load_strict = variant != 'inception_v3' + load_strict = not has_aux_logits return build_model_with_cfg( - model_cls, variant, pretrained, + InceptionV3, + variant, + pretrained, pretrained_cfg=pretrained_cfg, pretrained_strict=load_strict, - **kwargs) + **kwargs, + ) -@register_model -def inception_v3(pretrained=False, **kwargs): - # original PyTorch weights, ported from Tensorflow but modified - model = _create_inception_v3('inception_v3', pretrained=pretrained, **kwargs) - return model +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'Conv2d_1a_3x3.conv', 'classifier': 'fc', + **kwargs + } -@register_model -def tf_inception_v3(pretrained=False, **kwargs): +default_cfgs = generate_default_cfgs({ + # original PyTorch weights, ported from Tensorflow but modified + 'inception_v3.tv_in1k': _cfg( + # NOTE checkpoint has aux logit layer weights + hf_hub_id='timm/', + url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'), # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) - model = _create_inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs) - return model - - -@register_model -def adv_inception_v3(pretrained=False, **kwargs): + 'inception_v3.tf_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth', + num_classes=1000, label_offset=1), # my port of Tensorflow adversarially trained Inception V3 from # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz - model = _create_inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs) - return model + 'inception_v3.tf_adv_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth', + num_classes=1000, label_offset=1), + # from gluon pretrained models, best performing in terms of accuracy/loss metrics + # https://gluon-cv.mxnet.io/model_zoo/classification.html + 'inception_v3.gluon_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth', + mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults + std=IMAGENET_DEFAULT_STD, # also works well with inception defaults + ) +}) @register_model -def gluon_inception_v3(pretrained=False, **kwargs): - # from gluon pretrained models, best performing in terms of accuracy/loss metrics - # https://gluon-cv.mxnet.io/model_zoo/classification.html - model = _create_inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs) +def inception_v3(pretrained=False, **kwargs): + model = _create_inception_v3('inception_v3', pretrained=pretrained, **kwargs) return model + + +register_model_deprecations(__name__, { + 'tf_inception_v3': 'inception_v3.tf_in1k', + 'adv_inception_v3': 'inception_v3.tf_adv_in1k', + 'gluon_inception_v3': 'inception_v3.gluon_in1k', +}) \ No newline at end of file diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index c1559829a3..e225a48f12 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -2,49 +2,24 @@ Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) """ +from functools import partial + import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import create_classifier +from timm.layers import create_classifier, ConvNormAct from ._builder import build_model_with_cfg -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs __all__ = ['InceptionV4'] -default_cfgs = { - 'inception_v4': { - 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth', - 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), - 'crop_pct': 0.875, 'interpolation': 'bicubic', - 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - 'first_conv': 'features.0.conv', 'classifier': 'last_linear', - 'label_offset': 1, # 1001 classes in pretrained weights - } -} - - -class BasicConv2d(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): - super(BasicConv2d, self).__init__() - self.conv = nn.Conv2d( - in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) - self.bn = nn.BatchNorm2d(out_planes, eps=0.001) - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - x = self.relu(x) - return x - class Mixed3a(nn.Module): - def __init__(self): + def __init__(self, conv_block=ConvNormAct): super(Mixed3a, self).__init__() self.maxpool = nn.MaxPool2d(3, stride=2) - self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) + self.conv = conv_block(64, 96, kernel_size=3, stride=2) def forward(self, x): x0 = self.maxpool(x) @@ -54,19 +29,19 @@ def forward(self, x): class Mixed4a(nn.Module): - def __init__(self): + def __init__(self, conv_block=ConvNormAct): super(Mixed4a, self).__init__() self.branch0 = nn.Sequential( - BasicConv2d(160, 64, kernel_size=1, stride=1), - BasicConv2d(64, 96, kernel_size=3, stride=1) + conv_block(160, 64, kernel_size=1, stride=1), + conv_block(64, 96, kernel_size=3, stride=1) ) self.branch1 = nn.Sequential( - BasicConv2d(160, 64, kernel_size=1, stride=1), - BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)), - BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)), - BasicConv2d(64, 96, kernel_size=(3, 3), stride=1) + conv_block(160, 64, kernel_size=1, stride=1), + conv_block(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)), + conv_block(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)), + conv_block(64, 96, kernel_size=(3, 3), stride=1) ) def forward(self, x): @@ -77,9 +52,9 @@ def forward(self, x): class Mixed5a(nn.Module): - def __init__(self): + def __init__(self, conv_block=ConvNormAct): super(Mixed5a, self).__init__() - self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) + self.conv = conv_block(192, 192, kernel_size=3, stride=2) self.maxpool = nn.MaxPool2d(3, stride=2) def forward(self, x): @@ -90,24 +65,24 @@ def forward(self, x): class InceptionA(nn.Module): - def __init__(self): + def __init__(self, conv_block=ConvNormAct): super(InceptionA, self).__init__() - self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) + self.branch0 = conv_block(384, 96, kernel_size=1, stride=1) self.branch1 = nn.Sequential( - BasicConv2d(384, 64, kernel_size=1, stride=1), - BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1) + conv_block(384, 64, kernel_size=1, stride=1), + conv_block(64, 96, kernel_size=3, stride=1, padding=1) ) self.branch2 = nn.Sequential( - BasicConv2d(384, 64, kernel_size=1, stride=1), - BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), - BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + conv_block(384, 64, kernel_size=1, stride=1), + conv_block(64, 96, kernel_size=3, stride=1, padding=1), + conv_block(96, 96, kernel_size=3, stride=1, padding=1) ) self.branch3 = nn.Sequential( nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - BasicConv2d(384, 96, kernel_size=1, stride=1) + conv_block(384, 96, kernel_size=1, stride=1) ) def forward(self, x): @@ -120,14 +95,14 @@ def forward(self, x): class ReductionA(nn.Module): - def __init__(self): + def __init__(self, conv_block=ConvNormAct): super(ReductionA, self).__init__() - self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) + self.branch0 = conv_block(384, 384, kernel_size=3, stride=2) self.branch1 = nn.Sequential( - BasicConv2d(384, 192, kernel_size=1, stride=1), - BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), - BasicConv2d(224, 256, kernel_size=3, stride=2) + conv_block(384, 192, kernel_size=1, stride=1), + conv_block(192, 224, kernel_size=3, stride=1, padding=1), + conv_block(224, 256, kernel_size=3, stride=2) ) self.branch2 = nn.MaxPool2d(3, stride=2) @@ -141,27 +116,27 @@ def forward(self, x): class InceptionB(nn.Module): - def __init__(self): + def __init__(self, conv_block=ConvNormAct): super(InceptionB, self).__init__() - self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) + self.branch0 = conv_block(1024, 384, kernel_size=1, stride=1) self.branch1 = nn.Sequential( - BasicConv2d(1024, 192, kernel_size=1, stride=1), - BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), - BasicConv2d(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0)) + conv_block(1024, 192, kernel_size=1, stride=1), + conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + conv_block(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0)) ) self.branch2 = nn.Sequential( - BasicConv2d(1024, 192, kernel_size=1, stride=1), - BasicConv2d(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)), - BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), - BasicConv2d(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)), - BasicConv2d(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)) + conv_block(1024, 192, kernel_size=1, stride=1), + conv_block(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)), + conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + conv_block(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)), + conv_block(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)) ) self.branch3 = nn.Sequential( nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - BasicConv2d(1024, 128, kernel_size=1, stride=1) + conv_block(1024, 128, kernel_size=1, stride=1) ) def forward(self, x): @@ -174,19 +149,19 @@ def forward(self, x): class ReductionB(nn.Module): - def __init__(self): + def __init__(self, conv_block=ConvNormAct): super(ReductionB, self).__init__() self.branch0 = nn.Sequential( - BasicConv2d(1024, 192, kernel_size=1, stride=1), - BasicConv2d(192, 192, kernel_size=3, stride=2) + conv_block(1024, 192, kernel_size=1, stride=1), + conv_block(192, 192, kernel_size=3, stride=2) ) self.branch1 = nn.Sequential( - BasicConv2d(1024, 256, kernel_size=1, stride=1), - BasicConv2d(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), - BasicConv2d(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)), - BasicConv2d(320, 320, kernel_size=3, stride=2) + conv_block(1024, 256, kernel_size=1, stride=1), + conv_block(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), + conv_block(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)), + conv_block(320, 320, kernel_size=3, stride=2) ) self.branch2 = nn.MaxPool2d(3, stride=2) @@ -200,24 +175,24 @@ def forward(self, x): class InceptionC(nn.Module): - def __init__(self): + def __init__(self, conv_block=ConvNormAct): super(InceptionC, self).__init__() - self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) + self.branch0 = conv_block(1536, 256, kernel_size=1, stride=1) - self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) - self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) - self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + self.branch1_0 = conv_block(1536, 384, kernel_size=1, stride=1) + self.branch1_1a = conv_block(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch1_1b = conv_block(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) - self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) - self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0)) - self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1)) - self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) - self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + self.branch2_0 = conv_block(1536, 384, kernel_size=1, stride=1) + self.branch2_1 = conv_block(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0)) + self.branch2_2 = conv_block(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch2_3a = conv_block(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch2_3b = conv_block(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) self.branch3 = nn.Sequential( nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - BasicConv2d(1536, 256, kernel_size=1, stride=1) + conv_block(1536, 256, kernel_size=1, stride=1) ) def forward(self, x): @@ -242,37 +217,44 @@ def forward(self, x): class InceptionV4(nn.Module): - def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'): + def __init__( + self, + num_classes=1000, + in_chans=3, + output_stride=32, + drop_rate=0., + global_pool='avg', + norm_layer='batchnorm2d', + norm_eps=1e-3, + act_layer='relu', + ): super(InceptionV4, self).__init__() assert output_stride == 32 - self.drop_rate = drop_rate self.num_classes = num_classes self.num_features = 1536 - - self.features = nn.Sequential( - BasicConv2d(in_chans, 32, kernel_size=3, stride=2), - BasicConv2d(32, 32, kernel_size=3, stride=1), - BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), - Mixed3a(), - Mixed4a(), - Mixed5a(), - InceptionA(), - InceptionA(), - InceptionA(), - InceptionA(), - ReductionA(), # Mixed6a - InceptionB(), - InceptionB(), - InceptionB(), - InceptionB(), - InceptionB(), - InceptionB(), - InceptionB(), - ReductionB(), # Mixed7a - InceptionC(), - InceptionC(), - InceptionC(), + conv_block = partial( + ConvNormAct, + padding=0, + norm_layer=norm_layer, + act_layer=act_layer, + norm_kwargs=dict(eps=norm_eps), + act_kwargs=dict(inplace=True), ) + + features = [ + conv_block(in_chans, 32, kernel_size=3, stride=2), + conv_block(32, 32, kernel_size=3, stride=1), + conv_block(32, 64, kernel_size=3, stride=1, padding=1), + Mixed3a(conv_block), + Mixed4a(conv_block), + Mixed5a(conv_block), + ] + features += [InceptionA(conv_block) for _ in range(4)] + features += [ReductionA(conv_block)] # Mixed6a + features += [InceptionB(conv_block) for _ in range(7)] + features += [ReductionB(conv_block)] # Mixed7a + features += [InceptionC(conv_block) for _ in range(3)] + self.features = nn.Sequential(*features) self.feature_info = [ dict(num_chs=64, reduction=2, module='features.2'), dict(num_chs=160, reduction=4, module='features.3'), @@ -280,8 +262,8 @@ def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., dict(num_chs=1024, reduction=16, module='features.17'), dict(num_chs=1536, reduction=32, module='features.21'), ] - self.global_pool, self.last_linear = create_classifier( - self.num_features, self.num_classes, pool_type=global_pool) + self.global_pool, self.head_drop, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate) @torch.jit.ignore def group_matcher(self, coarse=False): @@ -308,8 +290,7 @@ def forward_features(self, x): def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) - if self.drop_rate > 0: - x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.head_drop(x) return x if pre_logits else self.last_linear(x) def forward(self, x): @@ -320,9 +301,25 @@ def forward(self, x): def _create_inception_v4(variant, pretrained=False, **kwargs): return build_model_with_cfg( - InceptionV4, variant, pretrained, + InceptionV4, + variant, + pretrained, feature_cfg=dict(flatten_sequential=True), - **kwargs) + **kwargs, + ) + + +default_cfgs = generate_default_cfgs({ + 'inception_v4.tf_in1k': { + 'hf_hub_id': 'timm/', + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth', + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'features.0.conv', 'classifier': 'last_linear', + 'label_offset': 1, # 1001 classes in pretrained weights + } +}) @register_model diff --git a/timm/models/mobilevit.py b/timm/models/mobilevit.py index 6d51c263d1..9fb986a6e4 100644 --- a/timm/models/mobilevit.py +++ b/timm/models/mobilevit.py @@ -23,77 +23,13 @@ from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath, is_exportable from ._builder import build_model_with_cfg from ._features_fx import register_notrace_module -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs, register_model_deprecations from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups from .vision_transformer import Block as TransformerBlock __all__ = [] -def _cfg(url='', **kwargs): - return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), - 'crop_pct': 0.9, 'interpolation': 'bicubic', - 'mean': (0., 0., 0.), 'std': (1., 1., 1.), - 'first_conv': 'stem.conv', 'classifier': 'head.fc', - 'fixed_input_size': False, - **kwargs - } - - -default_cfgs = { - 'mobilevit_xxs': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_xxs-ad385b40.pth'), - 'mobilevit_xs': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_xs-8fbd6366.pth'), - 'mobilevit_s': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth'), - 'semobilevit_s': _cfg(), - - 'mobilevitv2_050': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_050-49951ee2.pth', - crop_pct=0.888), - 'mobilevitv2_075': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_075-b5556ef6.pth', - crop_pct=0.888), - 'mobilevitv2_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_100-e464ef3b.pth', - crop_pct=0.888), - 'mobilevitv2_125': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_125-0ae35027.pth', - crop_pct=0.888), - 'mobilevitv2_150': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150-737c5019.pth', - crop_pct=0.888), - 'mobilevitv2_175': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175-16462ee2.pth', - crop_pct=0.888), - 'mobilevitv2_200': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200-b3422f67.pth', - crop_pct=0.888), - - 'mobilevitv2_150_in22ft1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_in22ft1k-0b555d7b.pth', - crop_pct=0.888), - 'mobilevitv2_175_in22ft1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_in22ft1k-4117fa1f.pth', - crop_pct=0.888), - 'mobilevitv2_200_in22ft1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_in22ft1k-1d7c8927.pth', - crop_pct=0.888), - - 'mobilevitv2_150_384_in22ft1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_384_in22ft1k-9e142854.pth', - input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), - 'mobilevitv2_175_384_in22ft1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_384_in22ft1k-059cbe56.pth', - input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), - 'mobilevitv2_200_384_in22ft1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_384_in22ft1k-32c87503.pth', - input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), -} - - def _inverted_residual_block(d, c, s, br=4.0): # inverted residual is a bottleneck block with bottle_ratio > 1 applied to in_chs, linear output, gs=1 (depthwise) return ByoBlockCfg( @@ -600,7 +536,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3) x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w) - x = self.conv_proj(x) return x @@ -625,6 +560,66 @@ def _create_mobilevit2(variant, cfg_variant=None, pretrained=False, **kwargs): **kwargs) +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'crop_pct': 0.9, 'interpolation': 'bicubic', + 'mean': (0., 0., 0.), 'std': (1., 1., 1.), + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + 'fixed_input_size': False, + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'mobilevit_xxs.cvnets_in1k': _cfg(hf_hub_id='timm/'), + 'mobilevit_xs.cvnets_in1k': _cfg(hf_hub_id='timm/'), + 'mobilevit_s.cvnets_in1k': _cfg(hf_hub_id='timm/'), + + 'mobilevitv2_050.cvnets_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.888), + 'mobilevitv2_075.cvnets_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.888), + 'mobilevitv2_100.cvnets_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.888), + 'mobilevitv2_125.cvnets_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.888), + 'mobilevitv2_150.cvnets_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.888), + 'mobilevitv2_175.cvnets_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.888), + 'mobilevitv2_200.cvnets_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.888), + + 'mobilevitv2_150.cvnets_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.888), + 'mobilevitv2_175.cvnets_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.888), + 'mobilevitv2_200.cvnets_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.888), + + 'mobilevitv2_150.cvnets_in22k_ft_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'mobilevitv2_175.cvnets_in22k_ft_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'mobilevitv2_200.cvnets_in22k_ft_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), +}) + + @register_model def mobilevit_xxs(pretrained=False, **kwargs): return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs) @@ -640,11 +635,6 @@ def mobilevit_s(pretrained=False, **kwargs): return _create_mobilevit('mobilevit_s', pretrained=pretrained, **kwargs) -@register_model -def semobilevit_s(pretrained=False, **kwargs): - return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs) - - @register_model def mobilevitv2_050(pretrained=False, **kwargs): return _create_mobilevit('mobilevitv2_050', pretrained=pretrained, **kwargs) @@ -680,37 +670,12 @@ def mobilevitv2_200(pretrained=False, **kwargs): return _create_mobilevit('mobilevitv2_200', pretrained=pretrained, **kwargs) -@register_model -def mobilevitv2_150_in22ft1k(pretrained=False, **kwargs): - return _create_mobilevit( - 'mobilevitv2_150_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs) - - -@register_model -def mobilevitv2_175_in22ft1k(pretrained=False, **kwargs): - return _create_mobilevit( - 'mobilevitv2_175_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs) +register_model_deprecations(__name__, { + 'mobilevitv2_150_in22ft1k': 'mobilevitv2_150.cvnets_in22k_ft_in1k', + 'mobilevitv2_175_in22ft1k': 'mobilevitv2_175.cvnets_in22k_ft_in1k', + 'mobilevitv2_200_in22ft1k': 'mobilevitv2_200.cvnets_in22k_ft_in1k', - -@register_model -def mobilevitv2_200_in22ft1k(pretrained=False, **kwargs): - return _create_mobilevit( - 'mobilevitv2_200_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs) - - -@register_model -def mobilevitv2_150_384_in22ft1k(pretrained=False, **kwargs): - return _create_mobilevit( - 'mobilevitv2_150_384_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs) - - -@register_model -def mobilevitv2_175_384_in22ft1k(pretrained=False, **kwargs): - return _create_mobilevit( - 'mobilevitv2_175_384_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs) - - -@register_model -def mobilevitv2_200_384_in22ft1k(pretrained=False, **kwargs): - return _create_mobilevit( - 'mobilevitv2_200_384_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs) \ No newline at end of file + 'mobilevitv2_150_384_in22ft1k': 'mobilevitv2_150.cvnets_in22k_ft_in1k_384', + 'mobilevitv2_175_384_in22ft1k': 'mobilevitv2_175.cvnets_in22k_ft_in1k_384', + 'mobilevitv2_200_384_in22ft1k': 'mobilevitv2_200.cvnets_in22k_ft_in1k_384', +}) \ No newline at end of file diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 0b2178d624..61fdc7f752 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -10,25 +10,10 @@ from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier from ._builder import build_model_with_cfg -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs __all__ = ['NASNetALarge'] -default_cfgs = { - 'nasnetalarge': { - 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nasnetalarge-dc4a7b8b.pth', - 'input_size': (3, 331, 331), - 'pool_size': (11, 11), - 'crop_pct': 0.911, - 'interpolation': 'bicubic', - 'mean': (0.5, 0.5, 0.5), - 'std': (0.5, 0.5, 0.5), - 'num_classes': 1000, - 'first_conv': 'conv0.conv', - 'classifier': 'last_linear', - 'label_offset': 1, # 1001 classes in pretrained weights - }, -} class ActConvBn(nn.Module): @@ -408,14 +393,22 @@ class NASNetALarge(nn.Module): """NASNetALarge (6 @ 4032) """ def __init__( - self, num_classes=1000, in_chans=3, stem_size=96, channel_multiplier=2, - num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'): + self, + num_classes=1000, + in_chans=3, + stem_size=96, + channel_multiplier=2, + num_features=4032, + output_stride=32, + drop_rate=0., + global_pool='avg', + pad_type='same', + ): super(NASNetALarge, self).__init__() self.num_classes = num_classes self.stem_size = stem_size self.num_features = num_features self.channel_multiplier = channel_multiplier - self.drop_rate = drop_rate assert output_stride == 32 channels = self.num_features // 24 @@ -501,8 +494,8 @@ def __init__( dict(num_chs=4032, reduction=32, module='act'), ] - self.global_pool, self.last_linear = create_classifier( - self.num_features, self.num_classes, pool_type=global_pool) + self.global_pool, self.head_drop, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate) @torch.jit.ignore def group_matcher(self, coarse=False): @@ -562,8 +555,7 @@ def forward_features(self, x): def forward_head(self, x): x = self.global_pool(x) - if self.drop_rate > 0: - x = F.dropout(x, self.drop_rate, training=self.training) + x = self.head_drop(x) x = self.last_linear(x) return x @@ -575,9 +567,30 @@ def forward(self, x): def _create_nasnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( - NASNetALarge, variant, pretrained, + NASNetALarge, + variant, + pretrained, feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model - **kwargs) + **kwargs, + ) + + +default_cfgs = generate_default_cfgs({ + 'nasnetalarge.tf_in1k': { + 'hf_hub_id': 'timm/', + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nasnetalarge-dc4a7b8b.pth', + 'input_size': (3, 331, 331), + 'pool_size': (11, 11), + 'crop_pct': 0.911, + 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5), + 'num_classes': 1000, + 'first_conv': 'conv0.conv', + 'classifier': 'last_linear', + 'label_offset': 1, # 1001 classes in pretrained weights + }, +}) @register_model diff --git a/timm/models/nest.py b/timm/models/nest.py index 681593df48..bc6984e69d 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -26,52 +26,30 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert -from timm.layers import create_conv2d, create_pool2d, to_ntuple +from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq, named_apply -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs, register_model_deprecations __all__ = ['Nest'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': [14, 14], - 'crop_pct': .875, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', - **kwargs - } - - -default_cfgs = { - # (weights from official Google JAX impl) - 'nest_base': _cfg(), - 'nest_small': _cfg(), - 'nest_tiny': _cfg(), - 'jx_nest_base': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_base-8bc41011.pth'), - 'jx_nest_small': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_small-422eaded.pth'), - 'jx_nest_tiny': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_tiny-e3428fb9.pth'), -} - - class Attention(nn.Module): """ This is much like `.vision_transformer.Attention` but uses *localised* self attention by accepting an input with an extra "image block" dim """ + fused_attn: torch.jit.Final[bool] + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, 3*dim, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) @@ -87,12 +65,17 @@ def forward(self, x): qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) # (B, H, T, N, N) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v # (B, H, T, N, C'), permute -> (B, T, N, C', H) - x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C) + x = x.permute(0, 2, 3, 4, 1).reshape(B, T, N, C) x = self.proj(x) x = self.proj_drop(x) return x # (B, T, N, C) @@ -118,11 +101,22 @@ def __init__( ): super().__init__() self.norm1 = norm_layer(dim) - self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=proj_drop, + ) def forward(self, x): y = self.norm1(x) @@ -317,7 +311,7 @@ def __init__( self.num_classes = num_classes self.num_features = embed_dims[-1] self.feature_info = [] - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + norm_layer = norm_layer or LayerNorm act_layer = act_layer or nn.GELU self.drop_rate = drop_rate self.num_levels = num_levels @@ -490,14 +484,39 @@ def checkpoint_filter_fn(state_dict, model): def _create_nest(variant, pretrained=False, **kwargs): model = build_model_with_cfg( - Nest, variant, pretrained, + Nest, + variant, + pretrained, feature_cfg=dict(out_indices=(0, 1, 2), flatten_sequential=True), pretrained_filter_fn=checkpoint_filter_fn, - **kwargs) + **kwargs, + ) return model +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': [14, 14], + 'crop_pct': .875, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'nest_base.untrained': _cfg(), + 'nest_small.untrained': _cfg(), + 'nest_tiny.untrained': _cfg(), + # (weights from official Google JAX impl, require 'SAME' padding) + 'nest_base_jx.goog_in1k': _cfg(hf_hub_id='timm/'), + 'nest_small_jx.goog_in1k': _cfg(hf_hub_id='timm/'), + 'nest_tiny_jx.goog_in1k': _cfg(hf_hub_id='timm/'), +}) + + @register_model def nest_base(pretrained=False, **kwargs): """ Nest-B @ 224x224 @@ -527,30 +546,38 @@ def nest_tiny(pretrained=False, **kwargs): @register_model -def jx_nest_base(pretrained=False, **kwargs): - """ Nest-B @ 224x224, Pretrained weights converted from official Jax impl. +def nest_base_jx(pretrained=False, **kwargs): + """ Nest-B @ 224x224 """ - kwargs['pad_type'] = 'same' - model_kwargs = dict(embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs) - model = _create_nest('jx_nest_base', pretrained=pretrained, **model_kwargs) + kwargs.setdefault('pad_type', 'same') + model_kwargs = dict( + embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs) + model = _create_nest('nest_base_jx', pretrained=pretrained, **model_kwargs) return model @register_model -def jx_nest_small(pretrained=False, **kwargs): - """ Nest-S @ 224x224, Pretrained weights converted from official Jax impl. +def nest_small_jx(pretrained=False, **kwargs): + """ Nest-S @ 224x224 """ - kwargs['pad_type'] = 'same' + kwargs.setdefault('pad_type', 'same') model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs) - model = _create_nest('jx_nest_small', pretrained=pretrained, **model_kwargs) + model = _create_nest('nest_small_jx', pretrained=pretrained, **model_kwargs) return model @register_model -def jx_nest_tiny(pretrained=False, **kwargs): - """ Nest-T @ 224x224, Pretrained weights converted from official Jax impl. +def nest_tiny_jx(pretrained=False, **kwargs): + """ Nest-T @ 224x224 """ - kwargs['pad_type'] = 'same' + kwargs.setdefault('pad_type', 'same') model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs) - model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs) + model = _create_nest('nest_tiny_jx', pretrained=pretrained, **model_kwargs) return model + + +register_model_deprecations(__name__, { + 'jx_nest_base': 'nest_base_jx', + 'jx_nest_small': 'nest_small_jx', + 'jx_nest_tiny': 'nest_tiny_jx', +}) \ No newline at end of file diff --git a/timm/models/pit.py b/timm/models/pit.py index 17dc679c06..4c5addd821 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -14,57 +14,21 @@ import math import re from functools import partial -from typing import Tuple +from typing import Sequence, Tuple import torch from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_, to_2tuple +from timm.layers import trunc_normal_, to_2tuple, LayerNorm from ._builder import build_model_with_cfg -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs from .vision_transformer import Block __all__ = ['PoolingVisionTransformer'] # model_registry will add each entrypoint fn to this -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.conv', 'classifier': 'head', - **kwargs - } - - -default_cfgs = { - # deit models (FB weights) - 'pit_ti_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_730.pth'), - 'pit_xs_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_781.pth'), - 'pit_s_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_809.pth'), - 'pit_b_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_820.pth'), - 'pit_ti_distilled_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth', - classifier=('head', 'head_dist')), - 'pit_xs_distilled_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth', - classifier=('head', 'head_dist')), - 'pit_s_distilled_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth', - classifier=('head', 'head_dist')), - 'pit_b_distilled_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth', - classifier=('head', 'head_dist')), -} - - class SequentialTuple(nn.Sequential): """ This module exists to work around torchscript typing issues list -> list""" def __init__(self, *args): @@ -87,11 +51,13 @@ def __init__( proj_drop=.0, attn_drop=.0, drop_path_prob=None, + norm_layer=None, ): super(Transformer, self).__init__() - self.layers = nn.ModuleList([]) embed_dim = base_dim * heads + self.pool = pool + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() self.blocks = nn.Sequential(*[ Block( dim=embed_dim, @@ -105,30 +71,29 @@ def __init__( ) for i in range(depth)]) - self.pool = pool - def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: x, cls_tokens = x - B, C, H, W = x.shape token_length = cls_tokens.shape[1] + if self.pool is not None: + x, cls_tokens = self.pool(x, cls_tokens) + B, C, H, W = x.shape x = x.flatten(2).transpose(1, 2) x = torch.cat((cls_tokens, x), dim=1) + x = self.norm(x) x = self.blocks(x) cls_tokens = x[:, :token_length] x = x[:, token_length:] x = x.transpose(1, 2).reshape(B, C, H, W) - if self.pool is not None: - x, cls_tokens = self.pool(x, cls_tokens) return x, cls_tokens -class ConvHeadPooling(nn.Module): +class Pooling(nn.Module): def __init__(self, in_feature, out_feature, stride, padding_mode='zeros'): - super(ConvHeadPooling, self).__init__() + super(Pooling, self).__init__() self.conv = nn.Conv2d( in_feature, @@ -148,10 +113,26 @@ def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]: class ConvEmbedding(nn.Module): - def __init__(self, in_channels, out_channels, patch_size, stride, padding): + def __init__( + self, + in_channels, + out_channels, + img_size: int = 224, + patch_size: int = 16, + stride: int = 8, + padding: int = 0, + ): super(ConvEmbedding, self).__init__() + padding = padding + self.img_size = to_2tuple(img_size) + self.patch_size = to_2tuple(patch_size) + self.height = math.floor((self.img_size[0] + 2 * padding - self.patch_size[0]) / stride + 1) + self.width = math.floor((self.img_size[1] + 2 * padding - self.patch_size[1]) / stride + 1) + self.grid_size = (self.height, self.width) + self.conv = nn.Conv2d( - in_channels, out_channels, kernel_size=patch_size, stride=stride, padding=padding, bias=True) + in_channels, out_channels, kernel_size=patch_size, + stride=stride, padding=padding, bias=True) def forward(self, x): x = self.conv(x) @@ -166,13 +147,14 @@ class PoolingVisionTransformer(nn.Module): """ def __init__( self, - img_size, - patch_size, - stride, - base_dims, - depth, - heads, - mlp_ratio, + img_size: int = 224, + patch_size: int = 16, + stride: int = 8, + stem_type: str = 'overlap', + base_dims: Sequence[int] = (48, 48, 48), + depth: Sequence[int] = (2, 6, 4), + heads: Sequence[int] = (2, 4, 8), + mlp_ratio: float = 4, num_classes=1000, in_chans=3, global_pool='token', @@ -186,50 +168,48 @@ def __init__( super(PoolingVisionTransformer, self).__init__() assert global_pool in ('token',) - padding = 0 - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - height = math.floor((img_size[0] + 2 * padding - patch_size[0]) / stride + 1) - width = math.floor((img_size[1] + 2 * padding - patch_size[1]) / stride + 1) - self.base_dims = base_dims self.heads = heads + embed_dim = base_dims[0] * heads[0] self.num_classes = num_classes self.global_pool = global_pool self.num_tokens = 2 if distilled else 1 + self.feature_info = [] - self.patch_size = patch_size - self.pos_embed = nn.Parameter(torch.randn(1, base_dims[0] * heads[0], height, width)) - self.patch_embed = ConvEmbedding(in_chans, base_dims[0] * heads[0], patch_size, stride, padding) - - self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, base_dims[0] * heads[0])) + self.patch_embed = ConvEmbedding(in_chans, embed_dim, img_size, patch_size, stride) + self.pos_embed = nn.Parameter(torch.randn(1, embed_dim, self.patch_embed.height, self.patch_embed.width)) + self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, embed_dim)) self.pos_drop = nn.Dropout(p=pos_drop_drate) transformers = [] # stochastic depth decay rule dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depth)).split(depth)] - for stage in range(len(depth)): + prev_dim = embed_dim + for i in range(len(depth)): pool = None - if stage < len(heads) - 1: - pool = ConvHeadPooling( - base_dims[stage] * heads[stage], - base_dims[stage + 1] * heads[stage + 1], + embed_dim = base_dims[i] * heads[i] + if i > 0: + pool = Pooling( + prev_dim, + embed_dim, stride=2, ) transformers += [Transformer( - base_dims[stage], - depth[stage], - heads[stage], + base_dims[i], + depth[i], + heads[i], mlp_ratio, pool=pool, proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, - drop_path_prob=dpr[stage], - ) - ] + drop_path_prob=dpr[i], + )] + prev_dim = embed_dim + self.feature_info += [dict(num_chs=prev_dim, reduction=(stride - 1) * 2**i, module=f'transformers.{i}')] + self.transformers = SequentialTuple(*transformers) self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6) - self.num_features = self.embed_dim = base_dims[-1] * heads[-1] + self.num_features = self.embed_dim = embed_dim # Classifier head self.head_drop = nn.Dropout(drop_rate) @@ -318,25 +298,58 @@ def checkpoint_filter_fn(state_dict, model): # if k == 'pos_embed' and v.shape != model.pos_embed.shape: # # To resize pos embedding when using model at different size from pretrained weights # v = resize_pos_embed(v, model.pos_embed) - k = p_blocks.sub(lambda exp: f'transformers.{int(exp.group(1))}.pool.', k) + k = p_blocks.sub(lambda exp: f'transformers.{int(exp.group(1)) + 1}.pool.', k) out_dict[k] = v return out_dict def _create_pit(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') + default_out_indices = tuple(range(3)) + out_indices = kwargs.pop('out_indices', default_out_indices) model = build_model_with_cfg( PoolingVisionTransformer, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(feature_cls='hook', no_rewrite=True, out_indices=out_indices), **kwargs, ) return model +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.conv', 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + # deit models (FB weights) + 'pit_ti_224.in1k': _cfg(hf_hub_id='timm/'), + 'pit_xs_224.in1k': _cfg(hf_hub_id='timm/'), + 'pit_s_224.in1k': _cfg(hf_hub_id='timm/'), + 'pit_b_224.in1k': _cfg(hf_hub_id='timm/'), + 'pit_ti_distilled_224.in1k': _cfg( + hf_hub_id='timm/', + classifier=('head', 'head_dist')), + 'pit_xs_distilled_224.in1k': _cfg( + hf_hub_id='timm/', + classifier=('head', 'head_dist')), + 'pit_s_distilled_224.in1k': _cfg( + hf_hub_id='timm/', + classifier=('head', 'head_dist')), + 'pit_b_distilled_224.in1k': _cfg( + hf_hub_id='timm/', + classifier=('head', 'head_dist')), +}) + + @register_model def pit_b_224(pretrained, **kwargs): model_args = dict( diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 7291c8fb71..0b55f354d7 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -14,26 +14,10 @@ from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier from ._builder import build_model_with_cfg -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs __all__ = ['PNASNet5Large'] -default_cfgs = { - 'pnasnet5large': { - 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth', - 'input_size': (3, 331, 331), - 'pool_size': (11, 11), - 'crop_pct': 0.911, - 'interpolation': 'bicubic', - 'mean': (0.5, 0.5, 0.5), - 'std': (0.5, 0.5, 0.5), - 'num_classes': 1000, - 'first_conv': 'conv_0.conv', - 'classifier': 'last_linear', - 'label_offset': 1, # 1001 classes in pretrained weights - }, -} - class SeparableConv2d(nn.Module): @@ -185,8 +169,16 @@ def forward(self, x_left): class Cell(CellBase): - def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type='', - is_reduction=False, match_prev_layer_dims=False): + def __init__( + self, + in_chs_left, + out_chs_left, + in_chs_right, + out_chs_right, + pad_type='', + is_reduction=False, + match_prev_layer_dims=False, + ): super(Cell, self).__init__() # If `is_reduction` is set to `True` stride 2 is used for @@ -236,10 +228,17 @@ def forward(self, x_left, x_right): class PNASNet5Large(nn.Module): - def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''): + def __init__( + self, + num_classes=1000, + in_chans=3, + output_stride=32, + drop_rate=0., + global_pool='avg', + pad_type='', + ): super(PNASNet5Large, self).__init__() self.num_classes = num_classes - self.drop_rate = drop_rate self.num_features = 4320 assert output_stride == 32 @@ -293,8 +292,8 @@ def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., dict(num_chs=4320, reduction=32, module='act'), ] - self.global_pool, self.last_linear = create_classifier( - self.num_features, self.num_classes, pool_type=global_pool) + self.global_pool, self.head_drop, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate) @torch.jit.ignore def group_matcher(self, coarse=False): @@ -334,8 +333,7 @@ def forward_features(self, x): def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) - if self.drop_rate > 0: - x = F.dropout(x, self.drop_rate, training=self.training) + x = self.head_drop(x) return x if pre_logits else self.last_linear(x) def forward(self, x): @@ -346,9 +344,30 @@ def forward(self, x): def _create_pnasnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( - PNASNet5Large, variant, pretrained, + PNASNet5Large, + variant, + pretrained, feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model - **kwargs) + **kwargs, + ) + + +default_cfgs = generate_default_cfgs({ + 'pnasnet5large.tf_in1k': { + 'hf_hub_id': 'timm/', + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth', + 'input_size': (3, 331, 331), + 'pool_size': (11, 11), + 'crop_pct': 0.911, + 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5), + 'num_classes': 1000, + 'first_conv': 'conv_0.conv', + 'classifier': 'last_linear', + 'label_offset': 1, # 1001 classes in pretrained weights + }, +}) @register_model diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index d230e788b5..0b4c54c81e 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -16,42 +16,21 @@ """ import math -from functools import partial from typing import Tuple, List, Callable, Union import torch import torch.nn as nn +import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_ +from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn from ._builder import build_model_with_cfg -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs __all__ = ['PyramidVisionTransformerV2'] -def _cfg(url='', **kwargs): - return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.9, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': False, - **kwargs - } - - -default_cfgs = { - 'pvt_v2_b0': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b0.pth'), - 'pvt_v2_b1': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b1.pth'), - 'pvt_v2_b2': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b2.pth'), - 'pvt_v2_b3': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b3.pth'), - 'pvt_v2_b4': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b4.pth'), - 'pvt_v2_b5': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b5.pth'), - 'pvt_v2_b2_li': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b2_li.pth') -} - - class MlpWithDepthwiseConv(nn.Module): def __init__( self, @@ -87,6 +66,8 @@ def forward(self, x, feat_size: List[int]): class Attention(nn.Module): + fused_attn: torch.jit.Final[bool] + def __init__( self, dim, @@ -104,6 +85,7 @@ def __init__( self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() self.q = nn.Linear(dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) @@ -132,26 +114,31 @@ def forward(self, x, feat_size: List[int]): q = self.q(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) if self.pool is not None: - x_ = x.permute(0, 2, 1).reshape(B, C, H, W) - x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) - x_ = self.norm(x_) - x_ = self.act(x_) - kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + x = x.permute(0, 2, 1).reshape(B, C, H, W) + x = self.sr(self.pool(x)).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + x = self.act(x) + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) else: if self.sr is not None: - x_ = x.permute(0, 2, 1).reshape(B, C, H, W) - x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) - x_ = self.norm(x_) - kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + x = x.permute(0, 2, 1).reshape(B, C, H, W) + x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) else: kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) k, v = kv.unbind(0) - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -171,7 +158,7 @@ def __init__( attn_drop=0., drop_path=0., act_layer=nn.GELU, - norm_layer=nn.LayerNorm, + norm_layer=LayerNorm, ): super().__init__() self.norm1 = norm_layer(dim) @@ -184,7 +171,8 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, ) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) self.mlp = MlpWithDepthwiseConv( in_features=dim, @@ -193,10 +181,11 @@ def __init__( drop=proj_drop, extra_relu=linear_attn, ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x, feat_size: List[int]): - x = x + self.drop_path(self.attn(self.norm1(x), feat_size)) - x = x + self.drop_path(self.mlp(self.norm2(x), feat_size)) + x = x + self.drop_path1(self.attn(self.norm1(x), feat_size)) + x = x + self.drop_path2(self.mlp(self.norm2(x), feat_size)) return x @@ -216,10 +205,9 @@ def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768): def forward(self, x): x = self.proj(x) - feat_size = x.shape[-2:] - x = x.flatten(2).transpose(1, 2) + x = x.permute(0, 2, 3, 1) x = self.norm(x) - return x, feat_size + return x class PyramidVisionTransformerStage(nn.Module): @@ -237,7 +225,7 @@ def __init__( proj_drop: float = 0., attn_drop: float = 0., drop_path: Union[List[float], float] = 0.0, - norm_layer: Callable = nn.LayerNorm, + norm_layer: Callable = LayerNorm, ): super().__init__() self.grad_checkpointing = False @@ -247,7 +235,8 @@ def __init__( patch_size=3, stride=2, in_chans=dim, - embed_dim=dim_out) + embed_dim=dim_out, + ) else: assert dim == dim_out self.downsample = None @@ -267,23 +256,27 @@ def __init__( self.norm = norm_layer(dim_out) - def forward(self, x, feat_size: List[int]) -> Tuple[torch.Tensor, List[int]]: + def forward(self, x): + # x is either B, C, H, W (if downsample) or B, H, W, C if not if self.downsample is not None: - x, feat_size = self.downsample(x) + # input to downsample is B, C, H, W + x = self.downsample(x) # output B, H, W, C + B, H, W, C = x.shape + feat_size = (H, W) + x = x.reshape(B, -1, C) for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint.checkpoint(blk, x, feat_size) else: x = blk(x, feat_size) x = self.norm(x) - x = x.reshape(x.shape[0], feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2).contiguous() - return x, feat_size + x = x.reshape(B, feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2).contiguous() + return x class PyramidVisionTransformerV2(nn.Module): def __init__( self, - img_size=None, in_chans=3, num_classes=1000, global_pool='avg', @@ -298,7 +291,7 @@ def __init__( proj_drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - norm_layer=nn.LayerNorm, + norm_layer=LayerNorm, ): super().__init__() self.num_classes = num_classes @@ -310,19 +303,21 @@ def __init__( num_heads = to_ntuple(num_stages)(num_heads) sr_ratios = to_ntuple(num_stages)(sr_ratios) assert(len(embed_dims)) == num_stages + self.feature_info = [] self.patch_embed = OverlapPatchEmbed( patch_size=7, stride=4, in_chans=in_chans, - embed_dim=embed_dims[0]) + embed_dim=embed_dims[0], + ) dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] cur = 0 prev_dim = embed_dims[0] - self.stages = nn.ModuleList() + stages = [] for i in range(num_stages): - self.stages.append(PyramidVisionTransformerStage( + stages += [PyramidVisionTransformerStage( dim=prev_dim, dim_out=embed_dims[i], depth=depths[i], @@ -336,9 +331,11 @@ def __init__( attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, - )) + )] prev_dim = embed_dims[i] cur += depths[i] + self.feature_info += [dict(num_chs=prev_dim, reduction=4 * 2**i, module=f'stages.{i}')] + self.stages = nn.Sequential(*stages) # classification head self.num_features = embed_dims[-1] @@ -390,9 +387,8 @@ def reset_classifier(self, num_classes, global_pool=None): self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): - x, feat_size = self.patch_embed(x) - for stage in self.stages: - x, feat_size = stage(x, feat_size=feat_size) + x = self.patch_embed(x) + x = self.stages(x) return x def forward_head(self, x, pre_logits: bool = False): @@ -428,69 +424,80 @@ def _checkpoint_filter_fn(state_dict, model): def _create_pvt2(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') + default_out_indices = tuple(range(4)) + out_indices = kwargs.pop('out_indices', default_out_indices) model = build_model_with_cfg( - PyramidVisionTransformerV2, variant, pretrained, + PyramidVisionTransformerV2, + variant, + pretrained, pretrained_filter_fn=_checkpoint_filter_fn, - **kwargs + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs, ) return model +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': False, + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'pvt_v2_b0': _cfg(hf_hub_id='timm/'), + 'pvt_v2_b1': _cfg(hf_hub_id='timm/'), + 'pvt_v2_b2': _cfg(hf_hub_id='timm/'), + 'pvt_v2_b3': _cfg(hf_hub_id='timm/'), + 'pvt_v2_b4': _cfg(hf_hub_id='timm/'), + 'pvt_v2_b5': _cfg(hf_hub_id='timm/'), + 'pvt_v2_b2_li': _cfg(hf_hub_id='timm/'), +}) + + @register_model def pvt_v2_b0(pretrained=False, **kwargs): - model_kwargs = dict( - depths=(2, 2, 2, 2), embed_dims=(32, 64, 160, 256), num_heads=(1, 2, 5, 8), - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - return _create_pvt2('pvt_v2_b0', pretrained=pretrained, **model_kwargs) + model_args = dict(depths=(2, 2, 2, 2), embed_dims=(32, 64, 160, 256), num_heads=(1, 2, 5, 8)) + return _create_pvt2('pvt_v2_b0', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def pvt_v2_b1(pretrained=False, **kwargs): - model_kwargs = dict( - depths=(2, 2, 2, 2), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - return _create_pvt2('pvt_v2_b1', pretrained=pretrained, **model_kwargs) + model_args = dict(depths=(2, 2, 2, 2), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8)) + return _create_pvt2('pvt_v2_b1', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def pvt_v2_b2(pretrained=False, **kwargs): - model_kwargs = dict( - depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - return _create_pvt2('pvt_v2_b2', pretrained=pretrained, **model_kwargs) + model_args = dict(depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8)) + return _create_pvt2('pvt_v2_b2', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def pvt_v2_b3(pretrained=False, **kwargs): - model_kwargs = dict( - depths=(3, 4, 18, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - return _create_pvt2('pvt_v2_b3', pretrained=pretrained, **model_kwargs) + model_args = dict(depths=(3, 4, 18, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8)) + return _create_pvt2('pvt_v2_b3', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def pvt_v2_b4(pretrained=False, **kwargs): - model_kwargs = dict( - depths=(3, 8, 27, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - return _create_pvt2('pvt_v2_b4', pretrained=pretrained, **model_kwargs) + model_args = dict(depths=(3, 8, 27, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8)) + return _create_pvt2('pvt_v2_b4', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def pvt_v2_b5(pretrained=False, **kwargs): - model_kwargs = dict( - depths=(3, 6, 40, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), - mlp_ratios=(4, 4, 4, 4), norm_layer=partial(nn.LayerNorm, eps=1e-6), - **kwargs) - return _create_pvt2('pvt_v2_b5', pretrained=pretrained, **model_kwargs) + model_args = dict( + depths=(3, 6, 40, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), mlp_ratios=(4, 4, 4, 4)) + return _create_pvt2('pvt_v2_b5', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def pvt_v2_b2_li(pretrained=False, **kwargs): - model_kwargs = dict( - depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), - norm_layer=partial(nn.LayerNorm, eps=1e-6), linear=True, **kwargs) - return _create_pvt2('pvt_v2_b2_li', pretrained=pretrained, **model_kwargs) + model_args = dict( + depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), linear=True) + return _create_pvt2('pvt_v2_b2_li', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 29a49953e0..5804a4e8b0 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -9,41 +9,12 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from ._builder import build_model_with_cfg -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs from .resnet import ResNet __all__ = [] -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bilinear', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'conv1', 'classifier': 'fc', - **kwargs - } - - -default_cfgs = { - 'res2net50_26w_4s': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth'), - 'res2net50_48w_2s': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth'), - 'res2net50_14w_8s': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth'), - 'res2net50_26w_6s': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth'), - 'res2net50_26w_8s': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth'), - 'res2net101_26w_4s': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth'), - 'res2next50': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth'), -} - - class Bottle2neck(nn.Module): """ Res2Net/Res2NeXT Bottleneck Adapted from https://github.com/gasvn/Res2Net/blob/master/res2net.py @@ -149,11 +120,33 @@ def _create_res2net(variant, pretrained=False, **kwargs): return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'res2net50_26w_4s.in1k': _cfg(hf_hub_id='timm/'), + 'res2net50_48w_2s.in1k': _cfg(hf_hub_id='timm/'), + 'res2net50_14w_8s.in1k': _cfg(hf_hub_id='timm/'), + 'res2net50_26w_6s.in1k': _cfg(hf_hub_id='timm/'), + 'res2net50_26w_8s.in1k': _cfg(hf_hub_id='timm/'), + 'res2net101_26w_4s.in1k': _cfg(hf_hub_id='timm/'), + 'res2next50.in1k': _cfg(hf_hub_id='timm/'), + 'res2net50d.in1k': _cfg(hf_hub_id='timm/', first_conv='conv1.0'), + 'res2net101d.in1k': _cfg(hf_hub_id='timm/', first_conv='conv1.0'), +}) + + @register_model def res2net50_26w_4s(pretrained=False, **kwargs): """Constructs a Res2Net-50 26w4s model. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_args = dict( block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4)) @@ -163,8 +156,6 @@ def res2net50_26w_4s(pretrained=False, **kwargs): @register_model def res2net101_26w_4s(pretrained=False, **kwargs): """Constructs a Res2Net-101 26w4s model. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_args = dict( block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4)) @@ -174,8 +165,6 @@ def res2net101_26w_4s(pretrained=False, **kwargs): @register_model def res2net50_26w_6s(pretrained=False, **kwargs): """Constructs a Res2Net-50 26w6s model. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_args = dict( block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6)) @@ -185,8 +174,6 @@ def res2net50_26w_6s(pretrained=False, **kwargs): @register_model def res2net50_26w_8s(pretrained=False, **kwargs): """Constructs a Res2Net-50 26w8s model. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_args = dict( block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8)) @@ -196,8 +183,6 @@ def res2net50_26w_8s(pretrained=False, **kwargs): @register_model def res2net50_48w_2s(pretrained=False, **kwargs): """Constructs a Res2Net-50 48w2s model. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_args = dict( block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2)) @@ -207,8 +192,6 @@ def res2net50_48w_2s(pretrained=False, **kwargs): @register_model def res2net50_14w_8s(pretrained=False, **kwargs): """Constructs a Res2Net-50 14w8s model. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_args = dict( block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8)) @@ -218,9 +201,27 @@ def res2net50_14w_8s(pretrained=False, **kwargs): @register_model def res2next50(pretrained=False, **kwargs): """Construct Res2NeXt-50 4s - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_args = dict( block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4)) return _create_res2net('res2next50', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def res2net50d(pretrained=False, **kwargs): + """Construct Res2Net-50 + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, stem_type='deep', + avg_down=True, stem_width=32, block_args=dict(scale=4)) + return _create_res2net('res2net50d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def res2net101d(pretrained=False, **kwargs): + """Construct Res2Net-50 + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, stem_type='deep', + avg_down=True, stem_width=32, block_args=dict(scale=4)) + return _create_res2net('res2net101d', pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 38303f9c6d..add9afde08 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -11,45 +11,10 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SplitAttn from ._builder import build_model_with_cfg -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs from .resnet import ResNet -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bilinear', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'conv1.0', 'classifier': 'fc', - **kwargs - } - -default_cfgs = { - 'resnest14d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth'), - 'resnest26d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth'), - 'resnest50d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth'), - 'resnest101e': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth', - input_size=(3, 256, 256), pool_size=(8, 8)), - 'resnest200e': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth', - input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=0.909, interpolation='bicubic'), - 'resnest269e': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth', - input_size=(3, 416, 416), pool_size=(13, 13), crop_pct=0.928, interpolation='bicubic'), - 'resnest50d_4s2x40d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth', - interpolation='bicubic'), - 'resnest50d_1s4x24d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth', - interpolation='bicubic') -} - - class ResNestBottleneck(nn.Module): """ResNet Bottleneck """ @@ -153,7 +118,45 @@ def forward(self, x): def _create_resnest(variant, pretrained=False, **kwargs): - return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) + return build_model_with_cfg( + ResNet, + variant, + pretrained, + **kwargs, + ) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1.0', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'resnest14d.gluon_in1k': _cfg(hf_hub_id='timm/'), + 'resnest26d.gluon_in1k': _cfg(hf_hub_id='timm/'), + 'resnest50d.in1k': _cfg(hf_hub_id='timm/'), + 'resnest101e.in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'resnest200e.in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=0.909, interpolation='bicubic'), + 'resnest269e.in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 416, 416), pool_size=(13, 13), crop_pct=0.928, interpolation='bicubic'), + 'resnest50d_4s2x40d.in1k': _cfg( + hf_hub_id='timm/', + interpolation='bicubic'), + 'resnest50d_1s4x24d.in1k': _cfg( + hf_hub_id='timm/', + interpolation='bicubic') +}) @register_model diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 4d40c49a2e..8aeaadd95a 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -18,41 +18,11 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import create_classifier from ._builder import build_model_with_cfg -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs __all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4), - 'crop_pct': 0.875, 'interpolation': 'bilinear', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.0', 'classifier': 'fc', - **kwargs - } - - -default_cfgs = { - 'selecsls42': _cfg( - url='', - interpolation='bicubic'), - 'selecsls42b': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls42b-8af30141.pth', - interpolation='bicubic'), - 'selecsls60': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls60-bbf87526.pth', - interpolation='bicubic'), - 'selecsls60b': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls60b-94e619b5.pth', - interpolation='bicubic'), - 'selecsls84': _cfg( - url='', - interpolation='bicubic'), -} - - class SequentialList(nn.Sequential): def __init__(self, *args): @@ -155,7 +125,6 @@ class SelecSLS(nn.Module): def __init__(self, cfg, num_classes=1000, in_chans=3, drop_rate=0.0, global_pool='avg'): self.num_classes = num_classes - self.drop_rate = drop_rate super(SelecSLS, self).__init__() self.stem = conv_bn(in_chans, 32, stride=2) @@ -165,14 +134,16 @@ def __init__(self, cfg, num_classes=1000, in_chans=3, drop_rate=0.0, global_pool self.num_features = cfg['num_features'] self.feature_info = cfg['feature_info'] - self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + self.global_pool, self.head_drop, self.fc = create_classifier( + self.num_features, + self.num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + ) for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1.) - nn.init.constant_(m.bias, 0.) @torch.jit.ignore def group_matcher(self, coarse=False): @@ -202,8 +173,7 @@ def forward_features(self, x): def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.head_drop(x) return x if pre_logits else self.fc(x) def forward(self, x): @@ -336,10 +306,41 @@ def _create_selecsls(variant, pretrained, **kwargs): # this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises? return build_model_with_cfg( - SelecSLS, variant, pretrained, + SelecSLS, + variant, + pretrained, model_cfg=cfg, feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True), - **kwargs) + **kwargs, + ) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'selecsls42.untrained': _cfg( + interpolation='bicubic'), + 'selecsls42b.in1k': _cfg( + hf_hub_id='timm/', + interpolation='bicubic'), + 'selecsls60.in1k': _cfg( + hf_hub_id='timm/', + interpolation='bicubic'), + 'selecsls60b.in1k': _cfg( + hf_hub_id='timm/', + interpolation='bicubic'), + 'selecsls84.untrained': _cfg( + interpolation='bicubic'), +}) @register_model diff --git a/timm/models/senet.py b/timm/models/senet.py index d36e985419..4706baa73e 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -21,45 +21,11 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import create_classifier from ._builder import build_model_with_cfg -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs __all__ = ['SENet'] -def _cfg(url='', **kwargs): - return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bilinear', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'layer0.conv1', 'classifier': 'last_linear', - **kwargs - } - - -default_cfgs = { - 'legacy_senet154': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_senet154-e9eb9fe6.pth'), - 'legacy_seresnet18': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet18-4bb0ce65.pth', - interpolation='bicubic'), - 'legacy_seresnet34': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet34-a4004e63.pth'), - 'legacy_seresnet50': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'), - 'legacy_seresnet101': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'), - 'legacy_seresnet152': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'), - 'legacy_seresnext26_32x4d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth', - interpolation='bicubic'), - 'legacy_seresnext50_32x4d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_se_resnext50_32x4d-f3651bad.pth'), - 'legacy_seresnext101_32x4d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_se_resnext101_32x4d-37725eac.pth'), -} - - def _weight_init(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') @@ -401,6 +367,40 @@ def _create_senet(variant, pretrained=False, **kwargs): return build_model_with_cfg(SENet, variant, pretrained, **kwargs) +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'layer0.conv1', 'classifier': 'last_linear', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'legacy_senet154.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_senet154-e9eb9fe6.pth'), + 'legacy_seresnet18.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet18-4bb0ce65.pth', + interpolation='bicubic'), + 'legacy_seresnet34.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet34-a4004e63.pth'), + 'legacy_seresnet50.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'), + 'legacy_seresnet101.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'), + 'legacy_seresnet152.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'), + 'legacy_seresnext26_32x4d.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth', + interpolation='bicubic'), + 'legacy_seresnext50_32x4d.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_se_resnext50_32x4d-f3651bad.pth'), + 'legacy_seresnext101_32x4d.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_se_resnext101_32x4d-37725eac.pth'), +}) + + @register_model def legacy_seresnet18(pretrained=False, **kwargs): model_args = dict( diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index f3f758b9ef..2899d29ed2 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -8,36 +8,19 @@ import math from functools import partial +from itertools import accumulate from typing import Tuple import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT -from timm.layers import lecun_normal_, DropPath, Mlp, PatchEmbed as TimmPatchEmbed +from timm.layers import lecun_normal_, DropPath, Mlp, PatchEmbed, ClassifierHead from ._builder import build_model_with_cfg from ._manipulate import named_apply -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs -__all__ = ['Sequencer2D'] # model_registry will add each entrypoint fn to this - - -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': DEFAULT_CROP_PCT, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.proj', 'classifier': 'head', - **kwargs - } - - -default_cfgs = dict( - sequencer2d_s=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_s.pth"), - sequencer2d_m=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_m.pth"), - sequencer2d_l=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_l.pth"), -) +__all__ = ['Sequencer2d'] # model_registry will add each entrypoint fn to this def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False): @@ -73,27 +56,6 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=Fals module.init_weights() -def get_stage( - index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer, - norm_layer, act_layer, num_layers, bidirectional, union, - with_fc, drop=0., drop_path_rate=0., **kwargs): - assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios) - blocks = [] - for block_idx in range(layers[index]): - drop_path = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) - blocks.append(block_layer( - embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index], - rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, - num_layers=num_layers, bidirectional=bidirectional, union=union, with_fc=with_fc, - drop=drop, drop_path=drop_path)) - - if index < len(embed_dims) - 1: - blocks.append(Downsample2D(embed_dims[index], embed_dims[index + 1], patch_sizes[index + 1])) - - blocks = nn.Sequential(*blocks) - return blocks - - class RNNIdentity(nn.Module): def __init__(self, *args, **kwargs): super(RNNIdentity, self).__init__() @@ -102,12 +64,18 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]: return x, None -class RNN2DBase(nn.Module): +class RNN2dBase(nn.Module): def __init__( - self, input_size: int, hidden_size: int, - num_layers: int = 1, bias: bool = True, bidirectional: bool = True, - union="cat", with_fc=True): + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + bidirectional: bool = True, + union="cat", + with_fc=True, + ): super().__init__() self.input_size = input_size @@ -190,29 +158,67 @@ def forward(self, x): return x -class LSTM2D(RNN2DBase): +class LSTM2d(RNN2dBase): def __init__( - self, input_size: int, hidden_size: int, - num_layers: int = 1, bias: bool = True, bidirectional: bool = True, - union="cat", with_fc=True): + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + bidirectional: bool = True, + union="cat", + with_fc=True, + ): super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc) if self.with_vertical: - self.rnn_v = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional) + self.rnn_v = nn.LSTM( + input_size, + hidden_size, + num_layers, + batch_first=True, + bias=bias, + bidirectional=bidirectional, + ) if self.with_horizontal: - self.rnn_h = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional) + self.rnn_h = nn.LSTM( + input_size, + hidden_size, + num_layers, + batch_first=True, + bias=bias, + bidirectional=bidirectional, + ) -class Sequencer2DBlock(nn.Module): +class Sequencer2dBlock(nn.Module): def __init__( - self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp, - norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, - num_layers=1, bidirectional=True, union="cat", with_fc=True, drop=0., drop_path=0.): + self, + dim, + hidden_size, + mlp_ratio=3.0, + rnn_layer=LSTM2d, + mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + num_layers=1, + bidirectional=True, + union="cat", + with_fc=True, + drop=0., + drop_path=0., + ): super().__init__() channels_dim = int(mlp_ratio * dim) self.norm1 = norm_layer(dim) - self.rnn_tokens = rnn_layer(dim, hidden_size, num_layers=num_layers, bidirectional=bidirectional, - union=union, with_fc=with_fc) + self.rnn_tokens = rnn_layer( + dim, + hidden_size, + num_layers=num_layers, + bidirectional=bidirectional, + union=union, + with_fc=with_fc, + ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) @@ -223,17 +229,6 @@ def forward(self, x): return x -class PatchEmbed(TimmPatchEmbed): - def forward(self, x): - x = self.proj(x) - if self.flatten: - x = x.flatten(2).transpose(1, 2) # BCHW -> BNC - else: - x = x.permute(0, 2, 3, 1) # BCHW -> BHWC - x = self.norm(x) - return x - - class Shuffle(nn.Module): def __init__(self): super().__init__() @@ -247,7 +242,7 @@ def forward(self, x): return x -class Downsample2D(nn.Module): +class Downsample2d(nn.Module): def __init__(self, input_dim, output_dim, patch_size): super().__init__() self.down = nn.Conv2d(input_dim, output_dim, kernel_size=patch_size, stride=patch_size) @@ -259,20 +254,74 @@ def forward(self, x): return x -class Sequencer2D(nn.Module): +class Sequencer2dStage(nn.Module): + def __init__( + self, + dim, + dim_out, + depth, + patch_size, + hidden_size, + mlp_ratio, + downsample=False, + block_layer=Sequencer2dBlock, + rnn_layer=LSTM2d, + mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + num_layers=1, + bidirectional=True, + union="cat", + with_fc=True, + drop=0., + drop_path=0., + ): + super().__init__() + if downsample: + self.downsample = Downsample2d(dim, dim_out, patch_size) + else: + assert dim == dim_out + self.downsample = nn.Identity() + + blocks = [] + for block_idx in range(depth): + blocks.append(block_layer( + dim_out, + hidden_size, + mlp_ratio=mlp_ratio, + rnn_layer=rnn_layer, + mlp_layer=mlp_layer, + norm_layer=norm_layer, + act_layer=act_layer, + num_layers=num_layers, + bidirectional=bidirectional, + union=union, + with_fc=with_fc, + drop=drop, + drop_path=drop_path[block_idx] if isinstance(drop_path, (list, tuple)) else drop_path, + )) + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + x = self.downsample(x) + x = self.blocks(x) + return x + + +class Sequencer2d(nn.Module): def __init__( self, num_classes=1000, img_size=224, in_chans=3, global_pool='avg', - layers=[4, 3, 8, 3], - patch_sizes=[7, 2, 1, 1], - embed_dims=[192, 384, 384, 384], - hidden_sizes=[48, 96, 96, 96], - mlp_ratios=[3.0, 3.0, 3.0, 3.0], - block_layer=Sequencer2DBlock, - rnn_layer=LSTM2D, + layers=(4, 3, 8, 3), + patch_sizes=(7, 2, 2, 1), + embed_dims=(192, 384, 384, 384), + hidden_sizes=(48, 96, 96, 96), + mlp_ratios=(3.0, 3.0, 3.0, 3.0), + block_layer=Sequencer2dBlock, + rnn_layer=LSTM2d, mlp_layer=Mlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, @@ -291,23 +340,56 @@ def __init__( self.global_pool = global_pool self.num_features = embed_dims[-1] # num_features for consistency with other models self.feature_dim = -1 # channel dim index for feature outputs (rank 4, NHWC) - self.embed_dims = embed_dims + self.output_fmt = 'NHWC' + self.feature_info = [] + self.stem = PatchEmbed( - img_size=img_size, patch_size=patch_sizes[0], in_chans=in_chans, - embed_dim=embed_dims[0], norm_layer=norm_layer if stem_norm else None, - flatten=False) - - self.blocks = nn.Sequential(*[ - get_stage( - i, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer=block_layer, - rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, - num_layers=num_rnn_layers, bidirectional=bidirectional, - union=union, with_fc=with_fc, drop=drop_rate, drop_path_rate=drop_path_rate, - ) - for i, _ in enumerate(embed_dims)]) + img_size=None, + patch_size=patch_sizes[0], + in_chans=in_chans, + embed_dim=embed_dims[0], + norm_layer=norm_layer if stem_norm else None, + flatten=False, + output_fmt='NHWC', + ) + assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios) + reductions = list(accumulate(patch_sizes, lambda x, y: x * y)) + stages = [] + prev_dim = embed_dims[0] + for i, _ in enumerate(embed_dims): + stages += [Sequencer2dStage( + prev_dim, + embed_dims[i], + depth=layers[i], + downsample=i > 0, + patch_size=patch_sizes[i], + hidden_size=hidden_sizes[i], + mlp_ratio=mlp_ratios[i], + block_layer=block_layer, + rnn_layer=rnn_layer, + mlp_layer=mlp_layer, + norm_layer=norm_layer, + act_layer=act_layer, + num_layers=num_rnn_layers, + bidirectional=bidirectional, + union=union, + with_fc=with_fc, + drop=drop_rate, + drop_path=drop_path_rate, + )] + prev_dim = embed_dims[i] + self.feature_info += [dict(num_chs=prev_dim, reduction=reductions[i], module=f'stages.{i}')] + + self.stages = nn.Sequential(*stages) self.norm = norm_layer(embed_dims[-1]) - self.head = nn.Linear(embed_dims[-1], self.num_classes) if num_classes > 0 else nn.Identity() + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + input_fmt=self.output_fmt, + ) self.init_weights(nlhb=nlhb) @@ -320,8 +402,11 @@ def group_matcher(self, coarse=False): return dict( stem=r'^stem', blocks=[ - (r'^blocks\.(\d+)\..*\.down', (99999,)), - (r'^blocks\.(\d+)', None) if coarse else (r'^blocks\.(\d+)\.(\d+)', None), + (r'^stages\.(\d+)', None), + (r'^norm', (99999,)) + ] if coarse else [ + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^stages\.(\d+)\.downsample', (0,)), (r'^norm', (99999,)) ] ) @@ -336,21 +421,16 @@ def get_classifier(self): def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes - if global_pool is not None: - assert global_pool in ('', 'avg') - self.global_pool = global_pool - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head.reset(num_classes, pool_type=global_pool) def forward_features(self, x): x = self.stem(x) - x = self.blocks(x) + x = self.stages(x) x = self.norm(x) return x def forward_head(self, x, pre_logits: bool = False): - if self.global_pool == 'avg': - x = x.mean(dim=(1, 2)) - return x if pre_logits else self.head(x) + return self.head(x, pre_logits=True) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) @@ -358,15 +438,56 @@ def forward(self, x): return x -def _create_sequencer2d(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Sequencer2D models.') +def checkpoint_filter_fn(state_dict, model): + """ Remap original checkpoints -> timm """ + if 'stages.0.blocks.0.norm1.weight' in state_dict: + return state_dict # already translated checkpoint + if 'model' in state_dict: + state_dict = state_dict['model'] + + import re + out_dict = {} + for k, v in state_dict.items(): + k = re.sub(r'blocks.([0-9]+).([0-9]+).down', lambda x: f'stages.{int(x.group(1)) + 1}.downsample.down', k) + k = re.sub(r'blocks.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) + k = k.replace('head.', 'head.fc.') + out_dict[k] = v - model = build_model_with_cfg(Sequencer2D, variant, pretrained, **kwargs) + return out_dict + + +def _create_sequencer2d(variant, pretrained=False, **kwargs): + default_out_indices = tuple(range(3)) + out_indices = kwargs.pop('out_indices', default_out_indices) + + model = build_model_with_cfg( + Sequencer2d, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs, + ) return model -# main +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': DEFAULT_CROP_PCT, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.proj', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'sequencer2d_s.in1k': _cfg(hf_hub_id='timm/'), + 'sequencer2d_m.in1k': _cfg(hf_hub_id='timm/'), + 'sequencer2d_l.in1k': _cfg(hf_hub_id='timm/'), +}) + @register_model def sequencer2d_s(pretrained=False, **kwargs): @@ -376,12 +497,12 @@ def sequencer2d_s(pretrained=False, **kwargs): embed_dims=[192, 384, 384, 384], hidden_sizes=[48, 96, 96, 96], mlp_ratios=[3.0, 3.0, 3.0, 3.0], - rnn_layer=LSTM2D, + rnn_layer=LSTM2d, bidirectional=True, union="cat", with_fc=True, - **kwargs) - model = _create_sequencer2d('sequencer2d_s', pretrained=pretrained, **model_args) + ) + model = _create_sequencer2d('sequencer2d_s', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -393,12 +514,12 @@ def sequencer2d_m(pretrained=False, **kwargs): embed_dims=[192, 384, 384, 384], hidden_sizes=[48, 96, 96, 96], mlp_ratios=[3.0, 3.0, 3.0, 3.0], - rnn_layer=LSTM2D, + rnn_layer=LSTM2d, bidirectional=True, union="cat", with_fc=True, **kwargs) - model = _create_sequencer2d('sequencer2d_m', pretrained=pretrained, **model_args) + model = _create_sequencer2d('sequencer2d_m', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -410,10 +531,10 @@ def sequencer2d_l(pretrained=False, **kwargs): embed_dims=[192, 384, 384, 384], hidden_sizes=[48, 96, 96, 96], mlp_ratios=[3.0, 3.0, 3.0, 3.0], - rnn_layer=LSTM2D, + rnn_layer=LSTM2d, bidirectional=True, union="cat", with_fc=True, **kwargs) - model = _create_sequencer2d('sequencer2d_l', pretrained=pretrained, **model_args) + model = _create_sequencer2d('sequencer2d_l', pretrained=pretrained, **dict(model_args, **kwargs)) return model diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 425bd7c219..a0b8d6bf8a 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -15,34 +15,10 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SelectiveKernel, ConvNormAct, create_attn from ._builder import build_model_with_cfg -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs from .resnet import ResNet -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'conv1', 'classifier': 'fc', - **kwargs - } - - -default_cfgs = { - 'skresnet18': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth'), - 'skresnet34': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'), - 'skresnet50': _cfg(), - 'skresnet50d': _cfg( - first_conv='conv1.0'), - 'skresnext50_32x4d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth'), -} - - class SelectiveKernelBasic(nn.Module): expansion = 1 @@ -166,7 +142,33 @@ def forward(self, x): def _create_skresnet(variant, pretrained=False, **kwargs): - return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) + return build_model_with_cfg( + ResNet, + variant, + pretrained, + **kwargs, + ) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'skresnet18.ra_in1k': _cfg(hf_hub_id='timm/'), + 'skresnet34.ra_in1k': _cfg(hf_hub_id='timm/'), + 'skresnet50.untrained': _cfg(), + 'skresnet50d.untrained': _cfg( + first_conv='conv1.0'), + 'skresnext50_32x4d.ra_in1k': _cfg(hf_hub_id='timm/'), +}) @register_model diff --git a/timm/models/twins.py b/timm/models/twins.py index ddf7897d2a..dda2a5d1c7 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -23,44 +23,11 @@ from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_, use_fused_attn from ._builder import build_model_with_cfg from ._features_fx import register_notrace_module -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs from .vision_transformer import Attention __all__ = ['Twins'] # model_registry will add each entrypoint fn to this - -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head', - **kwargs - } - - -default_cfgs = { - 'twins_pcpvt_small': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_small-e70e7e7a.pth', - ), - 'twins_pcpvt_base': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_base-e5ecb09b.pth', - ), - 'twins_pcpvt_large': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_large-d273f802.pth', - ), - 'twins_svt_small': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_small-42e5f78c.pth', - ), - 'twins_svt_base': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_base-c2265010.pth', - ), - 'twins_svt_large': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_large-90f6aaa9.pth', - ), -} - Size_ = Tuple[int, int] @@ -469,6 +436,27 @@ def _create_twins(variant, pretrained=False, **kwargs): return model +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'twins_pcpvt_small.in1k': _cfg(hf_hub_id='timm/'), + 'twins_pcpvt_base.in1k': _cfg(hf_hub_id='timm/'), + 'twins_pcpvt_large.in1k': _cfg(hf_hub_id='timm/'), + 'twins_svt_small.in1k': _cfg(hf_hub_id='timm/'), + 'twins_svt_base.in1k': _cfg(hf_hub_id='timm/'), + 'twins_svt_large.in1k': _cfg(hf_hub_id='timm/'), +}) + + @register_model def twins_pcpvt_small(pretrained=False, **kwargs): model_args = dict( diff --git a/timm/models/vgg.py b/timm/models/vgg.py index abe9f8d5de..1ba12c9a92 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -15,34 +15,11 @@ from timm.layers import ClassifierHead from ._builder import build_model_with_cfg from ._features_fx import register_notrace_module -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs __all__ = ['VGG'] -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bilinear', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'features.0', 'classifier': 'head.fc', - **kwargs - } - - -default_cfgs = { - 'vgg11': _cfg(url='https://download.pytorch.org/models/vgg11-bbd30ac9.pth'), - 'vgg13': _cfg(url='https://download.pytorch.org/models/vgg13-c768596a.pth'), - 'vgg16': _cfg(url='https://download.pytorch.org/models/vgg16-397923af.pth'), - 'vgg19': _cfg(url='https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'), - 'vgg11_bn': _cfg(url='https://download.pytorch.org/models/vgg11_bn-6002323d.pth'), - 'vgg13_bn': _cfg(url='https://download.pytorch.org/models/vgg13_bn-abd245e5.pth'), - 'vgg16_bn': _cfg(url='https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'), - 'vgg19_bn': _cfg(url='https://download.pytorch.org/models/vgg19_bn-c79401a0.pth'), -} - - cfgs: Dict[str, List[Union[str, int]]] = { 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], @@ -55,8 +32,15 @@ def _cfg(url='', **kwargs): class ConvMlp(nn.Module): def __init__( - self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0, - drop_rate: float = 0.2, act_layer: nn.Module = None, conv_layer: nn.Module = None): + self, + in_features=512, + out_features=4096, + kernel_size=7, + mlp_ratio=1.0, + drop_rate: float = 0.2, + act_layer: nn.Module = None, + conv_layer: nn.Module = None, + ): super(ConvMlp, self).__init__() self.input_kernel_size = kernel_size mid_features = int(out_features * mlp_ratio) @@ -124,10 +108,20 @@ def __init__( self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{len(layers) - 1}')) self.pre_logits = ConvMlp( - prev_chs, self.num_features, 7, mlp_ratio=mlp_ratio, - drop_rate=drop_rate, act_layer=act_layer, conv_layer=conv_layer) + prev_chs, + self.num_features, + 7, + mlp_ratio=mlp_ratio, + drop_rate=drop_rate, + act_layer=act_layer, + conv_layer=conv_layer, + ) self.head = ClassifierHead( - self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + ) self._initialize_weights() @@ -147,7 +141,11 @@ def get_classifier(self): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes self.head = ClassifierHead( - self.num_features, self.num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + self.num_features, + self.num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + ) def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) @@ -197,14 +195,40 @@ def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG: # NOTE: VGG is one of few models with stride==1 features w/ 6 out_indices [0..5] out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5)) model = build_model_with_cfg( - VGG, variant, pretrained, + VGG, + variant, + pretrained, model_cfg=cfgs[cfg], feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), pretrained_filter_fn=_filter_fn, - **kwargs) + **kwargs, + ) return model +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'features.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'vgg11.tv_in1k': _cfg(hf_hub_id='timm/'), + 'vgg13.tv_in1k': _cfg(hf_hub_id='timm/'), + 'vgg16.tv_in1k': _cfg(hf_hub_id='timm/'), + 'vgg19.tv_in1k': _cfg(hf_hub_id='timm/'), + 'vgg11_bn.tv_in1k': _cfg(hf_hub_id='timm/'), + 'vgg13_bn.tv_in1k': _cfg(hf_hub_id='timm/'), + 'vgg16_bn.tv_in1k': _cfg(hf_hub_id='timm/'), + 'vgg19_bn.tv_in1k': _cfg(hf_hub_id='timm/'), +}) + + @register_model def vgg11(pretrained: bool = False, **kwargs: Any) -> VGG: r"""VGG 11-layer model (configuration "A") from diff --git a/timm/models/visformer.py b/timm/models/visformer.py index ae83c96336..15ae7bb9b8 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -14,30 +14,11 @@ from timm.layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier, use_fused_attn from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs __all__ = ['Visformer'] -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.0', 'classifier': 'head', - **kwargs - } - - -default_cfgs = dict( - visformer_tiny=_cfg(), - visformer_small=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/visformer_small-839e1f5b.pth' - ), -) - - class SpatialMlp(nn.Module): def __init__( self, @@ -464,6 +445,23 @@ def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs): return model +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'visformer_tiny.in1k': _cfg(hf_hub_id='timm/'), + 'visformer_small.in1k': _cfg(hf_hub_id='timm/'), +}) + + @register_model def visformer_tiny(pretrained=False, **kwargs): model_cfg = dict( diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index f7160b9f83..0e962a62cf 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -61,11 +61,12 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): k = self.k_norm(k) if self.fused_attn: - attn_bias = None if self.rel_pos is not None: attn_bias = self.rel_pos.get_bias() elif shared_rel_pos is not None: attn_bias = shared_rel_pos + else: + attn_bias = None x = torch.nn.functional.scaled_dot_product_attention( q, k, v, diff --git a/timm/models/xception.py b/timm/models/xception.py index 99e74b4630..14b6e4f120 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -27,26 +27,10 @@ from timm.layers import create_classifier from ._builder import build_model_with_cfg -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs, register_model_deprecations __all__ = ['Xception'] -default_cfgs = { - 'xception': { - 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth', - 'input_size': (3, 299, 299), - 'pool_size': (10, 10), - 'crop_pct': 0.8975, - 'interpolation': 'bicubic', - 'mean': (0.5, 0.5, 0.5), - 'std': (0.5, 0.5, 0.5), - 'num_classes': 1000, - 'first_conv': 'conv1', - 'classifier': 'fc' - # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 - } -} - class SeparableConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1): @@ -244,6 +228,29 @@ def _xception(variant, pretrained=False, **kwargs): **kwargs) +default_cfgs = generate_default_cfgs({ + 'legacy_xception.tf_in1k': { + 'hf_hub_id': 'timm/', + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth', + 'input_size': (3, 299, 299), + 'pool_size': (10, 10), + 'crop_pct': 0.8975, + 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5), + 'num_classes': 1000, + 'first_conv': 'conv1', + 'classifier': 'fc' + # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 + } +}) + + @register_model -def xception(pretrained=False, **kwargs): - return _xception('xception', pretrained=pretrained, **kwargs) +def legacy_xception(pretrained=False, **kwargs): + return _xception('legacy_xception', pretrained=pretrained, **kwargs) + + +register_model_deprecations(__name__, { + 'xception': 'legacy_xception', +})