diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 45f2e5410a..d55faccc4b 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -34,7 +34,7 @@ SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d from .padding import get_padding, get_same_padding, pad_same from .patch_dropout import PatchDropout -from .patch_embed import PatchEmbed, resample_patch_embed +from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed from .pool2d_same import AvgPool2dSame, create_pool2d from .pos_embed import resample_abs_pos_embed from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index a4946531b2..6ca6a1a120 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -9,7 +9,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ import logging -from typing import List, Optional, Callable +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn as nn @@ -75,6 +75,49 @@ def forward(self, x): return x +class PatchEmbedWithSize(PatchEmbed): + """ 2D Image to Patch Embedding + """ + output_fmt: Format + + def __init__( + self, + img_size: Optional[int] = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten: bool = True, + output_fmt: Optional[str] = None, + bias: bool = True, + ): + super().__init__( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer, + flatten=flatten, + output_fmt=output_fmt, + bias=bias, + ) + + def forward(self, x) -> Tuple[torch.Tensor, List[int]]: + B, C, H, W = x.shape + if self.img_size is not None: + _assert(H % self.patch_size[0] == 0, f"Input image height ({H}) must be divisible by patch size ({self.patch_size[0]}).") + _assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).") + + x = self.proj(x) + grid_size = x.shape[-2:] + if self.flatten: + x = x.flatten(2).transpose(1, 2) # NCHW -> NLC + elif self.output_fmt != Format.NCHW: + x = nchw_to(x, self.output_fmt) + x = self.norm(x) + return x, grid_size + + def resample_patch_embed( patch_embed, new_size: List[int], diff --git a/timm/layers/pos_embed.py b/timm/layers/pos_embed.py index d0e675210c..c3afce76cc 100644 --- a/timm/layers/pos_embed.py +++ b/timm/layers/pos_embed.py @@ -24,29 +24,31 @@ def resample_abs_pos_embed( verbose: bool = False, ): # sort out sizes, assume square if old size not provided - new_size = to_2tuple(new_size) - new_ntok = new_size[0] * new_size[1] - if not old_size: - old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens)) - old_size = to_2tuple(old_size) - if new_size == old_size: # might not both be same container type + num_pos_tokens = posemb.shape[1] + num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens + if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]: return posemb + if not old_size: + hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens)) + old_size = hw, hw + if num_prefix_tokens: posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:] else: posemb_prefix, posemb = None, posemb # do the interpolation + embed_dim = posemb.shape[-1] posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) - posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1) - - if verbose: - _logger.info(f'Resized position embedding: {old_size} to {new_size}.') + posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim) # add back extra (class, etc) prefix tokens if posemb_prefix is not None: - print(posemb_prefix.shape, posemb.shape) posemb = torch.cat([posemb_prefix, posemb], dim=1) + + if not torch.jit.is_scripting() and verbose: + _logger.info(f'Resized position embedding: {old_size} to {new_size}.') + return posemb diff --git a/timm/models/deit.py b/timm/models/deit.py index be52a977e6..650ab6796e 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -11,11 +11,13 @@ # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. from functools import partial +from typing import Sequence, Union import torch from torch import nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import resample_abs_pos_embed from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq @@ -71,11 +73,37 @@ def reset_classifier(self, num_classes, global_pool=None): def set_distilled_training(self, enable=True): self.distilled_training = enable + def _intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + ): + outputs, num_blocks = [], len(self.blocks) + take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n) + + # forward pass + x = self.patch_embed(x) + x = torch.cat(( + self.cls_token.expand(x.shape[0], -1, -1), + self.dist_token.expand(x.shape[0], -1, -1), + x), + dim=1) + x = self.pos_drop(x + self.pos_embed) + x = self.patch_drop(x) + x = self.norm_pre(x) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in take_indices: + outputs.append(x) + + return outputs + def forward_features(self, x) -> torch.Tensor: x = self.patch_embed(x) x = torch.cat(( self.cls_token.expand(x.shape[0], -1, -1), - self.dist_token.expand(x.shape[0], -1, -1), x), + self.dist_token.expand(x.shape[0], -1, -1), + x), dim=1) x = self.pos_drop(x + self.pos_embed) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index b38a168dab..a418f2b0e3 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -27,7 +27,7 @@ import math from collections import OrderedDict from functools import partial -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -125,7 +125,7 @@ def __init__( drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - ffn_layer=Mlp, + mlp_layer=Mlp, ): super().__init__() self.norm1 = norm_layer(dim) @@ -142,7 +142,7 @@ def __init__( self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) - self.mlp = ffn_layer( + self.mlp = mlp_layer( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, @@ -172,7 +172,7 @@ def __init__( drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - ffn_layer=Mlp, + mlp_layer=Mlp, ): super().__init__() self.init_values = init_values @@ -189,7 +189,7 @@ def __init__( self.norm1 = norm_layer(dim) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.mlp = ffn_layer( + self.mlp = mlp_layer( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, @@ -232,7 +232,7 @@ def __init__( drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - ffn_layer=None, # NOTE: not used + mlp_layer=None, # NOTE: not used ): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' @@ -326,7 +326,7 @@ def __init__( drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - ffn_layer=Mlp, + mlp_layer=Mlp, ): super().__init__() self.num_parallel = num_parallel @@ -349,7 +349,7 @@ def __init__( ]))) self.ffns.append(nn.Sequential(OrderedDict([ ('norm', norm_layer(dim)), - ('mlp', ffn_layer( + ('mlp', mlp_layer( dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, @@ -413,7 +413,7 @@ def __init__( norm_layer: Optional[Callable] = None, act_layer: Optional[Callable] = None, block_fn: Callable = Block, - ffn_layer: Callable = Mlp, + mlp_layer: Callable = Mlp, ): """ Args: @@ -435,7 +435,7 @@ def __init__( attn_drop_rate: Attention dropout rate. drop_path_rate: Stochastic depth rate. weight_init: Weight initialization scheme. - embed_layer: Patch embedding layey. + embed_layer: Patch embedding layer. norm_layer: Normalization layer. act_layer: MLP activation layer. block_fn: Transformer block layer. @@ -490,7 +490,7 @@ def __init__( drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, - ffn_layer=ffn_layer, + mlp_layer=mlp_layer, ) for i in range(depth)]) self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() @@ -560,6 +560,55 @@ def _pos_embed(self, x): x = x + self.pos_embed return self.pos_drop(x) + def _intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + ): + outputs, num_blocks = [], len(self.blocks) + take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n) + + # forward pass + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in take_indices: + outputs.append(x) + + return outputs + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + reshape: bool = False, + return_class_token: bool = False, + norm: bool = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + """ Intermediate layer accessor (NOTE: This is a WIP experiment). + Inspired by DINO / DINOv2 interface + """ + # take last n blocks if n is an int, if in is a sequence, select by matching indices + outputs = self._intermediate_layers(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs] + outputs = [out[:, self.num_prefix_tokens:] for out in outputs] + + if reshape: + grid_size = self.patch_embed.grid_size + outputs = [ + out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + def forward_features(self, x): x = self.patch_embed(x) x = self._pos_embed(x) @@ -816,9 +865,7 @@ def _convert_openai_clip(state_dict, model): def _convert_dinov2(state_dict, model): import re - out_dict = {} - for k, v in state_dict.items(): if k == "mask_token": continue @@ -828,11 +875,10 @@ def _convert_dinov2(state_dict, model): elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k): out_dict[k.replace("w3", "fc2")] = v continue - out_dict[k] = v - return out_dict + def checkpoint_filter_fn( state_dict, model, @@ -1072,19 +1118,27 @@ def _cfg(url='', **kwargs): hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - # DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune only) - 'vit_small_patch14_dinov2': _cfg( + # DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only) + 'vit_small_patch14_dinov2.lvd142m': _cfg( url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)), - 'vit_base_patch14_dinov2': _cfg( + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_base_patch14_dinov2.lvd142m': _cfg( url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)), - 'vit_large_patch14_dinov2': _cfg( + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_large_patch14_dinov2.lvd142m': _cfg( url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)), - 'vit_giant_patch14_dinov2': _cfg( + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_giant_patch14_dinov2.lvd142m': _cfg( url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)), + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), # ViT ImageNet-21K-P pretraining by MILL 'vit_base_patch16_224_miil.in21k': _cfg( @@ -1359,6 +1413,22 @@ def _cfg(url='', **kwargs): 'vit_base_patch16_xp_224.untrained': _cfg(url=''), 'vit_large_patch14_xp_224.untrained': _cfg(url=''), 'vit_huge_patch14_xp_224.untrained': _cfg(url=''), + + 'vit_base_patch16_224.mae': _cfg( + url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth', + #hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_large_patch16_224.mae': _cfg( + url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth', + # hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_huge_patch14_224.mae': _cfg( + url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_huge.pth', + # hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), }) @@ -1904,10 +1974,8 @@ def vit_small_patch14_dinov2(pretrained=False, **kwargs): """ ViT-S/14 for DINOv2 """ model_args = dict( - patch_size=14, embed_dim=384, depth=12, num_heads=6, - init_values=1.0, img_size=518, + patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1.0, img_size=518, ) - model = _create_vision_transformer( 'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -1918,10 +1986,8 @@ def vit_base_patch14_dinov2(pretrained=False, **kwargs): """ ViT-B/14 for DINOv2 """ model_args = dict( - patch_size=14, embed_dim=768, depth=12, num_heads=12, - init_values=1.0, img_size=518, + patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1.0, img_size=518, ) - model = _create_vision_transformer( 'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -1932,14 +1998,13 @@ def vit_large_patch14_dinov2(pretrained=False, **kwargs): """ ViT-L/14 for DINOv2 """ model_args = dict( - patch_size=14, embed_dim=1024, depth=24, num_heads=16, - init_values=1.0, img_size=518, + patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1.0, img_size=518, ) - model = _create_vision_transformer( 'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vit_giant_patch14_dinov2(pretrained=False, **kwargs): """ ViT-G/14 for DINOv2 @@ -1952,13 +2017,13 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs): model_args = dict( patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1.0, - mlp_ratio=2.66667 * 2, ffn_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU + mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU ) - model = _create_vision_transformer( 'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) return model + register_model_deprecations(__name__, { 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k', 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k', diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index f41e90e89b..8cf7bec1e6 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -14,6 +14,7 @@ Hacked together by / Copyright 2020, Ross Wightman """ from functools import partial +from typing import List, Tuple import torch import torch.nn as nn @@ -74,10 +75,43 @@ def forward(self, x): x = self.backbone(x) if isinstance(x, (list, tuple)): x = x[-1] # last feature if backbone outputs list/tuple of features - x = self.proj(x).flatten(2).transpose(1, 2) + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) return x +class HybridEmbedWithSize(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__( + self, + backbone, + img_size=224, + patch_size=1, + feature_size=None, + in_chans=3, + embed_dim=768, + bias=True, + ): + super().__init__( + backbone=backbone, + img_size=img_size, + patch_size=patch_size, + feature_size=feature_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=bias, + ) + + def forward(self, x) -> Tuple[torch.Tensor, List[int]]: + x = self.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + x = self.proj(x) + return x.flatten(2).transpose(1, 2), x.shape[-2:] + + def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): embed_layer = partial(HybridEmbed, backbone=backbone) kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set