Skip to content

Commit

Permalink
Tweak DinoV2 add, add MAE ViT weights, add initial intermediate layer…
Browse files Browse the repository at this point in the history
… getter experiment
  • Loading branch information
rwightman committed May 10, 2023
1 parent 59bea4c commit a01d8f8
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 50 deletions.
2 changes: 1 addition & 1 deletion timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .padding import get_padding, get_same_padding, pad_same
from .patch_dropout import PatchDropout
from .patch_embed import PatchEmbed, resample_patch_embed
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
Expand Down
45 changes: 44 additions & 1 deletion timm/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
from typing import List, Optional, Callable
from typing import Callable, List, Optional, Tuple, Union

import torch
from torch import nn as nn
Expand Down Expand Up @@ -75,6 +75,49 @@ def forward(self, x):
return x


class PatchEmbedWithSize(PatchEmbed):
""" 2D Image to Patch Embedding
"""
output_fmt: Format

def __init__(
self,
img_size: Optional[int] = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = True,
output_fmt: Optional[str] = None,
bias: bool = True,
):
super().__init__(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer,
flatten=flatten,
output_fmt=output_fmt,
bias=bias,
)

def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
B, C, H, W = x.shape
if self.img_size is not None:
_assert(H % self.patch_size[0] == 0, f"Input image height ({H}) must be divisible by patch size ({self.patch_size[0]}).")
_assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).")

x = self.proj(x)
grid_size = x.shape[-2:]
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
elif self.output_fmt != Format.NCHW:
x = nchw_to(x, self.output_fmt)
x = self.norm(x)
return x, grid_size


def resample_patch_embed(
patch_embed,
new_size: List[int],
Expand Down
24 changes: 13 additions & 11 deletions timm/layers/pos_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,31 @@ def resample_abs_pos_embed(
verbose: bool = False,
):
# sort out sizes, assume square if old size not provided
new_size = to_2tuple(new_size)
new_ntok = new_size[0] * new_size[1]
if not old_size:
old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens))
old_size = to_2tuple(old_size)
if new_size == old_size: # might not both be same container type
num_pos_tokens = posemb.shape[1]
num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
return posemb

if not old_size:
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
old_size = hw, hw

if num_prefix_tokens:
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
else:
posemb_prefix, posemb = None, posemb

# do the interpolation
embed_dim = posemb.shape[-1]
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1)

if verbose:
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)

# add back extra (class, etc) prefix tokens
if posemb_prefix is not None:
print(posemb_prefix.shape, posemb.shape)
posemb = torch.cat([posemb_prefix, posemb], dim=1)

if not torch.jit.is_scripting() and verbose:
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')

return posemb
30 changes: 29 additions & 1 deletion timm/models/deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
from functools import partial
from typing import Sequence, Union

import torch
from torch import nn as nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import resample_abs_pos_embed
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
Expand Down Expand Up @@ -71,11 +73,37 @@ def reset_classifier(self, num_classes, global_pool=None):
def set_distilled_training(self, enable=True):
self.distilled_training = enable

def _intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
):
outputs, num_blocks = [], len(self.blocks)
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)

# forward pass
x = self.patch_embed(x)
x = torch.cat((
self.cls_token.expand(x.shape[0], -1, -1),
self.dist_token.expand(x.shape[0], -1, -1),
x),
dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.patch_drop(x)
x = self.norm_pre(x)
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in take_indices:
outputs.append(x)

return outputs

def forward_features(self, x) -> torch.Tensor:
x = self.patch_embed(x)
x = torch.cat((
self.cls_token.expand(x.shape[0], -1, -1),
self.dist_token.expand(x.shape[0], -1, -1), x),
self.dist_token.expand(x.shape[0], -1, -1),
x),
dim=1)
x = self.pos_drop(x + self.pos_embed)
if self.grad_checkpointing and not torch.jit.is_scripting():
Expand Down
Loading

0 comments on commit a01d8f8

Please sign in to comment.